Avoid using a host pointer in device code

This commit is contained in:
Stan Gerald Moore
2022-10-10 12:08:31 -06:00
parent e661297838
commit 0f33ff1fc1
2 changed files with 130 additions and 100 deletions

View File

@ -123,6 +123,7 @@ void PairMEAMKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType, TagPairMEAMOffsets>(0,inum_half),*this,n);
meam_inst_kk->meam_dens_setup(atom->nmax, nall, n);
update_meam_views();
x = atomKK->k_x.view<DeviceType>();
f = atomKK->k_f.view<DeviceType>();
@ -324,35 +325,35 @@ KOKKOS_INLINE_FUNCTION
void PairMEAMKokkos<DeviceType>::operator()(TagPairMEAMPackForwardComm, const int &i) const {
int j = d_sendlist(iswap, i);
int m = i*38;
v_buf[m++] = meam_inst_kk->d_rho0[j];
v_buf[m++] = meam_inst_kk->d_rho1[j];
v_buf[m++] = meam_inst_kk->d_rho2[j];
v_buf[m++] = meam_inst_kk->d_rho3[j];
v_buf[m++] = meam_inst_kk->d_frhop[j];
v_buf[m++] = meam_inst_kk->d_gamma[j];
v_buf[m++] = meam_inst_kk->d_dgamma1[j];
v_buf[m++] = meam_inst_kk->d_dgamma2[j];
v_buf[m++] = meam_inst_kk->d_dgamma3[j];
v_buf[m++] = meam_inst_kk->d_arho2b[j];
v_buf[m++] = meam_inst_kk->d_arho1(j,0);
v_buf[m++] = meam_inst_kk->d_arho1(j,1);
v_buf[m++] = meam_inst_kk->d_arho1(j,2);
v_buf[m++] = meam_inst_kk->d_arho2(j,0);
v_buf[m++] = meam_inst_kk->d_arho2(j,1);
v_buf[m++] = meam_inst_kk->d_arho2(j,2);
v_buf[m++] = meam_inst_kk->d_arho2(j,3);
v_buf[m++] = meam_inst_kk->d_arho2(j,4);
v_buf[m++] = meam_inst_kk->d_arho2(j,5);
for (int k = 0; k < 10; k++) v_buf[m++] = meam_inst_kk->d_arho3(j,k);
v_buf[m++] = meam_inst_kk->d_arho3b(j,0);
v_buf[m++] = meam_inst_kk->d_arho3b(j,1);
v_buf[m++] = meam_inst_kk->d_arho3b(j,2);
v_buf[m++] = meam_inst_kk->d_t_ave(j,0);
v_buf[m++] = meam_inst_kk->d_t_ave(j,1);
v_buf[m++] = meam_inst_kk->d_t_ave(j,2);
v_buf[m++] = meam_inst_kk->d_tsq_ave(j,0);
v_buf[m++] = meam_inst_kk->d_tsq_ave(j,1);
v_buf[m++] = meam_inst_kk->d_tsq_ave(j,2);
v_buf[m++] = d_rho0[j];
v_buf[m++] = d_rho1[j];
v_buf[m++] = d_rho2[j];
v_buf[m++] = d_rho3[j];
v_buf[m++] = d_frhop[j];
v_buf[m++] = d_gamma[j];
v_buf[m++] = d_dgamma1[j];
v_buf[m++] = d_dgamma2[j];
v_buf[m++] = d_dgamma3[j];
v_buf[m++] = d_arho2b[j];
v_buf[m++] = d_arho1(j,0);
v_buf[m++] = d_arho1(j,1);
v_buf[m++] = d_arho1(j,2);
v_buf[m++] = d_arho2(j,0);
v_buf[m++] = d_arho2(j,1);
v_buf[m++] = d_arho2(j,2);
v_buf[m++] = d_arho2(j,3);
v_buf[m++] = d_arho2(j,4);
v_buf[m++] = d_arho2(j,5);
for (int k = 0; k < 10; k++) v_buf[m++] = d_arho3(j,k);
v_buf[m++] = d_arho3b(j,0);
v_buf[m++] = d_arho3b(j,1);
v_buf[m++] = d_arho3b(j,2);
v_buf[m++] = d_t_ave(j,0);
v_buf[m++] = d_t_ave(j,1);
v_buf[m++] = d_t_ave(j,2);
v_buf[m++] = d_tsq_ave(j,0);
v_buf[m++] = d_tsq_ave(j,1);
v_buf[m++] = d_tsq_ave(j,2);
}
/* ---------------------------------------------------------------------- */
@ -372,35 +373,35 @@ KOKKOS_INLINE_FUNCTION
void PairMEAMKokkos<DeviceType>::operator()(TagPairMEAMUnpackForwardComm, const int &i) const{
int m = i*38;
meam_inst_kk->d_rho0[i+first] = v_buf[m++];
meam_inst_kk->d_rho1[i+first] = v_buf[m++];
meam_inst_kk->d_rho2[i+first] = v_buf[m++];
meam_inst_kk->d_rho3[i+first] = v_buf[m++];
meam_inst_kk->d_frhop[i+first] = v_buf[m++];
meam_inst_kk->d_gamma[i+first] = v_buf[m++];
meam_inst_kk->d_dgamma1[i+first] = v_buf[m++];
meam_inst_kk->d_dgamma2[i+first] = v_buf[m++];
meam_inst_kk->d_dgamma3[i+first] = v_buf[m++];
meam_inst_kk->d_arho2b[i+first] = v_buf[m++];
meam_inst_kk->d_arho1(i+first,0) = v_buf[m++];
meam_inst_kk->d_arho1(i+first,1) = v_buf[m++];
meam_inst_kk->d_arho1(i+first,2) = v_buf[m++];
meam_inst_kk->d_arho2(i+first,0) = v_buf[m++];
meam_inst_kk->d_arho2(i+first,1) = v_buf[m++];
meam_inst_kk->d_arho2(i+first,2) = v_buf[m++];
meam_inst_kk->d_arho2(i+first,3) = v_buf[m++];
meam_inst_kk->d_arho2(i+first,4) = v_buf[m++];
meam_inst_kk->d_arho2(i+first,5) = v_buf[m++];
for (int k = 0; k < 10; k++) meam_inst_kk->d_arho3(i+first,k) = v_buf[m++];
meam_inst_kk->d_arho3b(i+first,0) = v_buf[m++];
meam_inst_kk->d_arho3b(i+first,1) = v_buf[m++];
meam_inst_kk->d_arho3b(i+first,2) = v_buf[m++];
meam_inst_kk->d_t_ave(i+first,0) = v_buf[m++];
meam_inst_kk->d_t_ave(i+first,1) = v_buf[m++];
meam_inst_kk->d_t_ave(i+first,2) = v_buf[m++];
meam_inst_kk->d_tsq_ave(i+first,0) = v_buf[m++];
meam_inst_kk->d_tsq_ave(i+first,1) = v_buf[m++];
meam_inst_kk->d_tsq_ave(i+first,2) = v_buf[m++];
d_rho0[i+first] = v_buf[m++];
d_rho1[i+first] = v_buf[m++];
d_rho2[i+first] = v_buf[m++];
d_rho3[i+first] = v_buf[m++];
d_frhop[i+first] = v_buf[m++];
d_gamma[i+first] = v_buf[m++];
d_dgamma1[i+first] = v_buf[m++];
d_dgamma2[i+first] = v_buf[m++];
d_dgamma3[i+first] = v_buf[m++];
d_arho2b[i+first] = v_buf[m++];
d_arho1(i+first,0) = v_buf[m++];
d_arho1(i+first,1) = v_buf[m++];
d_arho1(i+first,2) = v_buf[m++];
d_arho2(i+first,0) = v_buf[m++];
d_arho2(i+first,1) = v_buf[m++];
d_arho2(i+first,2) = v_buf[m++];
d_arho2(i+first,3) = v_buf[m++];
d_arho2(i+first,4) = v_buf[m++];
d_arho2(i+first,5) = v_buf[m++];
for (int k = 0; k < 10; k++) d_arho3(i+first,k) = v_buf[m++];
d_arho3b(i+first,0) = v_buf[m++];
d_arho3b(i+first,1) = v_buf[m++];
d_arho3b(i+first,2) = v_buf[m++];
d_t_ave(i+first,0) = v_buf[m++];
d_t_ave(i+first,1) = v_buf[m++];
d_t_ave(i+first,2) = v_buf[m++];
d_tsq_ave(i+first,0) = v_buf[m++];
d_tsq_ave(i+first,1) = v_buf[m++];
d_tsq_ave(i+first,2) = v_buf[m++];
}
/* ---------------------------------------------------------------------- */
@ -555,27 +556,27 @@ KOKKOS_INLINE_FUNCTION
void PairMEAMKokkos<DeviceType>::operator()(TagPairMEAMPackReverseComm, const int &i) const {
int m = i*30;
v_buf[m++] = meam_inst_kk->d_rho0[i+first];
v_buf[m++] = meam_inst_kk->d_arho2b[i+first];
v_buf[m++] = meam_inst_kk->d_arho1(i+first,0);
v_buf[m++] = meam_inst_kk->d_arho1(i+first,1);
v_buf[m++] = meam_inst_kk->d_arho1(i+first,2);
v_buf[m++] = meam_inst_kk->d_arho2(i+first,0);
v_buf[m++] = meam_inst_kk->d_arho2(i+first,1);
v_buf[m++] = meam_inst_kk->d_arho2(i+first,2);
v_buf[m++] = meam_inst_kk->d_arho2(i+first,3);
v_buf[m++] = meam_inst_kk->d_arho2(i+first,4);
v_buf[m++] = meam_inst_kk->d_arho2(i+first,5);
for (int k = 0; k < 10; k++) v_buf[m++] = meam_inst_kk->d_arho3(i+first,k);
v_buf[m++] = meam_inst_kk->d_arho3b(i+first,0);
v_buf[m++] = meam_inst_kk->d_arho3b(i+first,1);
v_buf[m++] = meam_inst_kk->d_arho3b(i+first,2);
v_buf[m++] = meam_inst_kk->d_t_ave(i+first,0);
v_buf[m++] = meam_inst_kk->d_t_ave(i+first,1);
v_buf[m++] = meam_inst_kk->d_t_ave(i+first,2);
v_buf[m++] = meam_inst_kk->d_tsq_ave(i+first,0);
v_buf[m++] = meam_inst_kk->d_tsq_ave(i+first,1);
v_buf[m++] = meam_inst_kk->d_tsq_ave(i+first,2);
v_buf[m++] = d_rho0[i+first];
v_buf[m++] = d_arho2b[i+first];
v_buf[m++] = d_arho1(i+first,0);
v_buf[m++] = d_arho1(i+first,1);
v_buf[m++] = d_arho1(i+first,2);
v_buf[m++] = d_arho2(i+first,0);
v_buf[m++] = d_arho2(i+first,1);
v_buf[m++] = d_arho2(i+first,2);
v_buf[m++] = d_arho2(i+first,3);
v_buf[m++] = d_arho2(i+first,4);
v_buf[m++] = d_arho2(i+first,5);
for (int k = 0; k < 10; k++) v_buf[m++] = d_arho3(i+first,k);
v_buf[m++] = d_arho3b(i+first,0);
v_buf[m++] = d_arho3b(i+first,1);
v_buf[m++] = d_arho3b(i+first,2);
v_buf[m++] = d_t_ave(i+first,0);
v_buf[m++] = d_t_ave(i+first,1);
v_buf[m++] = d_t_ave(i+first,2);
v_buf[m++] = d_tsq_ave(i+first,0);
v_buf[m++] = d_tsq_ave(i+first,1);
v_buf[m++] = d_tsq_ave(i+first,2);
}
/* ---------------------------------------------------------------------- */
@ -640,27 +641,27 @@ void PairMEAMKokkos<DeviceType>::operator()(TagPairMEAMUnpackReverseComm, const
int j = d_sendlist(iswap, i);
int m = i*30;
meam_inst_kk->d_rho0[j] += v_buf[m++];
meam_inst_kk->d_arho2b[j] += v_buf[m++];
meam_inst_kk->d_arho1(j,0) += v_buf[m++];
meam_inst_kk->d_arho1(j,1) += v_buf[m++];
meam_inst_kk->d_arho1(j,2) += v_buf[m++];
meam_inst_kk->d_arho2(j,0) += v_buf[m++];
meam_inst_kk->d_arho2(j,1) += v_buf[m++];
meam_inst_kk->d_arho2(j,2) += v_buf[m++];
meam_inst_kk->d_arho2(j,3) += v_buf[m++];
meam_inst_kk->d_arho2(j,4) += v_buf[m++];
meam_inst_kk->d_arho2(j,5) += v_buf[m++];
for (int k = 0; k < 10; k++) meam_inst_kk->d_arho3(j,k) += v_buf[m++];
meam_inst_kk->d_arho3b(j,0) += v_buf[m++];
meam_inst_kk->d_arho3b(j,1) += v_buf[m++];
meam_inst_kk->d_arho3b(j,2) += v_buf[m++];
meam_inst_kk->d_t_ave(j,0) += v_buf[m++];
meam_inst_kk->d_t_ave(j,1) += v_buf[m++];
meam_inst_kk->d_t_ave(j,2) += v_buf[m++];
meam_inst_kk->d_tsq_ave(j,0) += v_buf[m++];
meam_inst_kk->d_tsq_ave(j,1) += v_buf[m++];
meam_inst_kk->d_tsq_ave(j,2) += v_buf[m++];
d_rho0[j] += v_buf[m++];
d_arho2b[j] += v_buf[m++];
d_arho1(j,0) += v_buf[m++];
d_arho1(j,1) += v_buf[m++];
d_arho1(j,2) += v_buf[m++];
d_arho2(j,0) += v_buf[m++];
d_arho2(j,1) += v_buf[m++];
d_arho2(j,2) += v_buf[m++];
d_arho2(j,3) += v_buf[m++];
d_arho2(j,4) += v_buf[m++];
d_arho2(j,5) += v_buf[m++];
for (int k = 0; k < 10; k++) d_arho3(j,k) += v_buf[m++];
d_arho3b(j,0) += v_buf[m++];
d_arho3b(j,1) += v_buf[m++];
d_arho3b(j,2) += v_buf[m++];
d_t_ave(j,0) += v_buf[m++];
d_t_ave(j,1) += v_buf[m++];
d_t_ave(j,2) += v_buf[m++];
d_tsq_ave(j,0) += v_buf[m++];
d_tsq_ave(j,1) += v_buf[m++];
d_tsq_ave(j,2) += v_buf[m++];
}
/* ---------------------------------------------------------------------- */
@ -744,6 +745,29 @@ void PairMEAMKokkos<DeviceType>::operator()(TagPairMEAMOffsets, const int ii, in
/* ---------------------------------------------------------------------- */
template<class DeviceType>
void PairMEAMKokkos<DeviceType>::update_meam_views()
{
d_rho0 = meam_inst_kk->d_rho0;
d_rho1 = meam_inst_kk->d_rho1;
d_rho2 = meam_inst_kk->d_rho2;
d_rho3 = meam_inst_kk->d_rho3;
d_frhop = meam_inst_kk->d_frhop;
d_gamma = meam_inst_kk->d_gamma;
d_dgamma1 = meam_inst_kk->d_dgamma1;
d_dgamma2 = meam_inst_kk->d_dgamma2;
d_dgamma3 = meam_inst_kk->d_dgamma3;
d_arho1 = meam_inst_kk->d_arho1;
d_arho2 = meam_inst_kk->d_arho2;
d_arho3 = meam_inst_kk->d_arho3;
d_arho2b = meam_inst_kk->d_arho2b;
d_arho3b = meam_inst_kk->d_arho3b;
d_t_ave = meam_inst_kk->d_t_ave;
d_tsq_ave = meam_inst_kk->d_tsq_ave;
}
/* ---------------------------------------------------------------------- */
namespace LAMMPS_NS {
template class PairMEAMKokkos<LMPDeviceType>;
#ifdef KOKKOS_ENABLE_CUDA

View File

@ -114,6 +114,12 @@ class PairMEAMKokkos : public PairMEAM, public KokkosBase {
int iswap,first;
int neighflag,nlocal,nall,eflag,vflag;
typename ArrayTypes<DeviceType>::t_ffloat_1d d_rho, d_rho0, d_rho1, d_rho2, d_rho3, d_frhop;
typename ArrayTypes<DeviceType>::t_ffloat_1d d_gamma, d_dgamma1, d_dgamma2, d_dgamma3, d_arho2b;
typename ArrayTypes<DeviceType>::t_ffloat_2d d_arho1, d_arho2, d_arho3, d_arho3b, d_t_ave, d_tsq_ave;
void update_meam_views();
friend void pair_virial_fdotr_compute<PairMEAMKokkos>(PairMEAMKokkos*);
};