diff --git a/src/KOKKOS/fix_nve_kokkos.cpp b/src/KOKKOS/fix_nve_kokkos.cpp index 5dcb611d41..f812c575bc 100644 --- a/src/KOKKOS/fix_nve_kokkos.cpp +++ b/src/KOKKOS/fix_nve_kokkos.cpp @@ -29,6 +29,7 @@ FixNVEKokkos::FixNVEKokkos(LAMMPS *lmp, int narg, char **arg) : FixNVE(lmp, narg, arg) { kokkosable = 1; + fuse_integrate_flag = 1; atomKK = (AtomKokkos *) atom; execution_space = ExecutionSpaceFromDevice::space; @@ -168,6 +169,35 @@ void FixNVEKokkos::cleanup_copy() vatom = nullptr; } +/* ---------------------------------------------------------------------- + allow for both per-type and per-atom mass +------------------------------------------------------------------------- */ + +template +void FixNVEKokkos::fused_integrate() +{ + atomKK->sync(execution_space,datamask_read); + atomKK->modified(execution_space,datamask_modify); + + x = atomKK->k_x.view(); + v = atomKK->k_v.view(); + f = atomKK->k_f.view(); + rmass = atomKK->k_rmass.view(); + mass = atomKK->k_mass.view(); + type = atomKK->k_type.view(); + mask = atomKK->k_mask.view(); + int nlocal = atomKK->nlocal; + if (igroup == atomKK->firstgroup) nlocal = atomKK->nfirst; + + if (rmass.data()) { + FixNVEKokkosFusedIntegrateFunctor functor(this); + Kokkos::parallel_for(nlocal,functor); + } else { + FixNVEKokkosFusedIntegrateFunctor functor(this); + Kokkos::parallel_for(nlocal,functor); + } +} + namespace LAMMPS_NS { template class FixNVEKokkos; #ifdef LMP_KOKKOS_GPU diff --git a/src/KOKKOS/fix_nve_kokkos.h b/src/KOKKOS/fix_nve_kokkos.h index 215aacb4a0..a10f846cba 100644 --- a/src/KOKKOS/fix_nve_kokkos.h +++ b/src/KOKKOS/fix_nve_kokkos.h @@ -46,6 +46,7 @@ class FixNVEKokkos : public FixNVE { void init() override; void initial_integrate(int) override; void final_integrate() override; + void fused_integrate() override; KOKKOS_INLINE_FUNCTION void initial_integrate_item(int) const; @@ -96,6 +97,25 @@ struct FixNVEKokkosFinalIntegrateFunctor { } }; +template +struct FixNVEKokkosFusedIntegrateFunctor { + typedef DeviceType device_type ; + FixNVEKokkos c; + + FixNVEKokkosFusedIntegrateFunctor(FixNVEKokkos* c_ptr): + c(*c_ptr) {c.cleanup_copy();}; + KOKKOS_INLINE_FUNCTION + void operator()(const int i) const { + if (RMass) { + c.final_integrate_rmass_item(i); + c.initial_integrate_rmass_item(i); + } else { + c.final_integrate_item(i); + c.initial_integrate_item(i); + } + } +}; + } #endif diff --git a/src/KOKKOS/modify_kokkos.cpp b/src/KOKKOS/modify_kokkos.cpp index 9d8c16603e..581f6598f1 100644 --- a/src/KOKKOS/modify_kokkos.cpp +++ b/src/KOKKOS/modify_kokkos.cpp @@ -392,6 +392,25 @@ void ModifyKokkos::final_integrate() } } + +/* ---------------------------------------------------------------------- + 2nd half of integrate call, only for relevant fixes +------------------------------------------------------------------------- */ + +void ModifyKokkos::fused_integrate() +{ + for (int i = 0; i < n_final_integrate; i++) { + atomKK->sync(fix[list_final_integrate[i]]->execution_space, + fix[list_final_integrate[i]]->datamask_read); + int prev_auto_sync = lmp->kokkos->auto_sync; + if (!fix[list_final_integrate[i]]->kokkosable) lmp->kokkos->auto_sync = 1; + fix[list_final_integrate[i]]->fused_integrate(); + lmp->kokkos->auto_sync = prev_auto_sync; + atomKK->modified(fix[list_final_integrate[i]]->execution_space, + fix[list_final_integrate[i]]->datamask_modify); + } +} + /* ---------------------------------------------------------------------- end-of-timestep call, only for relevant fixes only call fix->end_of_step() on timesteps that are multiples of nevery diff --git a/src/KOKKOS/modify_kokkos.h b/src/KOKKOS/modify_kokkos.h index 5edf5cd662..e440693fbb 100644 --- a/src/KOKKOS/modify_kokkos.h +++ b/src/KOKKOS/modify_kokkos.h @@ -39,6 +39,7 @@ class ModifyKokkos : public Modify { void pre_reverse(int,int) override; void post_force(int) override; void final_integrate() override; + void fused_integrate() override; void end_of_step() override; double energy_couple() override; double energy_global() override; diff --git a/src/KOKKOS/pair_kokkos.h b/src/KOKKOS/pair_kokkos.h index 0ff244f67d..7551d03f89 100644 --- a/src/KOKKOS/pair_kokkos.h +++ b/src/KOKKOS/pair_kokkos.h @@ -280,6 +280,12 @@ struct PairComputeFunctor { const X_FLOAT ztmp = c.x(i,2); const int itype = c.type(i); + Kokkos::single(Kokkos::PerThread(team), [&] (){ + f(i,0) = 0.0; + f(i,1) = 0.0; + f(i,2) = 0.0; + }); + const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i); const int jnum = list.d_numneigh[i]; @@ -337,6 +343,12 @@ struct PairComputeFunctor { const int itype = c.type(i); const F_FLOAT qtmp = c.q(i); + Kokkos::single(Kokkos::PerThread(team), [&] (){ + f(i,0) = 0.0; + f(i,1) = 0.0; + f(i,2) = 0.0; + }); + const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i); const int jnum = list.d_numneigh[i]; @@ -399,6 +411,12 @@ struct PairComputeFunctor { const X_FLOAT ztmp = c.x(i,2); const int itype = c.type(i); + Kokkos::single(Kokkos::PerThread(team), [&] (){ + f(i,0) = 0.0; + f(i,1) = 0.0; + f(i,2) = 0.0; + }); + const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i); const int jnum = list.d_numneigh[i]; @@ -495,6 +513,12 @@ struct PairComputeFunctor { const int itype = c.type(i); const F_FLOAT qtmp = c.q(i); + Kokkos::single(Kokkos::PerThread(team), [&] (){ + f(i,0) = 0.0; + f(i,1) = 0.0; + f(i,2) = 0.0; + }); + const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i); const int jnum = list.d_numneigh[i]; @@ -743,6 +767,8 @@ EV_FLOAT pair_compute_neighlist (PairStyle* fpair, typename std::enable_if<(NEIG fpair->lmp->kokkos->neigh_thread = 1; if (fpair->lmp->kokkos->neigh_thread) { + fpair->fuse_force_clear_flag = 1; + int vector_length = 8; int atoms_per_team = 32; diff --git a/src/KOKKOS/verlet_kokkos.cpp b/src/KOKKOS/verlet_kokkos.cpp index 01401542cd..0339df663b 100644 --- a/src/KOKKOS/verlet_kokkos.cpp +++ b/src/KOKKOS/verlet_kokkos.cpp @@ -296,7 +296,9 @@ void VerletKokkos::run(int n) // initial time integration timer->stamp(); - modify->initial_integrate(vflag); + fuse_check(i,n); + if (!fuse_integrate) + modify->initial_integrate(vflag); if (n_post_integrate) modify->post_integrate(); timer->stamp(Timer::MODIFY); @@ -362,7 +364,8 @@ void VerletKokkos::run(int n) // since some bonded potentials tally pairwise energy/virial // and Pair:ev_tally() needs to be called before any tallying - force_clear(); + if (!fuse_force_clear) + force_clear(); timer->stamp(); @@ -494,7 +497,10 @@ void VerletKokkos::run(int n) // force modifications, final time integration, diagnostics if (n_post_force) modify->post_force(vflag); - modify->final_integrate(); + + if (fuse_integrate) modify->fused_integrate(); + else modify->final_integrate(); + if (n_end_of_step) modify->end_of_step(); timer->stamp(Timer::MODIFY); @@ -593,3 +599,35 @@ void VerletKokkos::force_clear() } } } + +/* ---------------------------------------------------------------------- + check if can fuse force_clear() with pair compute() + Requirements: + - no pre_force fixes + - no torques, SPIN forces, or includegroup set + - pair compute() must be called + - pair_style must support fusing + + check if can fuse initial_integrate() with final_integrate() + Requirements: + - no end_of_step fixes + - not on first, last, or output step + - no timers to break out of loop + - integrate fix style must support fusing +------------------------------------------------------------------------- */ + +void VerletKokkos::fuse_check(int i, int n) +{ + fuse_force_clear = 0; + if (modify->n_pre_force) fuse_force_clear = 0; + if (torqueflag || extraflag || neighbor->includegroup) fuse_force_clear = 0; + if (!pair_compute_flag) fuse_force_clear = 0; + if (!force->pair->fuse_force_clear_flag) fuse_force_clear = 0; + + fuse_integrate = 0; + if (modify->n_end_of_step) fuse_integrate = 0; + if (i == 0 || i == n-1) fuse_integrate = 0; + if (update->ntimestep == output->next) fuse_integrate = 0; + if (timer->has_timeout()) fuse_integrate = 0; + if (!modify->check_fuse_integrate()) fuse_integrate = 0; +} diff --git a/src/KOKKOS/verlet_kokkos.h b/src/KOKKOS/verlet_kokkos.h index 067df54f4f..c71211c542 100644 --- a/src/KOKKOS/verlet_kokkos.h +++ b/src/KOKKOS/verlet_kokkos.h @@ -46,6 +46,9 @@ class VerletKokkos : public Verlet { protected: DAT::t_f_array f_merge_copy,f; + int fuse_force_clear,fuse_integrate; + + void fuse_check(int, int); }; } diff --git a/src/fix.cpp b/src/fix.cpp index 1d41ad3943..02adcbd016 100644 --- a/src/fix.cpp +++ b/src/fix.cpp @@ -102,15 +102,15 @@ Fix::Fix(LAMMPS *lmp, int /*narg*/, char **arg) : vflag_atom = cvflag_atom = 0; centroidstressflag = CENTROID_SAME; - // KOKKOS per-fix data masks + // KOKKOS package execution_space = Host; datamask_read = ALL_MASK; datamask_modify = ALL_MASK; - kokkosable = 0; + kokkosable = copymode = 0; forward_comm_device = exchange_comm_device = 0; - copymode = 0; + fuse_integrate_flag = 0; } /* ---------------------------------------------------------------------- */ diff --git a/src/fix.h b/src/fix.h index b47cfb2f4a..23b120d989 100644 --- a/src/fix.h +++ b/src/fix.h @@ -127,11 +127,12 @@ class Fix : protected Pointers { int restart_reset; // 1 if restart just re-initialized fix - // KOKKOS host/device flag and data masks + // KOKKOS flags and variables int kokkosable; // 1 if Kokkos fix int forward_comm_device; // 1 if forward comm on Device int exchange_comm_device; // 1 if exchange comm on Device + int fuse_integrate_flag; // 1 if can fuse initial integrate with final integrate ExecutionSpace execution_space; unsigned int datamask_read, datamask_modify; @@ -152,6 +153,7 @@ class Fix : protected Pointers { virtual void setup_pre_reverse(int, int) {} virtual void min_setup(int) {} virtual void initial_integrate(int) {} + virtual void fused_integrate() {} virtual void post_integrate() {} virtual void pre_exchange() {} virtual void pre_neighbor() {} diff --git a/src/modify.cpp b/src/modify.cpp index d0656d3895..71f7cb8889 100644 --- a/src/modify.cpp +++ b/src/modify.cpp @@ -475,6 +475,17 @@ void Modify::final_integrate() for (int i = 0; i < n_final_integrate; i++) fix[list_final_integrate[i]]->final_integrate(); } + +/* ---------------------------------------------------------------------- + 2nd half of integrate call, only for relevant fixes +------------------------------------------------------------------------- */ + +void Modify::fused_integrate() +{ + for (int i = 0; i < n_final_integrate; i++) + fix[list_final_integrate[i]]->fused_integrate(); +} + /* ---------------------------------------------------------------------- end-of-timestep call, only for relevant fixes only call fix->end_of_step() on timesteps that are multiples of nevery @@ -1799,3 +1810,22 @@ double Modify::memory_usage() for (int i = 0; i < ncompute; i++) bytes += compute[i]->memory_usage(); return bytes; } + +/* ---------------------------------------------------------------------- + check if initial and final integrate can be fused +------------------------------------------------------------------------- */ + +int Modify::check_fuse_integrate() +{ + int fuse_integrate_flag = 1; + + for (int i = 0; i < n_initial_integrate; i++) + if (!fix[list_initial_integrate[i]]->fuse_integrate_flag) + fuse_integrate_flag = 0; + + for (int i = 0; i < n_final_integrate; i++) + if (!fix[list_final_integrate[i]]->fuse_integrate_flag) + fuse_integrate_flag = 0; + + return fuse_integrate_flag; +} diff --git a/src/modify.h b/src/modify.h index 7a3f54c277..da3cab55d2 100644 --- a/src/modify.h +++ b/src/modify.h @@ -61,6 +61,7 @@ class Modify : protected Pointers { virtual void setup_pre_force(int); virtual void setup_pre_reverse(int, int); virtual void initial_integrate(int); + virtual void fused_integrate(); virtual void post_integrate(); virtual void pre_exchange(); virtual void pre_neighbor(); @@ -150,6 +151,8 @@ class Modify : protected Pointers { double memory_usage(); + int check_fuse_integrate(); + protected: // internal fix counts diff --git a/src/pair.cpp b/src/pair.cpp index 9ae24e1e93..34c8e4978e 100644 --- a/src/pair.cpp +++ b/src/pair.cpp @@ -116,15 +116,14 @@ Pair::Pair(LAMMPS *lmp) : nondefault_history_transfer = 0; beyond_contact = 0; - // KOKKOS per-fix data masks + // KOKKOS package execution_space = Host; datamask_read = ALL_MASK; datamask_modify = ALL_MASK; - kokkosable = 0; - reverse_comm_device = 0; - copymode = 0; + kokkosable = copymode = 0; + reverse_comm_device = fuse_force_clear_flag = 0; } /* ---------------------------------------------------------------------- */ diff --git a/src/pair.h b/src/pair.h index 8c856660e9..d36b48f340 100644 --- a/src/pair.h +++ b/src/pair.h @@ -119,12 +119,13 @@ class Pair : protected Pointers { int beyond_contact, nondefault_history_transfer; // for granular styles - // KOKKOS host/device flag and data masks + // KOKKOS flags and variables ExecutionSpace execution_space; unsigned int datamask_read, datamask_modify; int kokkosable; // 1 if Kokkos pair int reverse_comm_device; // 1 if reverse comm on Device + int fuse_force_clear_flag; // 1 if can fuse force clear with force compute Pair(class LAMMPS *); ~Pair() override; diff --git a/src/timer.h b/src/timer.h index 5c100db1c0..f7efa5ac64 100644 --- a/src/timer.h +++ b/src/timer.h @@ -63,6 +63,7 @@ class Timer : protected Pointers { bool has_normal() const { return (_level >= NORMAL); } bool has_full() const { return (_level >= FULL); } bool has_sync() const { return (_sync != OFF); } + bool has_timeout() const { return (_timeout >= 0.0); } // flag if wallclock time is expired bool is_timeout() const { return (_timeout == 0.0); }