Fuse some Kokkos kernels to reduce launch latency for small systems
This commit is contained in:
@ -29,6 +29,7 @@ FixNVEKokkos<DeviceType>::FixNVEKokkos(LAMMPS *lmp, int narg, char **arg) :
|
||||
FixNVE(lmp, narg, arg)
|
||||
{
|
||||
kokkosable = 1;
|
||||
fuse_integrate_flag = 1;
|
||||
atomKK = (AtomKokkos *) atom;
|
||||
execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
|
||||
|
||||
@ -168,6 +169,35 @@ void FixNVEKokkos<DeviceType>::cleanup_copy()
|
||||
vatom = nullptr;
|
||||
}
|
||||
|
||||
/* ----------------------------------------------------------------------
|
||||
allow for both per-type and per-atom mass
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
template<class DeviceType>
|
||||
void FixNVEKokkos<DeviceType>::fused_integrate()
|
||||
{
|
||||
atomKK->sync(execution_space,datamask_read);
|
||||
atomKK->modified(execution_space,datamask_modify);
|
||||
|
||||
x = atomKK->k_x.view<DeviceType>();
|
||||
v = atomKK->k_v.view<DeviceType>();
|
||||
f = atomKK->k_f.view<DeviceType>();
|
||||
rmass = atomKK->k_rmass.view<DeviceType>();
|
||||
mass = atomKK->k_mass.view<DeviceType>();
|
||||
type = atomKK->k_type.view<DeviceType>();
|
||||
mask = atomKK->k_mask.view<DeviceType>();
|
||||
int nlocal = atomKK->nlocal;
|
||||
if (igroup == atomKK->firstgroup) nlocal = atomKK->nfirst;
|
||||
|
||||
if (rmass.data()) {
|
||||
FixNVEKokkosFusedIntegrateFunctor<DeviceType,1> functor(this);
|
||||
Kokkos::parallel_for(nlocal,functor);
|
||||
} else {
|
||||
FixNVEKokkosFusedIntegrateFunctor<DeviceType,0> functor(this);
|
||||
Kokkos::parallel_for(nlocal,functor);
|
||||
}
|
||||
}
|
||||
|
||||
namespace LAMMPS_NS {
|
||||
template class FixNVEKokkos<LMPDeviceType>;
|
||||
#ifdef LMP_KOKKOS_GPU
|
||||
|
||||
@ -46,6 +46,7 @@ class FixNVEKokkos : public FixNVE {
|
||||
void init() override;
|
||||
void initial_integrate(int) override;
|
||||
void final_integrate() override;
|
||||
void fused_integrate() override;
|
||||
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void initial_integrate_item(int) const;
|
||||
@ -96,6 +97,25 @@ struct FixNVEKokkosFinalIntegrateFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <class DeviceType, int RMass>
|
||||
struct FixNVEKokkosFusedIntegrateFunctor {
|
||||
typedef DeviceType device_type ;
|
||||
FixNVEKokkos<DeviceType> c;
|
||||
|
||||
FixNVEKokkosFusedIntegrateFunctor(FixNVEKokkos<DeviceType>* c_ptr):
|
||||
c(*c_ptr) {c.cleanup_copy();};
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void operator()(const int i) const {
|
||||
if (RMass) {
|
||||
c.final_integrate_rmass_item(i);
|
||||
c.initial_integrate_rmass_item(i);
|
||||
} else {
|
||||
c.final_integrate_item(i);
|
||||
c.initial_integrate_item(i);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -392,6 +392,25 @@ void ModifyKokkos::final_integrate()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* ----------------------------------------------------------------------
|
||||
2nd half of integrate call, only for relevant fixes
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
void ModifyKokkos::fused_integrate()
|
||||
{
|
||||
for (int i = 0; i < n_final_integrate; i++) {
|
||||
atomKK->sync(fix[list_final_integrate[i]]->execution_space,
|
||||
fix[list_final_integrate[i]]->datamask_read);
|
||||
int prev_auto_sync = lmp->kokkos->auto_sync;
|
||||
if (!fix[list_final_integrate[i]]->kokkosable) lmp->kokkos->auto_sync = 1;
|
||||
fix[list_final_integrate[i]]->fused_integrate();
|
||||
lmp->kokkos->auto_sync = prev_auto_sync;
|
||||
atomKK->modified(fix[list_final_integrate[i]]->execution_space,
|
||||
fix[list_final_integrate[i]]->datamask_modify);
|
||||
}
|
||||
}
|
||||
|
||||
/* ----------------------------------------------------------------------
|
||||
end-of-timestep call, only for relevant fixes
|
||||
only call fix->end_of_step() on timesteps that are multiples of nevery
|
||||
|
||||
@ -39,6 +39,7 @@ class ModifyKokkos : public Modify {
|
||||
void pre_reverse(int,int) override;
|
||||
void post_force(int) override;
|
||||
void final_integrate() override;
|
||||
void fused_integrate() override;
|
||||
void end_of_step() override;
|
||||
double energy_couple() override;
|
||||
double energy_global() override;
|
||||
|
||||
@ -280,6 +280,12 @@ struct PairComputeFunctor {
|
||||
const X_FLOAT ztmp = c.x(i,2);
|
||||
const int itype = c.type(i);
|
||||
|
||||
Kokkos::single(Kokkos::PerThread(team), [&] (){
|
||||
f(i,0) = 0.0;
|
||||
f(i,1) = 0.0;
|
||||
f(i,2) = 0.0;
|
||||
});
|
||||
|
||||
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
|
||||
const int jnum = list.d_numneigh[i];
|
||||
|
||||
@ -337,6 +343,12 @@ struct PairComputeFunctor {
|
||||
const int itype = c.type(i);
|
||||
const F_FLOAT qtmp = c.q(i);
|
||||
|
||||
Kokkos::single(Kokkos::PerThread(team), [&] (){
|
||||
f(i,0) = 0.0;
|
||||
f(i,1) = 0.0;
|
||||
f(i,2) = 0.0;
|
||||
});
|
||||
|
||||
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
|
||||
const int jnum = list.d_numneigh[i];
|
||||
|
||||
@ -399,6 +411,12 @@ struct PairComputeFunctor {
|
||||
const X_FLOAT ztmp = c.x(i,2);
|
||||
const int itype = c.type(i);
|
||||
|
||||
Kokkos::single(Kokkos::PerThread(team), [&] (){
|
||||
f(i,0) = 0.0;
|
||||
f(i,1) = 0.0;
|
||||
f(i,2) = 0.0;
|
||||
});
|
||||
|
||||
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
|
||||
const int jnum = list.d_numneigh[i];
|
||||
|
||||
@ -495,6 +513,12 @@ struct PairComputeFunctor {
|
||||
const int itype = c.type(i);
|
||||
const F_FLOAT qtmp = c.q(i);
|
||||
|
||||
Kokkos::single(Kokkos::PerThread(team), [&] (){
|
||||
f(i,0) = 0.0;
|
||||
f(i,1) = 0.0;
|
||||
f(i,2) = 0.0;
|
||||
});
|
||||
|
||||
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
|
||||
const int jnum = list.d_numneigh[i];
|
||||
|
||||
@ -743,6 +767,8 @@ EV_FLOAT pair_compute_neighlist (PairStyle* fpair, typename std::enable_if<(NEIG
|
||||
fpair->lmp->kokkos->neigh_thread = 1;
|
||||
|
||||
if (fpair->lmp->kokkos->neigh_thread) {
|
||||
fpair->fuse_force_clear_flag = 1;
|
||||
|
||||
int vector_length = 8;
|
||||
int atoms_per_team = 32;
|
||||
|
||||
|
||||
@ -296,6 +296,8 @@ void VerletKokkos::run(int n)
|
||||
// initial time integration
|
||||
|
||||
timer->stamp();
|
||||
fuse_check(i,n);
|
||||
if (!fuse_integrate)
|
||||
modify->initial_integrate(vflag);
|
||||
if (n_post_integrate) modify->post_integrate();
|
||||
timer->stamp(Timer::MODIFY);
|
||||
@ -362,6 +364,7 @@ void VerletKokkos::run(int n)
|
||||
// since some bonded potentials tally pairwise energy/virial
|
||||
// and Pair:ev_tally() needs to be called before any tallying
|
||||
|
||||
if (!fuse_force_clear)
|
||||
force_clear();
|
||||
|
||||
timer->stamp();
|
||||
@ -494,7 +497,10 @@ void VerletKokkos::run(int n)
|
||||
// force modifications, final time integration, diagnostics
|
||||
|
||||
if (n_post_force) modify->post_force(vflag);
|
||||
modify->final_integrate();
|
||||
|
||||
if (fuse_integrate) modify->fused_integrate();
|
||||
else modify->final_integrate();
|
||||
|
||||
if (n_end_of_step) modify->end_of_step();
|
||||
timer->stamp(Timer::MODIFY);
|
||||
|
||||
@ -593,3 +599,35 @@ void VerletKokkos::force_clear()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ----------------------------------------------------------------------
|
||||
check if can fuse force_clear() with pair compute()
|
||||
Requirements:
|
||||
- no pre_force fixes
|
||||
- no torques, SPIN forces, or includegroup set
|
||||
- pair compute() must be called
|
||||
- pair_style must support fusing
|
||||
|
||||
check if can fuse initial_integrate() with final_integrate()
|
||||
Requirements:
|
||||
- no end_of_step fixes
|
||||
- not on first, last, or output step
|
||||
- no timers to break out of loop
|
||||
- integrate fix style must support fusing
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
void VerletKokkos::fuse_check(int i, int n)
|
||||
{
|
||||
fuse_force_clear = 0;
|
||||
if (modify->n_pre_force) fuse_force_clear = 0;
|
||||
if (torqueflag || extraflag || neighbor->includegroup) fuse_force_clear = 0;
|
||||
if (!pair_compute_flag) fuse_force_clear = 0;
|
||||
if (!force->pair->fuse_force_clear_flag) fuse_force_clear = 0;
|
||||
|
||||
fuse_integrate = 0;
|
||||
if (modify->n_end_of_step) fuse_integrate = 0;
|
||||
if (i == 0 || i == n-1) fuse_integrate = 0;
|
||||
if (update->ntimestep == output->next) fuse_integrate = 0;
|
||||
if (timer->has_timeout()) fuse_integrate = 0;
|
||||
if (!modify->check_fuse_integrate()) fuse_integrate = 0;
|
||||
}
|
||||
|
||||
@ -46,6 +46,9 @@ class VerletKokkos : public Verlet {
|
||||
|
||||
protected:
|
||||
DAT::t_f_array f_merge_copy,f;
|
||||
int fuse_force_clear,fuse_integrate;
|
||||
|
||||
void fuse_check(int, int);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -102,15 +102,15 @@ Fix::Fix(LAMMPS *lmp, int /*narg*/, char **arg) :
|
||||
vflag_atom = cvflag_atom = 0;
|
||||
centroidstressflag = CENTROID_SAME;
|
||||
|
||||
// KOKKOS per-fix data masks
|
||||
// KOKKOS package
|
||||
|
||||
execution_space = Host;
|
||||
datamask_read = ALL_MASK;
|
||||
datamask_modify = ALL_MASK;
|
||||
|
||||
kokkosable = 0;
|
||||
kokkosable = copymode = 0;
|
||||
forward_comm_device = exchange_comm_device = 0;
|
||||
copymode = 0;
|
||||
fuse_integrate_flag = 0;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
@ -127,11 +127,12 @@ class Fix : protected Pointers {
|
||||
|
||||
int restart_reset; // 1 if restart just re-initialized fix
|
||||
|
||||
// KOKKOS host/device flag and data masks
|
||||
// KOKKOS flags and variables
|
||||
|
||||
int kokkosable; // 1 if Kokkos fix
|
||||
int forward_comm_device; // 1 if forward comm on Device
|
||||
int exchange_comm_device; // 1 if exchange comm on Device
|
||||
int fuse_integrate_flag; // 1 if can fuse initial integrate with final integrate
|
||||
ExecutionSpace execution_space;
|
||||
unsigned int datamask_read, datamask_modify;
|
||||
|
||||
@ -152,6 +153,7 @@ class Fix : protected Pointers {
|
||||
virtual void setup_pre_reverse(int, int) {}
|
||||
virtual void min_setup(int) {}
|
||||
virtual void initial_integrate(int) {}
|
||||
virtual void fused_integrate() {}
|
||||
virtual void post_integrate() {}
|
||||
virtual void pre_exchange() {}
|
||||
virtual void pre_neighbor() {}
|
||||
|
||||
@ -475,6 +475,17 @@ void Modify::final_integrate()
|
||||
for (int i = 0; i < n_final_integrate; i++) fix[list_final_integrate[i]]->final_integrate();
|
||||
}
|
||||
|
||||
|
||||
/* ----------------------------------------------------------------------
|
||||
2nd half of integrate call, only for relevant fixes
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
void Modify::fused_integrate()
|
||||
{
|
||||
for (int i = 0; i < n_final_integrate; i++)
|
||||
fix[list_final_integrate[i]]->fused_integrate();
|
||||
}
|
||||
|
||||
/* ----------------------------------------------------------------------
|
||||
end-of-timestep call, only for relevant fixes
|
||||
only call fix->end_of_step() on timesteps that are multiples of nevery
|
||||
@ -1799,3 +1810,22 @@ double Modify::memory_usage()
|
||||
for (int i = 0; i < ncompute; i++) bytes += compute[i]->memory_usage();
|
||||
return bytes;
|
||||
}
|
||||
|
||||
/* ----------------------------------------------------------------------
|
||||
check if initial and final integrate can be fused
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
int Modify::check_fuse_integrate()
|
||||
{
|
||||
int fuse_integrate_flag = 1;
|
||||
|
||||
for (int i = 0; i < n_initial_integrate; i++)
|
||||
if (!fix[list_initial_integrate[i]]->fuse_integrate_flag)
|
||||
fuse_integrate_flag = 0;
|
||||
|
||||
for (int i = 0; i < n_final_integrate; i++)
|
||||
if (!fix[list_final_integrate[i]]->fuse_integrate_flag)
|
||||
fuse_integrate_flag = 0;
|
||||
|
||||
return fuse_integrate_flag;
|
||||
}
|
||||
|
||||
@ -61,6 +61,7 @@ class Modify : protected Pointers {
|
||||
virtual void setup_pre_force(int);
|
||||
virtual void setup_pre_reverse(int, int);
|
||||
virtual void initial_integrate(int);
|
||||
virtual void fused_integrate();
|
||||
virtual void post_integrate();
|
||||
virtual void pre_exchange();
|
||||
virtual void pre_neighbor();
|
||||
@ -150,6 +151,8 @@ class Modify : protected Pointers {
|
||||
|
||||
double memory_usage();
|
||||
|
||||
int check_fuse_integrate();
|
||||
|
||||
protected:
|
||||
// internal fix counts
|
||||
|
||||
|
||||
@ -116,15 +116,14 @@ Pair::Pair(LAMMPS *lmp) :
|
||||
nondefault_history_transfer = 0;
|
||||
beyond_contact = 0;
|
||||
|
||||
// KOKKOS per-fix data masks
|
||||
// KOKKOS package
|
||||
|
||||
execution_space = Host;
|
||||
datamask_read = ALL_MASK;
|
||||
datamask_modify = ALL_MASK;
|
||||
|
||||
kokkosable = 0;
|
||||
reverse_comm_device = 0;
|
||||
copymode = 0;
|
||||
kokkosable = copymode = 0;
|
||||
reverse_comm_device = fuse_force_clear_flag = 0;
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
@ -119,12 +119,13 @@ class Pair : protected Pointers {
|
||||
|
||||
int beyond_contact, nondefault_history_transfer; // for granular styles
|
||||
|
||||
// KOKKOS host/device flag and data masks
|
||||
// KOKKOS flags and variables
|
||||
|
||||
ExecutionSpace execution_space;
|
||||
unsigned int datamask_read, datamask_modify;
|
||||
int kokkosable; // 1 if Kokkos pair
|
||||
int reverse_comm_device; // 1 if reverse comm on Device
|
||||
int fuse_force_clear_flag; // 1 if can fuse force clear with force compute
|
||||
|
||||
Pair(class LAMMPS *);
|
||||
~Pair() override;
|
||||
|
||||
@ -63,6 +63,7 @@ class Timer : protected Pointers {
|
||||
bool has_normal() const { return (_level >= NORMAL); }
|
||||
bool has_full() const { return (_level >= FULL); }
|
||||
bool has_sync() const { return (_sync != OFF); }
|
||||
bool has_timeout() const { return (_timeout >= 0.0); }
|
||||
|
||||
// flag if wallclock time is expired
|
||||
bool is_timeout() const { return (_timeout == 0.0); }
|
||||
|
||||
Reference in New Issue
Block a user