diff --git a/src/KOKKOS/fix_langevin_kokkos.cpp b/src/KOKKOS/fix_langevin_kokkos.cpp index b7305644c9..14eb0f1ab7 100644 --- a/src/KOKKOS/fix_langevin_kokkos.cpp +++ b/src/KOKKOS/fix_langevin_kokkos.cpp @@ -44,6 +44,7 @@ FixLangevinKokkos::FixLangevinKokkos(LAMMPS *lmp, int narg, char **a FixLangevin(lmp, narg, arg),rand_pool(seed + comm->me) { kokkosable = 1; + fuse_integrate_flag = 1; atomKK = (AtomKokkos *) atom; int ntypes = atomKK->ntypes; @@ -169,6 +170,14 @@ void FixLangevinKokkos::initial_integrate_item(int i) const /* ---------------------------------------------------------------------- */ +template +void FixLangevinKokkos::fused_integrate(int vflag) +{ + initial_integrate(vflag); +} + +/* ---------------------------------------------------------------------- */ + template void FixLangevinKokkos::post_force(int /*vflag*/) { diff --git a/src/KOKKOS/fix_langevin_kokkos.h b/src/KOKKOS/fix_langevin_kokkos.h index f7142e6286..0bd628270e 100644 --- a/src/KOKKOS/fix_langevin_kokkos.h +++ b/src/KOKKOS/fix_langevin_kokkos.h @@ -69,6 +69,7 @@ namespace LAMMPS_NS { void cleanup_copy(); void init() override; void initial_integrate(int) override; + void fused_integrate(int) override; void post_force(int) override; void reset_dt() override; void grow_arrays(int) override; diff --git a/src/KOKKOS/fix_nve_kokkos.cpp b/src/KOKKOS/fix_nve_kokkos.cpp index c26e26a02c..b8236c2657 100644 --- a/src/KOKKOS/fix_nve_kokkos.cpp +++ b/src/KOKKOS/fix_nve_kokkos.cpp @@ -160,21 +160,12 @@ void FixNVEKokkos::final_integrate_rmass_item(int i) const } } -/* ---------------------------------------------------------------------- */ - -template -void FixNVEKokkos::cleanup_copy() -{ - id = style = nullptr; - vatom = nullptr; -} - /* ---------------------------------------------------------------------- allow for both per-type and per-atom mass ------------------------------------------------------------------------- */ template -void FixNVEKokkos::fused_integrate() +void FixNVEKokkos::fused_integrate(int /*vflag*/) { atomKK->sync(execution_space,datamask_read); @@ -199,6 +190,47 @@ void FixNVEKokkos::fused_integrate() atomKK->modified(execution_space,datamask_modify); } +template +KOKKOS_INLINE_FUNCTION +void FixNVEKokkos::fused_integrate_item(int i) const +{ + if (mask[i] & groupbit) { + const double dtfm = 2.0 * dtf / mass[type[i]]; + v(i,0) += dtfm * f(i,0); + v(i,1) += dtfm * f(i,1); + v(i,2) += dtfm * f(i,2); + x(i,0) += dtv * v(i,0); + x(i,1) += dtv * v(i,1); + x(i,2) += dtv * v(i,2); + } +} + +template +KOKKOS_INLINE_FUNCTION +void FixNVEKokkos::fused_integrate_rmass_item(int i) const +{ + if (mask[i] & groupbit) { + const double dtfm = 2.0 * dtf / rmass[i]; + v(i,0) += dtfm * f(i,0); + v(i,1) += dtfm * f(i,1); + v(i,2) += dtfm * f(i,2); + x(i,0) += dtv * v(i,0); + x(i,1) += dtv * v(i,1); + x(i,2) += dtv * v(i,2); + } +} + +/* ---------------------------------------------------------------------- */ + +template +void FixNVEKokkos::cleanup_copy() +{ + id = style = nullptr; + vatom = nullptr; +} + +/* ---------------------------------------------------------------------- */ + namespace LAMMPS_NS { template class FixNVEKokkos; #ifdef LMP_KOKKOS_GPU diff --git a/src/KOKKOS/fix_nve_kokkos.h b/src/KOKKOS/fix_nve_kokkos.h index a10f846cba..a1e8e1398c 100644 --- a/src/KOKKOS/fix_nve_kokkos.h +++ b/src/KOKKOS/fix_nve_kokkos.h @@ -46,7 +46,7 @@ class FixNVEKokkos : public FixNVE { void init() override; void initial_integrate(int) override; void final_integrate() override; - void fused_integrate() override; + void fused_integrate(int) override; KOKKOS_INLINE_FUNCTION void initial_integrate_item(int) const; @@ -56,6 +56,10 @@ class FixNVEKokkos : public FixNVE { void final_integrate_item(int) const; KOKKOS_INLINE_FUNCTION void final_integrate_rmass_item(int) const; + KOKKOS_INLINE_FUNCTION + void fused_integrate_item(int) const; + KOKKOS_INLINE_FUNCTION + void fused_integrate_rmass_item(int) const; private: @@ -106,13 +110,10 @@ struct FixNVEKokkosFusedIntegrateFunctor { c(*c_ptr) {c.cleanup_copy();}; KOKKOS_INLINE_FUNCTION void operator()(const int i) const { - if (RMass) { - c.final_integrate_rmass_item(i); - c.initial_integrate_rmass_item(i); - } else { - c.final_integrate_item(i); - c.initial_integrate_item(i); - } + if (RMass) + c.fused_integrate_rmass_item(i); + else + c.fused_integrate_item(i); } }; diff --git a/src/KOKKOS/modify_kokkos.cpp b/src/KOKKOS/modify_kokkos.cpp index 44fcb48727..0b81a1cabb 100644 --- a/src/KOKKOS/modify_kokkos.cpp +++ b/src/KOKKOS/modify_kokkos.cpp @@ -396,14 +396,14 @@ void ModifyKokkos::final_integrate() fused initial and final integrate call, only for relevant fixes ------------------------------------------------------------------------- */ -void ModifyKokkos::fused_integrate() +void ModifyKokkos::fused_integrate(int vflag) { for (int i = 0; i < n_final_integrate; i++) { atomKK->sync(fix[list_final_integrate[i]]->execution_space, fix[list_final_integrate[i]]->datamask_read); int prev_auto_sync = lmp->kokkos->auto_sync; if (!fix[list_final_integrate[i]]->kokkosable) lmp->kokkos->auto_sync = 1; - fix[list_final_integrate[i]]->fused_integrate(); + fix[list_final_integrate[i]]->fused_integrate(vflag); lmp->kokkos->auto_sync = prev_auto_sync; atomKK->modified(fix[list_final_integrate[i]]->execution_space, fix[list_final_integrate[i]]->datamask_modify); diff --git a/src/KOKKOS/modify_kokkos.h b/src/KOKKOS/modify_kokkos.h index 8dd9e6d9df..527518219c 100644 --- a/src/KOKKOS/modify_kokkos.h +++ b/src/KOKKOS/modify_kokkos.h @@ -39,7 +39,7 @@ class ModifyKokkos : public Modify { void pre_reverse(int,int) override; void post_force(int) override; void final_integrate() override; - void fused_integrate() override; + void fused_integrate(int) override; void end_of_step() override; double energy_couple() override; double energy_global() override; diff --git a/src/KOKKOS/verlet_kokkos.cpp b/src/KOKKOS/verlet_kokkos.cpp index bd15859644..3b88b34b38 100644 --- a/src/KOKKOS/verlet_kokkos.cpp +++ b/src/KOKKOS/verlet_kokkos.cpp @@ -498,7 +498,7 @@ void VerletKokkos::run(int n) if (n_post_force) modify->post_force(vflag); - if (fuse_integrate) modify->fused_integrate(); + if (fuse_integrate) modify->fused_integrate(vflag); else modify->final_integrate(); if (n_end_of_step) modify->end_of_step(); diff --git a/src/fix.h b/src/fix.h index 23b120d989..9676651afb 100644 --- a/src/fix.h +++ b/src/fix.h @@ -153,7 +153,6 @@ class Fix : protected Pointers { virtual void setup_pre_reverse(int, int) {} virtual void min_setup(int) {} virtual void initial_integrate(int) {} - virtual void fused_integrate() {} virtual void post_integrate() {} virtual void pre_exchange() {} virtual void pre_neighbor() {} @@ -162,6 +161,7 @@ class Fix : protected Pointers { virtual void pre_reverse(int, int) {} virtual void post_force(int) {} virtual void final_integrate() {} + virtual void fused_integrate(int) {} virtual void end_of_step() {} virtual void post_run() {} virtual void write_restart(FILE *) {} diff --git a/src/modify.h b/src/modify.h index 9686a9a7ec..6ca4b4ad26 100644 --- a/src/modify.h +++ b/src/modify.h @@ -61,7 +61,6 @@ class Modify : protected Pointers { virtual void setup_pre_force(int); virtual void setup_pre_reverse(int, int); virtual void initial_integrate(int); - virtual void fused_integrate() {} virtual void post_integrate(); virtual void pre_exchange(); virtual void pre_neighbor(); @@ -70,6 +69,7 @@ class Modify : protected Pointers { virtual void pre_reverse(int, int); virtual void post_force(int); virtual void final_integrate(); + virtual void fused_integrate(int) {} virtual void end_of_step(); virtual double energy_couple(); virtual double energy_global();