Merge pull request #4047 from stanmoore1/kk_fix_exchange_bug

Fix bug in some Kokkos fixes' unpack exchange on device
This commit is contained in:
Axel Kohlmeyer
2024-01-18 14:58:09 -05:00
committed by GitHub
12 changed files with 41 additions and 7 deletions

View File

@ -931,6 +931,7 @@ void CommKokkos::exchange_device()
if (nextrarecv) {
kkbase->unpack_exchange_kokkos(
k_buf_recv,k_indices,nrecv/data_size,
nrecv1/data_size,nextrarecv1,
ExecutionSpaceFromDevice<DeviceType>::space);
DeviceType().fence();
}

View File

@ -453,8 +453,12 @@ KOKKOS_INLINE_FUNCTION
void FixNeighHistoryKokkos<DeviceType>::operator()(TagFixNeighHistoryUnpackExchange, const int &i) const
{
int index = d_indices(i);
if (index > -1) {
int m = (int) d_ubuf(d_buf(i)).i;
if (i >= nrecv1)
m = nextrarecv1 + (int) d_ubuf(d_buf(nextrarecv1 + i - nrecv1)).i;
int n = (int) d_ubuf(d_buf(m++)).i;
d_npartner(index) = n;
for (int p = 0; p < n; p++) {
@ -471,6 +475,7 @@ void FixNeighHistoryKokkos<DeviceType>::operator()(TagFixNeighHistoryUnpackExcha
template<class DeviceType>
void FixNeighHistoryKokkos<DeviceType>::unpack_exchange_kokkos(
DAT::tdual_xfloat_2d &k_buf, DAT::tdual_int_1d &k_indices, int nrecv,
int nrecv1, int nextrarecv1,
ExecutionSpace /*space*/)
{
d_buf = typename AT::t_xfloat_1d_um(
@ -478,6 +483,9 @@ void FixNeighHistoryKokkos<DeviceType>::unpack_exchange_kokkos(
k_buf.extent(0)*k_buf.extent(1));
d_indices = k_indices.view<DeviceType>();
this->nrecv1 = nrecv1;
this->nextrarecv1 = nextrarecv1;
d_npartner = k_npartner.template view<DeviceType>();
d_partner = k_partner.template view<DeviceType>();
d_valuepartner = k_valuepartner.template view<DeviceType>();

View File

@ -72,12 +72,14 @@ class FixNeighHistoryKokkos : public FixNeighHistory, public KokkosBase {
void unpack_exchange_kokkos(DAT::tdual_xfloat_2d &k_buf,
DAT::tdual_int_1d &indices,int nrecv,
int nrecv1,int nrecv1extra,
ExecutionSpace space) override;
typename DAT::tdual_int_2d k_firstflag;
typename DAT::tdual_float_2d k_firstvalue;
private:
int nrecv1,nextrarecv1;
int nlocal,nsend,beyond_contact;
typename AT::t_tagint_1d tag;

View File

@ -1416,6 +1416,7 @@ KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::operator()(TagQEqUnpackExchange, const int &i) const
{
int index = d_indices(i);
if (index > -1) {
for (int m = 0; m < nprev; m++) d_s_hist(index,m) = d_buf(i*nprev*2 + m);
for (int m = 0; m < nprev; m++) d_t_hist(index,m) = d_buf(i*nprev*2 + nprev+m);
@ -1427,6 +1428,7 @@ void FixQEqReaxFFKokkos<DeviceType>::operator()(TagQEqUnpackExchange, const int
template <class DeviceType>
void FixQEqReaxFFKokkos<DeviceType>::unpack_exchange_kokkos(
DAT::tdual_xfloat_2d &k_buf, DAT::tdual_int_1d &k_indices, int nrecv,
int /*nrecv1*/, int /*nextrarecv1*/,
ExecutionSpace /*space*/)
{
k_buf.sync<DeviceType>();

View File

@ -143,6 +143,7 @@ class FixQEqReaxFFKokkos : public FixQEqReaxFF, public KokkosBase {
void unpack_exchange_kokkos(DAT::tdual_xfloat_2d &k_buf,
DAT::tdual_int_1d &indices,int nrecv,
int nrecv1,int nextrarecv1,
ExecutionSpace space) override;
struct params_qeq{

View File

@ -1581,8 +1581,8 @@ void FixShakeKokkos<DeviceType>::pack_exchange_item(const int &mysend, int &offs
else offset++;
} else {
d_buf[mysend] = nsend + offset;
int m = nsend + offset;
d_buf[mysend] = m;
d_buf[m++] = flag;
if (flag == 1) {
d_buf[m++] = d_shake_atom(i,0);
@ -1703,6 +1703,8 @@ void FixShakeKokkos<DeviceType>::operator()(TagFixShakeUnpackExchange, const int
if (index > -1) {
int m = d_buf[i];
if (i >= nrecv1)
m = nextrarecv1 + d_buf[nextrarecv1 + i - nrecv1];
int flag = d_shake_flag[index] = static_cast<int> (d_buf[m++]);
if (flag == 1) {
@ -1739,6 +1741,7 @@ void FixShakeKokkos<DeviceType>::operator()(TagFixShakeUnpackExchange, const int
template <class DeviceType>
void FixShakeKokkos<DeviceType>::unpack_exchange_kokkos(
DAT::tdual_xfloat_2d &k_buf, DAT::tdual_int_1d &k_indices, int nrecv,
int nrecv1, int nextrarecv1,
ExecutionSpace /*space*/)
{
k_buf.sync<DeviceType>();
@ -1749,6 +1752,9 @@ void FixShakeKokkos<DeviceType>::unpack_exchange_kokkos(
k_buf.extent(0)*k_buf.extent(1));
d_indices = k_indices.view<DeviceType>();
this->nrecv1 = nrecv1;
this->nextrarecv1 = nextrarecv1;
k_shake_flag.template sync<DeviceType>();
k_shake_atom.template sync<DeviceType>();
k_shake_type.template sync<DeviceType>();

View File

@ -110,9 +110,12 @@ class FixShakeKokkos : public FixShake, public KokkosBase {
void unpack_exchange_kokkos(DAT::tdual_xfloat_2d &k_buf,
DAT::tdual_int_1d &indices,int nrecv,
int nrecv1,int nrecv1extra,
ExecutionSpace space) override;
protected:
int nrecv1,nextrarecv1;
typename AT::t_x_array d_x;
typename AT::t_v_array d_v;
typename AT::t_f_array d_f;
@ -257,4 +260,3 @@ struct FixShakeKokkosPackExchangeFunctor {
#endif
#endif

View File

@ -188,8 +188,8 @@ void FixSpringSelfKokkos<DeviceType>::pack_exchange_item(const int &mysend, int
{
const int i = d_exchange_sendlist(mysend);
d_buf[mysend] = nsend + offset;
int m = nsend + offset;
d_buf[mysend] = m;
d_buf[m++] = d_xoriginal(i,0);
d_buf[m++] = d_xoriginal(i,1);
d_buf[m++] = d_xoriginal(i,2);
@ -258,6 +258,8 @@ void FixSpringSelfKokkos<DeviceType>::operator()(TagFixSpringSelfUnpackExchange,
if (index > -1) {
int m = d_buf[i];
if (i >= nrecv1)
m = nextrarecv1 + d_buf[nextrarecv1 + i - nrecv1];
d_xoriginal(index,0) = static_cast<tagint> (d_buf[m++]);
d_xoriginal(index,1) = static_cast<tagint> (d_buf[m++]);
@ -270,6 +272,7 @@ void FixSpringSelfKokkos<DeviceType>::operator()(TagFixSpringSelfUnpackExchange,
template <class DeviceType>
void FixSpringSelfKokkos<DeviceType>::unpack_exchange_kokkos(
DAT::tdual_xfloat_2d &k_buf, DAT::tdual_int_1d &k_indices, int nrecv,
int nrecv1, int nextrarecv1,
ExecutionSpace /*space*/)
{
k_buf.sync<DeviceType>();
@ -280,6 +283,9 @@ void FixSpringSelfKokkos<DeviceType>::unpack_exchange_kokkos(
k_buf.extent(0)*k_buf.extent(1));
d_indices = k_indices.view<DeviceType>();
this->nrecv1 = nrecv1;
this->nextrarecv1 = nextrarecv1;
k_xoriginal.template sync<DeviceType>();
copymode = 1;

View File

@ -58,6 +58,7 @@ class FixSpringSelfKokkos : public FixSpringSelf, public KokkosBase {
void unpack_exchange_kokkos(DAT::tdual_xfloat_2d &k_buf,
DAT::tdual_int_1d &indices,int nrecv,
int nrecv1,int nrecv1extra,
ExecutionSpace space) override;
@ -65,6 +66,8 @@ class FixSpringSelfKokkos : public FixSpringSelf, public KokkosBase {
int unpack_exchange(int, double *) override;
protected:
int nrecv1,nextrarecv1;
DAT::tdual_x_array k_xoriginal;
typename AT::t_x_array d_xoriginal;

View File

@ -419,6 +419,7 @@ void FixWallGranKokkos<DeviceType>::operator()(TagFixWallGranUnpackExchange, con
template<class DeviceType>
void FixWallGranKokkos<DeviceType>::unpack_exchange_kokkos(
DAT::tdual_xfloat_2d &k_buf, DAT::tdual_int_1d &k_indices, int nrecv,
int /*nrecv1*/, int /*nextrarecv1*/,
ExecutionSpace /*space*/)
{
d_buf = typename ArrayTypes<DeviceType>::t_xfloat_1d_um(
@ -430,7 +431,6 @@ void FixWallGranKokkos<DeviceType>::unpack_exchange_kokkos(
copymode = 1;
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagFixWallGranUnpackExchange>(0,nrecv),*this);
copymode = 0;

View File

@ -62,12 +62,13 @@ class FixWallGranKokkos : public FixWallGranOld, public KokkosBase {
void operator()(TagFixWallGranUnpackExchange, const int&) const;
int pack_exchange_kokkos(const int &nsend,DAT::tdual_xfloat_2d &buf,
DAT::tdual_int_1d k_sendlist,
DAT::tdual_int_1d k_copylist,
ExecutionSpace space) override;
DAT::tdual_int_1d k_sendlist,
DAT::tdual_int_1d k_copylist,
ExecutionSpace space) override;
void unpack_exchange_kokkos(DAT::tdual_xfloat_2d &k_buf,
DAT::tdual_int_1d &indices,int nrecv,
int nrecv1,int nrecv1extra,
ExecutionSpace space) override;
private:
@ -91,6 +92,7 @@ class FixWallGranKokkos : public FixWallGranOld, public KokkosBase {
typename AT::t_int_1d d_copylist;
typename AT::t_int_1d d_indices;
};
}
#endif

View File

@ -47,6 +47,7 @@ class KokkosBase {
ExecutionSpace /*space*/) { return 0; }
virtual void unpack_exchange_kokkos(DAT::tdual_xfloat_2d & /*k_buf*/,
DAT::tdual_int_1d & /*indices*/, int /*nrecv*/,
int /*nrecv1*/, int /*nextrarecv1*/,
ExecutionSpace /*space*/) {}
// Region