Fixing CUDA runtime issues in fix_shardlow_kokkos

This commit is contained in:
Stan Moore
2017-06-09 09:31:37 -06:00
parent 86497949f2
commit c51cadcc6c
2 changed files with 51 additions and 32 deletions

View File

@ -436,7 +436,7 @@ template<bool STACKPARAMS>
KOKKOS_INLINE_FUNCTION
void FixShardlowKokkos<DeviceType>::ssa_update_dpde(
int start_ii, int count, int id
)
) const
{
#ifdef DPD_USE_RAN_MARS
class RanMars *pRNG = pp_random[id];
@ -682,26 +682,18 @@ void FixShardlowKokkos<DeviceType>::initial_integrate(int vflag)
dt = update->dt;
// process neighbors in the local AIR
for (int workPhase = 0; workPhase < ssa_phaseCt; ++workPhase) {
for (workPhase = 0; workPhase < ssa_phaseCt; ++workPhase) {
int workItemCt = h_ssa_phaseLen[workPhase];
if(atom->ntypes > MAX_TYPES_STACKPARAMS) {
Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) {
int ct = ssa_itemLen(workPhase, workItem);
int ii = ssa_itemLoc(workPhase, workItem);
ssa_update_dpde<false>(ii, ct, workItem);
});
} else {
Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) {
int ct = ssa_itemLen(workPhase, workItem);
int ii = ssa_itemLoc(workPhase, workItem);
ssa_update_dpde<true>(ii, ct, workItem);
});
}
if(atom->ntypes > MAX_TYPES_STACKPARAMS)
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagFixShardlowSSAUpdateDPDE<false> >(0,workItemCt),*this);
else
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagFixShardlowSSAUpdateDPDE<true> >(0,workItemCt),*this);
}
//Loop over all 13 outward directions (7 stages)
for (int workPhase = 0; workPhase < ssa_gphaseCt; ++workPhase) {
for (workPhase = 0; workPhase < ssa_gphaseCt; ++workPhase) {
// int airnum = workPhase + 1;
int workItemCt = h_ssa_gphaseLen[workPhase];
@ -713,27 +705,21 @@ void FixShardlowKokkos<DeviceType>::initial_integrate(int vflag)
// memset(&(atom->uCond[nlocal]), 0, sizeof(double)*nghost);
// memset(&(atom->uMech[nlocal]), 0, sizeof(double)*nghost);
// must capture local variables, not class variables
auto l_uCond = uCond;
auto l_uMech = uMech;
Kokkos::parallel_for(Kokkos::RangePolicy<LMPDeviceType>(nlocal,nlocal+nghost), LAMMPS_LAMBDA (const int i) {
uCond(i) = 0.0;
uMech(i) = 0.0;
l_uCond(i) = 0.0;
l_uMech(i) = 0.0;
});
DeviceType::fence();
}
// process neighbors in this AIR
if(atom->ntypes > MAX_TYPES_STACKPARAMS) {
Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) {
int ct = ssa_gitemLen(workPhase, workItem);
int ii = ssa_gitemLoc(workPhase, workItem);
ssa_update_dpde<false>(ii, ct, workItem);
});
} else {
Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) {
int ct = ssa_gitemLen(workPhase, workItem);
int ii = ssa_gitemLoc(workPhase, workItem);
ssa_update_dpde<true>(ii, ct, workItem);
});
}
if(atom->ntypes > MAX_TYPES_STACKPARAMS)
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagFixShardlowSSAUpdateDPDEGhost<false> >(0,workItemCt),*this);
else
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagFixShardlowSSAUpdateDPDEGhost<true> >(0,workItemCt),*this);
// Communicate the ghost deltas to the atom owners
comm->reverse_comm_fix(this);
@ -755,6 +741,24 @@ fprintf(stdout, "\n%6d %6d,%6d %6d: "
copymode = 0;
}
template<class DeviceType>
template<bool STACKPARAMS>
KOKKOS_INLINE_FUNCTION
void FixShardlowKokkos<DeviceType>::operator()(TagFixShardlowSSAUpdateDPDE<STACKPARAMS>, const int &workItem) const {
const int ct = ssa_itemLen(workPhase, workItem);
const int ii = ssa_itemLoc(workPhase, workItem);
ssa_update_dpde<STACKPARAMS>(ii, ct, workItem);
}
template<class DeviceType>
template<bool STACKPARAMS>
KOKKOS_INLINE_FUNCTION
void FixShardlowKokkos<DeviceType>::operator()(TagFixShardlowSSAUpdateDPDEGhost<STACKPARAMS>, const int &workItem) const {
const int ct = ssa_gitemLen(workPhase, workItem);
const int ii = ssa_gitemLoc(workPhase, workItem);
ssa_update_dpde<STACKPARAMS>(ii, ct, workItem);
}
/* ---------------------------------------------------------------------- */
template<class DeviceType>

View File

@ -30,6 +30,12 @@ FixStyle(shardlow/kk/host,FixShardlowKokkos<LMPHostType>)
namespace LAMMPS_NS {
template<bool STACKPARAMS>
struct TagFixShardlowSSAUpdateDPDE{};
template<bool STACKPARAMS>
struct TagFixShardlowSSAUpdateDPDEGhost{};
template<class DeviceType>
class FixShardlowKokkos : public FixShardlow {
public:
@ -60,6 +66,14 @@ class FixShardlowKokkos : public FixShardlow {
F_FLOAT cutinv,halfsigma,kappa,alpha;
};
template<bool STACKPARAMS>
KOKKOS_INLINE_FUNCTION
void operator()(TagFixShardlowSSAUpdateDPDE<STACKPARAMS>, const int&) const;
template<bool STACKPARAMS>
KOKKOS_INLINE_FUNCTION
void operator()(TagFixShardlowSSAUpdateDPDEGhost<STACKPARAMS>, const int&) const;
#ifdef DEBUG_PAIR_CT
typename AT::t_int_2d d_counters;
typename HAT::t_int_2d h_counters;
@ -68,6 +82,7 @@ class FixShardlowKokkos : public FixShardlow {
#endif
protected:
int workPhase;
double boltz_inv,ftm2v,dt;
// class PairDPDfdt *pairDPD;
@ -127,7 +142,7 @@ class FixShardlowKokkos : public FixShardlow {
// void ssa_update_dpd(int, int); // Constant Temperature
template<bool STACKPARAMS>
KOKKOS_INLINE_FUNCTION
void ssa_update_dpde(int, int, int); // Constant Energy
void ssa_update_dpde(int, int, int) const; // Constant Energy
};