Fix more GPU data movement issues with fix langevin/kk and gjf option
This commit is contained in:
@ -63,22 +63,16 @@ FixLangevinKokkos<DeviceType>::FixLangevinKokkos(LAMMPS *lmp, int narg, char **a
|
||||
k_ratio.modify_host();
|
||||
|
||||
if (gjfflag) {
|
||||
memory->destroy(franprev);
|
||||
memory->destroy(lv);
|
||||
grow_arrays(atomKK->nmax);
|
||||
atom->add_callback(Atom::GROW);
|
||||
k_franprev.sync_host();
|
||||
k_lv.sync_host();
|
||||
|
||||
// initialize franprev to zero
|
||||
for (int i = 0; i < atomKK->nlocal; i++) {
|
||||
franprev[i][0] = 0.0;
|
||||
franprev[i][1] = 0.0;
|
||||
franprev[i][2] = 0.0;
|
||||
lv[i][0] = 0.0;
|
||||
lv[i][1] = 0.0;
|
||||
lv[i][2] = 0.0;
|
||||
}
|
||||
k_franprev.modify_host();
|
||||
k_lv.modify_host();
|
||||
|
||||
Kokkos::deep_copy(d_franprev,0.0);
|
||||
Kokkos::deep_copy(d_lv,0.0);
|
||||
}
|
||||
|
||||
if (zeroflag) {
|
||||
k_fsumall = tdual_double_1d_3n("langevin:fsumall");
|
||||
h_fsumall = k_fsumall.template view<LMPHostType>();
|
||||
@ -99,8 +93,10 @@ FixLangevinKokkos<DeviceType>::~FixLangevinKokkos()
|
||||
memoryKK->destroy_kokkos(k_gfactor2,gfactor2);
|
||||
memoryKK->destroy_kokkos(k_ratio,ratio);
|
||||
memoryKK->destroy_kokkos(k_flangevin,flangevin);
|
||||
if (gjfflag) memoryKK->destroy_kokkos(k_franprev,franprev);
|
||||
if (gjfflag) memoryKK->destroy_kokkos(k_lv,lv);
|
||||
if (gjfflag) {
|
||||
memoryKK->destroy_kokkos(k_franprev,franprev);
|
||||
memoryKK->destroy_kokkos(k_lv,lv);
|
||||
}
|
||||
memoryKK->destroy_kokkos(k_tforce,tforce);
|
||||
}
|
||||
|
||||
@ -126,6 +122,160 @@ void FixLangevinKokkos<DeviceType>::init()
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
template<class DeviceType>
|
||||
void FixLangevinKokkos<DeviceType>::setup(int vflag)
|
||||
{
|
||||
if (gjfflag) {
|
||||
double dt = update->dt;
|
||||
double ftm2v = force->ftm2v;
|
||||
auto v = atomKK->k_v.view<DeviceType>();
|
||||
auto f = atomKK->k_f.view<DeviceType>();
|
||||
auto mask = atomKK->k_mask.view<DeviceType>();
|
||||
int nlocal = atom->nlocal;
|
||||
auto rmass = atomKK->k_rmass.view<DeviceType>();
|
||||
auto mass = atomKK->k_mass.view<DeviceType>();
|
||||
auto type = atomKK->k_type.view<DeviceType>();
|
||||
|
||||
if (atom->rmass) {
|
||||
atomKK->sync(execution_space,V_MASK|F_MASK|MASK_MASK|RMASS_MASK);
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType>(0,nlocal), KOKKOS_LAMBDA(const int &i) {
|
||||
if (mask[i] & groupbit) {
|
||||
const double dtfm = ftm2v * 0.5 * dt / rmass[i];
|
||||
v(i,0) -= dtfm * f(i,0);
|
||||
v(i,1) -= dtfm * f(i,1);
|
||||
v(i,2) -= dtfm * f(i,2);
|
||||
}
|
||||
});
|
||||
|
||||
if (tbiasflag) {
|
||||
// account for bias velocity
|
||||
if (temperature->kokkosable) {
|
||||
temperature->compute_scalar();
|
||||
temperature->remove_bias_all_kk();
|
||||
} else {
|
||||
atomKK->sync(temperature->execution_space,temperature->datamask_read);
|
||||
temperature->compute_scalar();
|
||||
temperature->remove_bias_all();
|
||||
atomKK->modified(temperature->execution_space,temperature->datamask_modify);
|
||||
atomKK->sync(execution_space,temperature->datamask_modify);
|
||||
}
|
||||
}
|
||||
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType>(0,nlocal), KOKKOS_LAMBDA(const int &i) {
|
||||
if (mask[i] & groupbit) {
|
||||
v(i,0) /= gjfa * gjfsib * gjfsib;
|
||||
v(i,1) /= gjfa * gjfsib * gjfsib;
|
||||
v(i,2) /= gjfa * gjfsib * gjfsib;
|
||||
}
|
||||
});
|
||||
|
||||
if (tbiasflag) {
|
||||
if (temperature->kokkosable) temperature->restore_bias_all();
|
||||
else {
|
||||
atomKK->sync(temperature->execution_space,temperature->datamask_read);
|
||||
temperature->restore_bias_all();
|
||||
atomKK->modified(temperature->execution_space,temperature->datamask_modify);
|
||||
atomKK->sync(execution_space,temperature->datamask_modify);
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
atomKK->sync(execution_space,V_MASK|F_MASK|MASK_MASK|TYPE_MASK);
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType>(0,nlocal), KOKKOS_LAMBDA(const int &i) {
|
||||
if (mask[i] & groupbit) {
|
||||
const double dtfm = ftm2v * 0.5 * dt / mass[type[i]];
|
||||
v(i,0) -= dtfm * f(i,0);
|
||||
v(i,1) -= dtfm * f(i,1);
|
||||
v(i,2) -= dtfm * f(i,2);
|
||||
}
|
||||
});
|
||||
|
||||
if (tbiasflag) {
|
||||
// account for bias velocity
|
||||
if (temperature->kokkosable) {
|
||||
temperature->compute_scalar();
|
||||
temperature->remove_bias_all_kk();
|
||||
} else {
|
||||
atomKK->sync(temperature->execution_space,temperature->datamask_read);
|
||||
temperature->compute_scalar();
|
||||
temperature->remove_bias_all();
|
||||
atomKK->modified(temperature->execution_space,temperature->datamask_modify);
|
||||
atomKK->sync(execution_space,temperature->datamask_modify);
|
||||
}
|
||||
}
|
||||
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType>(0,nlocal), KOKKOS_LAMBDA(const int &i) {
|
||||
if (mask[i] & groupbit) {
|
||||
v(i,0) /= gjfa * gjfsib * gjfsib;
|
||||
v(i,1) /= gjfa * gjfsib * gjfsib;
|
||||
v(i,2) /= gjfa * gjfsib * gjfsib;
|
||||
}
|
||||
});
|
||||
|
||||
if (tbiasflag) {
|
||||
if (temperature->kokkosable) temperature->restore_bias_all();
|
||||
else {
|
||||
atomKK->sync(temperature->execution_space,temperature->datamask_read);
|
||||
temperature->restore_bias_all();
|
||||
atomKK->modified(temperature->execution_space,temperature->datamask_modify);
|
||||
atomKK->sync(execution_space,temperature->datamask_modify);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
atomKK->modified(execution_space,V_MASK);
|
||||
}
|
||||
|
||||
post_force(vflag);
|
||||
|
||||
if (gjfflag) {
|
||||
double dt = update->dt;
|
||||
double ftm2v = force->ftm2v;
|
||||
auto f = atomKK->k_f.view<DeviceType>();
|
||||
auto v = atomKK->k_v.view<DeviceType>();
|
||||
auto mask = atomKK->k_mask.view<DeviceType>();
|
||||
int nlocal = atom->nlocal;
|
||||
auto rmass = atomKK->k_rmass.view<DeviceType>();
|
||||
auto mass = atomKK->k_mass.view<DeviceType>();
|
||||
auto type = atomKK->k_type.view<DeviceType>();
|
||||
|
||||
k_lv.template sync<DeviceType>();
|
||||
auto l_lv = d_lv;
|
||||
|
||||
if (atom->rmass) {
|
||||
atomKK->sync(execution_space,V_MASK|F_MASK|MASK_MASK|RMASS_MASK);
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType>(0,nlocal), KOKKOS_LAMBDA(const int &i) {
|
||||
if (mask[i] & groupbit) {
|
||||
const double dtfm = ftm2v * 0.5 * dt / rmass[i];
|
||||
v(i,0) += dtfm * f(i,0);
|
||||
v(i,1) += dtfm * f(i,1);
|
||||
v(i,2) += dtfm * f(i,2);
|
||||
l_lv(i,0) = v(i,0);
|
||||
l_lv(i,1) = v(i,1);
|
||||
l_lv(i,2) = v(i,2);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
atomKK->sync(execution_space,V_MASK|F_MASK|MASK_MASK|TYPE_MASK);
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType>(0,nlocal), KOKKOS_LAMBDA(const int &i) {
|
||||
if (mask[i] & groupbit) {
|
||||
const double dtfm = ftm2v * 0.5 * dt / mass[type[i]];
|
||||
v(i,0) += dtfm * f(i,0);
|
||||
v(i,1) += dtfm * f(i,1);
|
||||
v(i,2) += dtfm * f(i,2);
|
||||
l_lv(i,0) = v(i,0);
|
||||
l_lv(i,1) = v(i,1);
|
||||
l_lv(i,2) = v(i,2);
|
||||
}
|
||||
});
|
||||
}
|
||||
atomKK->modified(execution_space,V_MASK);
|
||||
k_lv.template modify<DeviceType>();
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
template<class DeviceType>
|
||||
void FixLangevinKokkos<DeviceType>::grow_arrays(int nmax)
|
||||
{
|
||||
@ -143,7 +293,6 @@ template<class DeviceType>
|
||||
void FixLangevinKokkos<DeviceType>::initial_integrate(int /*vflag*/)
|
||||
{
|
||||
atomKK->sync(execution_space,datamask_read);
|
||||
atomKK->modified(execution_space,datamask_modify);
|
||||
|
||||
v = atomKK->k_v.view<DeviceType>();
|
||||
f = atomKK->k_f.view<DeviceType>();
|
||||
@ -152,6 +301,8 @@ void FixLangevinKokkos<DeviceType>::initial_integrate(int /*vflag*/)
|
||||
|
||||
FixLangevinKokkosInitialIntegrateFunctor<DeviceType> functor(this);
|
||||
Kokkos::parallel_for(nlocal,functor);
|
||||
|
||||
atomKK->modified(execution_space,datamask_modify);
|
||||
}
|
||||
|
||||
template<class DeviceType>
|
||||
|
||||
@ -68,6 +68,7 @@ namespace LAMMPS_NS {
|
||||
|
||||
void cleanup_copy();
|
||||
void init() override;
|
||||
void setup(int) override;
|
||||
void initial_integrate(int) override;
|
||||
void fused_integrate(int) override;
|
||||
void post_force(int) override;
|
||||
|
||||
Reference in New Issue
Block a user