Fuse some Kokkos kernels to reduce launch latency for small systems

This commit is contained in:
Stan Moore
2023-04-28 14:40:59 -06:00
parent 4a608dced6
commit 235372d6e8
14 changed files with 185 additions and 12 deletions

View File

@ -29,6 +29,7 @@ FixNVEKokkos<DeviceType>::FixNVEKokkos(LAMMPS *lmp, int narg, char **arg) :
FixNVE(lmp, narg, arg) FixNVE(lmp, narg, arg)
{ {
kokkosable = 1; kokkosable = 1;
fuse_integrate_flag = 1;
atomKK = (AtomKokkos *) atom; atomKK = (AtomKokkos *) atom;
execution_space = ExecutionSpaceFromDevice<DeviceType>::space; execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
@ -168,6 +169,35 @@ void FixNVEKokkos<DeviceType>::cleanup_copy()
vatom = nullptr; vatom = nullptr;
} }
/* ----------------------------------------------------------------------
allow for both per-type and per-atom mass
------------------------------------------------------------------------- */
template<class DeviceType>
void FixNVEKokkos<DeviceType>::fused_integrate()
{
atomKK->sync(execution_space,datamask_read);
atomKK->modified(execution_space,datamask_modify);
x = atomKK->k_x.view<DeviceType>();
v = atomKK->k_v.view<DeviceType>();
f = atomKK->k_f.view<DeviceType>();
rmass = atomKK->k_rmass.view<DeviceType>();
mass = atomKK->k_mass.view<DeviceType>();
type = atomKK->k_type.view<DeviceType>();
mask = atomKK->k_mask.view<DeviceType>();
int nlocal = atomKK->nlocal;
if (igroup == atomKK->firstgroup) nlocal = atomKK->nfirst;
if (rmass.data()) {
FixNVEKokkosFusedIntegrateFunctor<DeviceType,1> functor(this);
Kokkos::parallel_for(nlocal,functor);
} else {
FixNVEKokkosFusedIntegrateFunctor<DeviceType,0> functor(this);
Kokkos::parallel_for(nlocal,functor);
}
}
namespace LAMMPS_NS { namespace LAMMPS_NS {
template class FixNVEKokkos<LMPDeviceType>; template class FixNVEKokkos<LMPDeviceType>;
#ifdef LMP_KOKKOS_GPU #ifdef LMP_KOKKOS_GPU

View File

@ -46,6 +46,7 @@ class FixNVEKokkos : public FixNVE {
void init() override; void init() override;
void initial_integrate(int) override; void initial_integrate(int) override;
void final_integrate() override; void final_integrate() override;
void fused_integrate() override;
KOKKOS_INLINE_FUNCTION KOKKOS_INLINE_FUNCTION
void initial_integrate_item(int) const; void initial_integrate_item(int) const;
@ -96,6 +97,25 @@ struct FixNVEKokkosFinalIntegrateFunctor {
} }
}; };
template <class DeviceType, int RMass>
struct FixNVEKokkosFusedIntegrateFunctor {
typedef DeviceType device_type ;
FixNVEKokkos<DeviceType> c;
FixNVEKokkosFusedIntegrateFunctor(FixNVEKokkos<DeviceType>* c_ptr):
c(*c_ptr) {c.cleanup_copy();};
KOKKOS_INLINE_FUNCTION
void operator()(const int i) const {
if (RMass) {
c.final_integrate_rmass_item(i);
c.initial_integrate_rmass_item(i);
} else {
c.final_integrate_item(i);
c.initial_integrate_item(i);
}
}
};
} }
#endif #endif

View File

@ -392,6 +392,25 @@ void ModifyKokkos::final_integrate()
} }
} }
/* ----------------------------------------------------------------------
2nd half of integrate call, only for relevant fixes
------------------------------------------------------------------------- */
void ModifyKokkos::fused_integrate()
{
for (int i = 0; i < n_final_integrate; i++) {
atomKK->sync(fix[list_final_integrate[i]]->execution_space,
fix[list_final_integrate[i]]->datamask_read);
int prev_auto_sync = lmp->kokkos->auto_sync;
if (!fix[list_final_integrate[i]]->kokkosable) lmp->kokkos->auto_sync = 1;
fix[list_final_integrate[i]]->fused_integrate();
lmp->kokkos->auto_sync = prev_auto_sync;
atomKK->modified(fix[list_final_integrate[i]]->execution_space,
fix[list_final_integrate[i]]->datamask_modify);
}
}
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
end-of-timestep call, only for relevant fixes end-of-timestep call, only for relevant fixes
only call fix->end_of_step() on timesteps that are multiples of nevery only call fix->end_of_step() on timesteps that are multiples of nevery

View File

@ -39,6 +39,7 @@ class ModifyKokkos : public Modify {
void pre_reverse(int,int) override; void pre_reverse(int,int) override;
void post_force(int) override; void post_force(int) override;
void final_integrate() override; void final_integrate() override;
void fused_integrate() override;
void end_of_step() override; void end_of_step() override;
double energy_couple() override; double energy_couple() override;
double energy_global() override; double energy_global() override;

View File

@ -280,6 +280,12 @@ struct PairComputeFunctor {
const X_FLOAT ztmp = c.x(i,2); const X_FLOAT ztmp = c.x(i,2);
const int itype = c.type(i); const int itype = c.type(i);
Kokkos::single(Kokkos::PerThread(team), [&] (){
f(i,0) = 0.0;
f(i,1) = 0.0;
f(i,2) = 0.0;
});
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i); const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
const int jnum = list.d_numneigh[i]; const int jnum = list.d_numneigh[i];
@ -337,6 +343,12 @@ struct PairComputeFunctor {
const int itype = c.type(i); const int itype = c.type(i);
const F_FLOAT qtmp = c.q(i); const F_FLOAT qtmp = c.q(i);
Kokkos::single(Kokkos::PerThread(team), [&] (){
f(i,0) = 0.0;
f(i,1) = 0.0;
f(i,2) = 0.0;
});
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i); const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
const int jnum = list.d_numneigh[i]; const int jnum = list.d_numneigh[i];
@ -399,6 +411,12 @@ struct PairComputeFunctor {
const X_FLOAT ztmp = c.x(i,2); const X_FLOAT ztmp = c.x(i,2);
const int itype = c.type(i); const int itype = c.type(i);
Kokkos::single(Kokkos::PerThread(team), [&] (){
f(i,0) = 0.0;
f(i,1) = 0.0;
f(i,2) = 0.0;
});
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i); const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
const int jnum = list.d_numneigh[i]; const int jnum = list.d_numneigh[i];
@ -495,6 +513,12 @@ struct PairComputeFunctor {
const int itype = c.type(i); const int itype = c.type(i);
const F_FLOAT qtmp = c.q(i); const F_FLOAT qtmp = c.q(i);
Kokkos::single(Kokkos::PerThread(team), [&] (){
f(i,0) = 0.0;
f(i,1) = 0.0;
f(i,2) = 0.0;
});
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i); const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
const int jnum = list.d_numneigh[i]; const int jnum = list.d_numneigh[i];
@ -743,6 +767,8 @@ EV_FLOAT pair_compute_neighlist (PairStyle* fpair, typename std::enable_if<(NEIG
fpair->lmp->kokkos->neigh_thread = 1; fpair->lmp->kokkos->neigh_thread = 1;
if (fpair->lmp->kokkos->neigh_thread) { if (fpair->lmp->kokkos->neigh_thread) {
fpair->fuse_force_clear_flag = 1;
int vector_length = 8; int vector_length = 8;
int atoms_per_team = 32; int atoms_per_team = 32;

View File

@ -296,6 +296,8 @@ void VerletKokkos::run(int n)
// initial time integration // initial time integration
timer->stamp(); timer->stamp();
fuse_check(i,n);
if (!fuse_integrate)
modify->initial_integrate(vflag); modify->initial_integrate(vflag);
if (n_post_integrate) modify->post_integrate(); if (n_post_integrate) modify->post_integrate();
timer->stamp(Timer::MODIFY); timer->stamp(Timer::MODIFY);
@ -362,6 +364,7 @@ void VerletKokkos::run(int n)
// since some bonded potentials tally pairwise energy/virial // since some bonded potentials tally pairwise energy/virial
// and Pair:ev_tally() needs to be called before any tallying // and Pair:ev_tally() needs to be called before any tallying
if (!fuse_force_clear)
force_clear(); force_clear();
timer->stamp(); timer->stamp();
@ -494,7 +497,10 @@ void VerletKokkos::run(int n)
// force modifications, final time integration, diagnostics // force modifications, final time integration, diagnostics
if (n_post_force) modify->post_force(vflag); if (n_post_force) modify->post_force(vflag);
modify->final_integrate();
if (fuse_integrate) modify->fused_integrate();
else modify->final_integrate();
if (n_end_of_step) modify->end_of_step(); if (n_end_of_step) modify->end_of_step();
timer->stamp(Timer::MODIFY); timer->stamp(Timer::MODIFY);
@ -593,3 +599,35 @@ void VerletKokkos::force_clear()
} }
} }
} }
/* ----------------------------------------------------------------------
check if can fuse force_clear() with pair compute()
Requirements:
- no pre_force fixes
- no torques, SPIN forces, or includegroup set
- pair compute() must be called
- pair_style must support fusing
check if can fuse initial_integrate() with final_integrate()
Requirements:
- no end_of_step fixes
- not on first, last, or output step
- no timers to break out of loop
- integrate fix style must support fusing
------------------------------------------------------------------------- */
void VerletKokkos::fuse_check(int i, int n)
{
fuse_force_clear = 0;
if (modify->n_pre_force) fuse_force_clear = 0;
if (torqueflag || extraflag || neighbor->includegroup) fuse_force_clear = 0;
if (!pair_compute_flag) fuse_force_clear = 0;
if (!force->pair->fuse_force_clear_flag) fuse_force_clear = 0;
fuse_integrate = 0;
if (modify->n_end_of_step) fuse_integrate = 0;
if (i == 0 || i == n-1) fuse_integrate = 0;
if (update->ntimestep == output->next) fuse_integrate = 0;
if (timer->has_timeout()) fuse_integrate = 0;
if (!modify->check_fuse_integrate()) fuse_integrate = 0;
}

View File

@ -46,6 +46,9 @@ class VerletKokkos : public Verlet {
protected: protected:
DAT::t_f_array f_merge_copy,f; DAT::t_f_array f_merge_copy,f;
int fuse_force_clear,fuse_integrate;
void fuse_check(int, int);
}; };
} }

View File

@ -102,15 +102,15 @@ Fix::Fix(LAMMPS *lmp, int /*narg*/, char **arg) :
vflag_atom = cvflag_atom = 0; vflag_atom = cvflag_atom = 0;
centroidstressflag = CENTROID_SAME; centroidstressflag = CENTROID_SAME;
// KOKKOS per-fix data masks // KOKKOS package
execution_space = Host; execution_space = Host;
datamask_read = ALL_MASK; datamask_read = ALL_MASK;
datamask_modify = ALL_MASK; datamask_modify = ALL_MASK;
kokkosable = 0; kokkosable = copymode = 0;
forward_comm_device = exchange_comm_device = 0; forward_comm_device = exchange_comm_device = 0;
copymode = 0; fuse_integrate_flag = 0;
} }
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */

View File

@ -127,11 +127,12 @@ class Fix : protected Pointers {
int restart_reset; // 1 if restart just re-initialized fix int restart_reset; // 1 if restart just re-initialized fix
// KOKKOS host/device flag and data masks // KOKKOS flags and variables
int kokkosable; // 1 if Kokkos fix int kokkosable; // 1 if Kokkos fix
int forward_comm_device; // 1 if forward comm on Device int forward_comm_device; // 1 if forward comm on Device
int exchange_comm_device; // 1 if exchange comm on Device int exchange_comm_device; // 1 if exchange comm on Device
int fuse_integrate_flag; // 1 if can fuse initial integrate with final integrate
ExecutionSpace execution_space; ExecutionSpace execution_space;
unsigned int datamask_read, datamask_modify; unsigned int datamask_read, datamask_modify;
@ -152,6 +153,7 @@ class Fix : protected Pointers {
virtual void setup_pre_reverse(int, int) {} virtual void setup_pre_reverse(int, int) {}
virtual void min_setup(int) {} virtual void min_setup(int) {}
virtual void initial_integrate(int) {} virtual void initial_integrate(int) {}
virtual void fused_integrate() {}
virtual void post_integrate() {} virtual void post_integrate() {}
virtual void pre_exchange() {} virtual void pre_exchange() {}
virtual void pre_neighbor() {} virtual void pre_neighbor() {}

View File

@ -475,6 +475,17 @@ void Modify::final_integrate()
for (int i = 0; i < n_final_integrate; i++) fix[list_final_integrate[i]]->final_integrate(); for (int i = 0; i < n_final_integrate; i++) fix[list_final_integrate[i]]->final_integrate();
} }
/* ----------------------------------------------------------------------
2nd half of integrate call, only for relevant fixes
------------------------------------------------------------------------- */
void Modify::fused_integrate()
{
for (int i = 0; i < n_final_integrate; i++)
fix[list_final_integrate[i]]->fused_integrate();
}
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
end-of-timestep call, only for relevant fixes end-of-timestep call, only for relevant fixes
only call fix->end_of_step() on timesteps that are multiples of nevery only call fix->end_of_step() on timesteps that are multiples of nevery
@ -1799,3 +1810,22 @@ double Modify::memory_usage()
for (int i = 0; i < ncompute; i++) bytes += compute[i]->memory_usage(); for (int i = 0; i < ncompute; i++) bytes += compute[i]->memory_usage();
return bytes; return bytes;
} }
/* ----------------------------------------------------------------------
check if initial and final integrate can be fused
------------------------------------------------------------------------- */
int Modify::check_fuse_integrate()
{
int fuse_integrate_flag = 1;
for (int i = 0; i < n_initial_integrate; i++)
if (!fix[list_initial_integrate[i]]->fuse_integrate_flag)
fuse_integrate_flag = 0;
for (int i = 0; i < n_final_integrate; i++)
if (!fix[list_final_integrate[i]]->fuse_integrate_flag)
fuse_integrate_flag = 0;
return fuse_integrate_flag;
}

View File

@ -61,6 +61,7 @@ class Modify : protected Pointers {
virtual void setup_pre_force(int); virtual void setup_pre_force(int);
virtual void setup_pre_reverse(int, int); virtual void setup_pre_reverse(int, int);
virtual void initial_integrate(int); virtual void initial_integrate(int);
virtual void fused_integrate();
virtual void post_integrate(); virtual void post_integrate();
virtual void pre_exchange(); virtual void pre_exchange();
virtual void pre_neighbor(); virtual void pre_neighbor();
@ -150,6 +151,8 @@ class Modify : protected Pointers {
double memory_usage(); double memory_usage();
int check_fuse_integrate();
protected: protected:
// internal fix counts // internal fix counts

View File

@ -116,15 +116,14 @@ Pair::Pair(LAMMPS *lmp) :
nondefault_history_transfer = 0; nondefault_history_transfer = 0;
beyond_contact = 0; beyond_contact = 0;
// KOKKOS per-fix data masks // KOKKOS package
execution_space = Host; execution_space = Host;
datamask_read = ALL_MASK; datamask_read = ALL_MASK;
datamask_modify = ALL_MASK; datamask_modify = ALL_MASK;
kokkosable = 0; kokkosable = copymode = 0;
reverse_comm_device = 0; reverse_comm_device = fuse_force_clear_flag = 0;
copymode = 0;
} }
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */

View File

@ -119,12 +119,13 @@ class Pair : protected Pointers {
int beyond_contact, nondefault_history_transfer; // for granular styles int beyond_contact, nondefault_history_transfer; // for granular styles
// KOKKOS host/device flag and data masks // KOKKOS flags and variables
ExecutionSpace execution_space; ExecutionSpace execution_space;
unsigned int datamask_read, datamask_modify; unsigned int datamask_read, datamask_modify;
int kokkosable; // 1 if Kokkos pair int kokkosable; // 1 if Kokkos pair
int reverse_comm_device; // 1 if reverse comm on Device int reverse_comm_device; // 1 if reverse comm on Device
int fuse_force_clear_flag; // 1 if can fuse force clear with force compute
Pair(class LAMMPS *); Pair(class LAMMPS *);
~Pair() override; ~Pair() override;

View File

@ -63,6 +63,7 @@ class Timer : protected Pointers {
bool has_normal() const { return (_level >= NORMAL); } bool has_normal() const { return (_level >= NORMAL); }
bool has_full() const { return (_level >= FULL); } bool has_full() const { return (_level >= FULL); }
bool has_sync() const { return (_sync != OFF); } bool has_sync() const { return (_sync != OFF); }
bool has_timeout() const { return (_timeout >= 0.0); }
// flag if wallclock time is expired // flag if wallclock time is expired
bool is_timeout() const { return (_timeout == 0.0); } bool is_timeout() const { return (_timeout == 0.0); }