Merge pull request #3758 from stanmoore1/kk_fuse

Fuse some Kokkos kernels to reduce launch latency for small systems
This commit is contained in:
Axel Kohlmeyer
2023-05-02 20:36:48 -04:00
committed by GitHub
17 changed files with 306 additions and 13 deletions

View File

@ -44,6 +44,7 @@ FixLangevinKokkos<DeviceType>::FixLangevinKokkos(LAMMPS *lmp, int narg, char **a
FixLangevin(lmp, narg, arg),rand_pool(seed + comm->me)
{
kokkosable = 1;
fuse_integrate_flag = 1;
sort_device = 1;
atomKK = (AtomKokkos *) atom;
int ntypes = atomKK->ntypes;
@ -170,6 +171,14 @@ void FixLangevinKokkos<DeviceType>::initial_integrate_item(int i) const
/* ---------------------------------------------------------------------- */
template<class DeviceType>
void FixLangevinKokkos<DeviceType>::fused_integrate(int vflag)
{
initial_integrate(vflag);
}
/* ---------------------------------------------------------------------- */
template<class DeviceType>
void FixLangevinKokkos<DeviceType>::post_force(int /*vflag*/)
{

View File

@ -70,6 +70,7 @@ namespace LAMMPS_NS {
void cleanup_copy();
void init() override;
void initial_integrate(int) override;
void fused_integrate(int) override;
void post_force(int) override;
void reset_dt() override;
void grow_arrays(int) override;

View File

@ -29,6 +29,7 @@ FixNVEKokkos<DeviceType>::FixNVEKokkos(LAMMPS *lmp, int narg, char **arg) :
FixNVE(lmp, narg, arg)
{
kokkosable = 1;
fuse_integrate_flag = 1;
atomKK = (AtomKokkos *) atom;
execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
@ -159,6 +160,66 @@ void FixNVEKokkos<DeviceType>::final_integrate_rmass_item(int i) const
}
}
/* ----------------------------------------------------------------------
allow for both per-type and per-atom mass
------------------------------------------------------------------------- */
template<class DeviceType>
void FixNVEKokkos<DeviceType>::fused_integrate(int /*vflag*/)
{
atomKK->sync(execution_space,datamask_read);
x = atomKK->k_x.view<DeviceType>();
v = atomKK->k_v.view<DeviceType>();
f = atomKK->k_f.view<DeviceType>();
rmass = atomKK->k_rmass.view<DeviceType>();
mass = atomKK->k_mass.view<DeviceType>();
type = atomKK->k_type.view<DeviceType>();
mask = atomKK->k_mask.view<DeviceType>();
int nlocal = atomKK->nlocal;
if (igroup == atomKK->firstgroup) nlocal = atomKK->nfirst;
if (rmass.data()) {
FixNVEKokkosFusedIntegrateFunctor<DeviceType,1> functor(this);
Kokkos::parallel_for(nlocal,functor);
} else {
FixNVEKokkosFusedIntegrateFunctor<DeviceType,0> functor(this);
Kokkos::parallel_for(nlocal,functor);
}
atomKK->modified(execution_space,datamask_modify);
}
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixNVEKokkos<DeviceType>::fused_integrate_item(int i) const
{
if (mask[i] & groupbit) {
const double dtfm = 2.0 * dtf / mass[type[i]];
v(i,0) += dtfm * f(i,0);
v(i,1) += dtfm * f(i,1);
v(i,2) += dtfm * f(i,2);
x(i,0) += dtv * v(i,0);
x(i,1) += dtv * v(i,1);
x(i,2) += dtv * v(i,2);
}
}
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixNVEKokkos<DeviceType>::fused_integrate_rmass_item(int i) const
{
if (mask[i] & groupbit) {
const double dtfm = 2.0 * dtf / rmass[i];
v(i,0) += dtfm * f(i,0);
v(i,1) += dtfm * f(i,1);
v(i,2) += dtfm * f(i,2);
x(i,0) += dtv * v(i,0);
x(i,1) += dtv * v(i,1);
x(i,2) += dtv * v(i,2);
}
}
/* ---------------------------------------------------------------------- */
template<class DeviceType>
@ -168,6 +229,8 @@ void FixNVEKokkos<DeviceType>::cleanup_copy()
vatom = nullptr;
}
/* ---------------------------------------------------------------------- */
namespace LAMMPS_NS {
template class FixNVEKokkos<LMPDeviceType>;
#ifdef LMP_KOKKOS_GPU

View File

@ -46,6 +46,7 @@ class FixNVEKokkos : public FixNVE {
void init() override;
void initial_integrate(int) override;
void final_integrate() override;
void fused_integrate(int) override;
KOKKOS_INLINE_FUNCTION
void initial_integrate_item(int) const;
@ -55,6 +56,10 @@ class FixNVEKokkos : public FixNVE {
void final_integrate_item(int) const;
KOKKOS_INLINE_FUNCTION
void final_integrate_rmass_item(int) const;
KOKKOS_INLINE_FUNCTION
void fused_integrate_item(int) const;
KOKKOS_INLINE_FUNCTION
void fused_integrate_rmass_item(int) const;
private:
@ -96,6 +101,22 @@ struct FixNVEKokkosFinalIntegrateFunctor {
}
};
template <class DeviceType, int RMass>
struct FixNVEKokkosFusedIntegrateFunctor {
typedef DeviceType device_type ;
FixNVEKokkos<DeviceType> c;
FixNVEKokkosFusedIntegrateFunctor(FixNVEKokkos<DeviceType>* c_ptr):
c(*c_ptr) {c.cleanup_copy();};
KOKKOS_INLINE_FUNCTION
void operator()(const int i) const {
if (RMass)
c.fused_integrate_rmass_item(i);
else
c.fused_integrate_item(i);
}
};
}
#endif

View File

@ -28,6 +28,7 @@ FixNVESphereKokkos<DeviceType>::FixNVESphereKokkos(LAMMPS *lmp, int narg, char *
FixNVESphere(lmp, narg, arg)
{
kokkosable = 1;
fuse_integrate_flag = 1;
atomKK = (AtomKokkos *)atom;
execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
@ -164,6 +165,73 @@ void FixNVESphereKokkos<DeviceType>::final_integrate_item(const int i) const
}
}
/* ---------------------------------------------------------------------- */
template<class DeviceType>
void FixNVESphereKokkos<DeviceType>::fused_integrate(int /*vflag*/)
{
if (extra == DIPOLE)
atomKK->sync(execution_space, X_MASK | V_MASK | OMEGA_MASK| F_MASK | TORQUE_MASK | RMASS_MASK | RADIUS_MASK | MASK_MASK | MU_MASK);
else
atomKK->sync(execution_space, X_MASK | V_MASK | OMEGA_MASK| F_MASK | TORQUE_MASK | RMASS_MASK | RADIUS_MASK | MASK_MASK);
x = atomKK->k_x.view<DeviceType>();
v = atomKK->k_v.view<DeviceType>();
omega = atomKK->k_omega.view<DeviceType>();
f = atomKK->k_f.view<DeviceType>();
torque = atomKK->k_torque.view<DeviceType>();
mask = atomKK->k_mask.view<DeviceType>();
rmass = atomKK->k_rmass.view<DeviceType>();
radius = atomKK->k_radius.view<DeviceType>();
mu = atomKK->k_mu.view<DeviceType>();
int nlocal = atom->nlocal;
if (igroup == atom->firstgroup) nlocal = atom->nfirst;
FixNVESphereKokkosFusedIntegrateFunctor<DeviceType> f(this);
Kokkos::parallel_for(nlocal,f);
if (extra == DIPOLE)
atomKK->modified(execution_space, X_MASK | V_MASK | OMEGA_MASK | MU_MASK);
else
atomKK->modified(execution_space, X_MASK | V_MASK | OMEGA_MASK);
}
/* ---------------------------------------------------------------------- */
template <class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixNVESphereKokkos<DeviceType>::fused_integrate_item(const int i) const
{
const double dtfrotate = dtf / inertia;
if (mask(i) & groupbit) {
const double dtfm = 2.0 * dtf / rmass(i);
v(i,0) += dtfm * f(i,0);
v(i,1) += dtfm * f(i,1);
v(i,2) += dtfm * f(i,2);
x(i,0) += dtv * v(i,0);
x(i,1) += dtv * v(i,1);
x(i,2) += dtv * v(i,2);
const double dtirotate = 2.0 * dtfrotate / (radius(i)*radius(i)*rmass(i));
omega(i,0) += dtirotate * torque(i,0);
omega(i,1) += dtirotate * torque(i,1);
omega(i,2) += dtirotate * torque(i,2);
if (extra == DIPOLE) {
const double g0 = mu(i,0) + dtv * (omega(i,1) * mu(i,2) - omega(i,2) * mu(i,1));
const double g1 = mu(i,1) + dtv * (omega(i,2) * mu(i,0) - omega(i,0) * mu(i,2));
const double g2 = mu(i,2) + dtv * (omega(i,0) * mu(i,1) - omega(i,1) * mu(i,0));
const double msq = g0*g0 + g1*g1 + g2*g2;
const double scale = mu(i,3)/sqrt(msq);
mu(i,0) = g0*scale;
mu(i,1) = g1*scale;
mu(i,2) = g2*scale;
}
}
}
namespace LAMMPS_NS {
template class FixNVESphereKokkos<LMPDeviceType>;
#ifdef LMP_KOKKOS_GPU

View File

@ -37,11 +37,14 @@ class FixNVESphereKokkos : public FixNVESphere {
void init() override;
void initial_integrate(int) override;
void final_integrate() override;
void fused_integrate(int) override;
KOKKOS_INLINE_FUNCTION
void initial_integrate_item(const int i) const;
KOKKOS_INLINE_FUNCTION
void final_integrate_item(const int i) const;
KOKKOS_INLINE_FUNCTION
void fused_integrate_item(int) const;
private:
typename ArrayTypes<DeviceType>::t_x_array x;
@ -77,6 +80,17 @@ struct FixNVESphereKokkosFinalIntegrateFunctor {
}
};
template <class DeviceType>
struct FixNVESphereKokkosFusedIntegrateFunctor {
typedef DeviceType device_type;
FixNVESphereKokkos<DeviceType> c;
FixNVESphereKokkosFusedIntegrateFunctor(FixNVESphereKokkos<DeviceType> *c_ptr): c(*c_ptr) { c.cleanup_copy(); }
KOKKOS_INLINE_FUNCTION
void operator()(const int i) const {
c.fused_integrate_item(i);
}
};
} // namespace LAMMPS_NS
#endif // LMP_FIX_NVE_SPHERE_KOKKOS_H

View File

@ -392,6 +392,24 @@ void ModifyKokkos::final_integrate()
}
}
/* ----------------------------------------------------------------------
fused initial and final integrate call, only for relevant fixes
------------------------------------------------------------------------- */
void ModifyKokkos::fused_integrate(int vflag)
{
for (int i = 0; i < n_final_integrate; i++) {
atomKK->sync(fix[list_final_integrate[i]]->execution_space,
fix[list_final_integrate[i]]->datamask_read);
int prev_auto_sync = lmp->kokkos->auto_sync;
if (!fix[list_final_integrate[i]]->kokkosable) lmp->kokkos->auto_sync = 1;
fix[list_final_integrate[i]]->fused_integrate(vflag);
lmp->kokkos->auto_sync = prev_auto_sync;
atomKK->modified(fix[list_final_integrate[i]]->execution_space,
fix[list_final_integrate[i]]->datamask_modify);
}
}
/* ----------------------------------------------------------------------
end-of-timestep call, only for relevant fixes
only call fix->end_of_step() on timesteps that are multiples of nevery
@ -881,3 +899,22 @@ int ModifyKokkos::min_reset_ref()
}
return itmpall;
}
/* ----------------------------------------------------------------------
check if initial and final integrate can be fused
------------------------------------------------------------------------- */
int ModifyKokkos::check_fuse_integrate()
{
int fuse_integrate_flag = 1;
for (int i = 0; i < n_initial_integrate; i++)
if (!fix[list_initial_integrate[i]]->fuse_integrate_flag)
fuse_integrate_flag = 0;
for (int i = 0; i < n_final_integrate; i++)
if (!fix[list_final_integrate[i]]->fuse_integrate_flag)
fuse_integrate_flag = 0;
return fuse_integrate_flag;
}

View File

@ -39,6 +39,7 @@ class ModifyKokkos : public Modify {
void pre_reverse(int,int) override;
void post_force(int) override;
void final_integrate() override;
void fused_integrate(int) override;
void end_of_step() override;
double energy_couple() override;
double energy_global() override;
@ -69,6 +70,8 @@ class ModifyKokkos : public Modify {
int min_dof() override;
int min_reset_ref() override;
int check_fuse_integrate();
protected:
};

View File

@ -280,6 +280,12 @@ struct PairComputeFunctor {
const X_FLOAT ztmp = c.x(i,2);
const int itype = c.type(i);
Kokkos::single(Kokkos::PerThread(team), [&] (){
f(i,0) = 0.0;
f(i,1) = 0.0;
f(i,2) = 0.0;
});
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
const int jnum = list.d_numneigh[i];
@ -337,6 +343,12 @@ struct PairComputeFunctor {
const int itype = c.type(i);
const F_FLOAT qtmp = c.q(i);
Kokkos::single(Kokkos::PerThread(team), [&] (){
f(i,0) = 0.0;
f(i,1) = 0.0;
f(i,2) = 0.0;
});
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
const int jnum = list.d_numneigh[i];
@ -399,6 +411,12 @@ struct PairComputeFunctor {
const X_FLOAT ztmp = c.x(i,2);
const int itype = c.type(i);
Kokkos::single(Kokkos::PerThread(team), [&] (){
f(i,0) = 0.0;
f(i,1) = 0.0;
f(i,2) = 0.0;
});
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
const int jnum = list.d_numneigh[i];
@ -495,6 +513,12 @@ struct PairComputeFunctor {
const int itype = c.type(i);
const F_FLOAT qtmp = c.q(i);
Kokkos::single(Kokkos::PerThread(team), [&] (){
f(i,0) = 0.0;
f(i,1) = 0.0;
f(i,2) = 0.0;
});
const AtomNeighborsConst neighbors_i = list.get_neighbors_const(i);
const int jnum = list.d_numneigh[i];
@ -743,6 +767,8 @@ EV_FLOAT pair_compute_neighlist (PairStyle* fpair, typename std::enable_if<(NEIG
fpair->lmp->kokkos->neigh_thread = 1;
if (fpair->lmp->kokkos->neigh_thread) {
fpair->fuse_force_clear_flag = 1;
int vector_length = 8;
int atoms_per_team = 32;

View File

@ -27,7 +27,7 @@
#include "kspace.h"
#include "output.h"
#include "update.h"
#include "modify.h"
#include "modify_kokkos.h"
#include "timer.h"
#include "memory_kokkos.h"
#include "kokkos.h"
@ -276,6 +276,9 @@ void VerletKokkos::run(int n)
lmp->kokkos->auto_sync = 0;
fuse_integrate = 0;
fuse_force_clear = 0;
if (atomKK->sortfreq > 0) sortflag = 1;
else sortflag = 0;
@ -296,7 +299,8 @@ void VerletKokkos::run(int n)
// initial time integration
timer->stamp();
modify->initial_integrate(vflag);
if (!fuse_integrate)
modify->initial_integrate(vflag);
if (n_post_integrate) modify->post_integrate();
timer->stamp(Timer::MODIFY);
@ -357,12 +361,17 @@ void VerletKokkos::run(int n)
}
}
// check if kernels can be fused, must come after initial_integrate
fuse_check(i,n);
// force computations
// important for pair to come before bonded contributions
// since some bonded potentials tally pairwise energy/virial
// and Pair:ev_tally() needs to be called before any tallying
force_clear();
if (!fuse_force_clear)
force_clear();
timer->stamp();
@ -494,7 +503,10 @@ void VerletKokkos::run(int n)
// force modifications, final time integration, diagnostics
if (n_post_force) modify->post_force(vflag);
modify->final_integrate();
if (fuse_integrate) modify->fused_integrate(vflag);
else modify->final_integrate();
if (n_end_of_step) modify->end_of_step();
timer->stamp(Timer::MODIFY);
@ -593,3 +605,35 @@ void VerletKokkos::force_clear()
}
}
}
/* ----------------------------------------------------------------------
check if can fuse force_clear() with pair compute()
Requirements:
- no pre_force fixes
- no torques, SPIN forces, or includegroup set
- pair compute() must be called
- pair_style must support fusing
check if can fuse initial_integrate() with final_integrate()
Requirements:
- no end_of_step fixes
- not on last or output step
- no timers to break out of loop
- integrate fix style must support fusing
------------------------------------------------------------------------- */
void VerletKokkos::fuse_check(int i, int n)
{
fuse_force_clear = 1;
if (modify->n_pre_force) fuse_force_clear = 0;
else if (torqueflag || extraflag || neighbor->includegroup) fuse_force_clear = 0;
else if (!force->pair || !pair_compute_flag) fuse_force_clear = 0;
else if (!force->pair->fuse_force_clear_flag) fuse_force_clear = 0;
fuse_integrate = 1;
if (modify->n_end_of_step) fuse_integrate = 0;
else if (i == n-1) fuse_integrate = 0;
else if (update->ntimestep == output->next) fuse_integrate = 0;
else if (timer->has_timeout()) fuse_integrate = 0;
else if (!((ModifyKokkos*)modify)->check_fuse_integrate()) fuse_integrate = 0;
}

View File

@ -46,6 +46,9 @@ class VerletKokkos : public Verlet {
protected:
DAT::t_f_array f_merge_copy,f;
int fuse_force_clear,fuse_integrate;
void fuse_check(int, int);
};
}

View File

@ -102,15 +102,15 @@ Fix::Fix(LAMMPS *lmp, int /*narg*/, char **arg) :
vflag_atom = cvflag_atom = 0;
centroidstressflag = CENTROID_SAME;
// KOKKOS per-fix data masks
// KOKKOS package
execution_space = Host;
datamask_read = ALL_MASK;
datamask_modify = ALL_MASK;
kokkosable = 0;
kokkosable = copymode = 0;
forward_comm_device = exchange_comm_device = sort_device = 0;
copymode = 0;
fuse_integrate_flag = 0;
}
/* ---------------------------------------------------------------------- */

View File

@ -127,11 +127,12 @@ class Fix : protected Pointers {
int restart_reset; // 1 if restart just re-initialized fix
// KOKKOS host/device flag and data masks
// KOKKOS flags and variables
int kokkosable; // 1 if Kokkos fix
int forward_comm_device; // 1 if forward comm on Device
int exchange_comm_device; // 1 if exchange comm on Device
int fuse_integrate_flag; // 1 if can fuse initial integrate with final integrate
int sort_device; // 1 if sort on Device
ExecutionSpace execution_space;
unsigned int datamask_read, datamask_modify;
@ -161,6 +162,7 @@ class Fix : protected Pointers {
virtual void pre_reverse(int, int) {}
virtual void post_force(int) {}
virtual void final_integrate() {}
virtual void fused_integrate(int) {}
virtual void end_of_step() {}
virtual void post_run() {}
virtual void write_restart(FILE *) {}

View File

@ -69,6 +69,7 @@ class Modify : protected Pointers {
virtual void pre_reverse(int, int);
virtual void post_force(int);
virtual void final_integrate();
virtual void fused_integrate(int) {}
virtual void end_of_step();
virtual double energy_couple();
virtual double energy_global();

View File

@ -116,15 +116,14 @@ Pair::Pair(LAMMPS *lmp) :
nondefault_history_transfer = 0;
beyond_contact = 0;
// KOKKOS per-fix data masks
// KOKKOS package
execution_space = Host;
datamask_read = ALL_MASK;
datamask_modify = ALL_MASK;
kokkosable = 0;
reverse_comm_device = 0;
copymode = 0;
kokkosable = copymode = 0;
reverse_comm_device = fuse_force_clear_flag = 0;
}
/* ---------------------------------------------------------------------- */

View File

@ -119,12 +119,13 @@ class Pair : protected Pointers {
int beyond_contact, nondefault_history_transfer; // for granular styles
// KOKKOS host/device flag and data masks
// KOKKOS flags and variables
ExecutionSpace execution_space;
unsigned int datamask_read, datamask_modify;
int kokkosable; // 1 if Kokkos pair
int reverse_comm_device; // 1 if reverse comm on Device
int fuse_force_clear_flag; // 1 if can fuse force clear with force compute
Pair(class LAMMPS *);
~Pair() override;

View File

@ -63,6 +63,7 @@ class Timer : protected Pointers {
bool has_normal() const { return (_level >= NORMAL); }
bool has_full() const { return (_level >= FULL); }
bool has_sync() const { return (_sync != OFF); }
bool has_timeout() const { return (_timeout >= 0.0); }
// flag if wallclock time is expired
bool is_timeout() const { return (_timeout == 0.0); }