Fixing CUDA runtime issues in fix_shardlow_kokkos
This commit is contained in:
@ -436,7 +436,7 @@ template<bool STACKPARAMS>
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void FixShardlowKokkos<DeviceType>::ssa_update_dpde(
|
||||
int start_ii, int count, int id
|
||||
)
|
||||
) const
|
||||
{
|
||||
#ifdef DPD_USE_RAN_MARS
|
||||
class RanMars *pRNG = pp_random[id];
|
||||
@ -682,26 +682,18 @@ void FixShardlowKokkos<DeviceType>::initial_integrate(int vflag)
|
||||
dt = update->dt;
|
||||
|
||||
// process neighbors in the local AIR
|
||||
for (int workPhase = 0; workPhase < ssa_phaseCt; ++workPhase) {
|
||||
for (workPhase = 0; workPhase < ssa_phaseCt; ++workPhase) {
|
||||
int workItemCt = h_ssa_phaseLen[workPhase];
|
||||
|
||||
if(atom->ntypes > MAX_TYPES_STACKPARAMS) {
|
||||
Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) {
|
||||
int ct = ssa_itemLen(workPhase, workItem);
|
||||
int ii = ssa_itemLoc(workPhase, workItem);
|
||||
ssa_update_dpde<false>(ii, ct, workItem);
|
||||
});
|
||||
} else {
|
||||
Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) {
|
||||
int ct = ssa_itemLen(workPhase, workItem);
|
||||
int ii = ssa_itemLoc(workPhase, workItem);
|
||||
ssa_update_dpde<true>(ii, ct, workItem);
|
||||
});
|
||||
}
|
||||
|
||||
if(atom->ntypes > MAX_TYPES_STACKPARAMS)
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagFixShardlowSSAUpdateDPDE<false> >(0,workItemCt),*this);
|
||||
else
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagFixShardlowSSAUpdateDPDE<true> >(0,workItemCt),*this);
|
||||
}
|
||||
|
||||
//Loop over all 13 outward directions (7 stages)
|
||||
for (int workPhase = 0; workPhase < ssa_gphaseCt; ++workPhase) {
|
||||
for (workPhase = 0; workPhase < ssa_gphaseCt; ++workPhase) {
|
||||
// int airnum = workPhase + 1;
|
||||
int workItemCt = h_ssa_gphaseLen[workPhase];
|
||||
|
||||
@ -713,27 +705,21 @@ void FixShardlowKokkos<DeviceType>::initial_integrate(int vflag)
|
||||
// memset(&(atom->uCond[nlocal]), 0, sizeof(double)*nghost);
|
||||
// memset(&(atom->uMech[nlocal]), 0, sizeof(double)*nghost);
|
||||
|
||||
// must capture local variables, not class variables
|
||||
auto l_uCond = uCond;
|
||||
auto l_uMech = uMech;
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<LMPDeviceType>(nlocal,nlocal+nghost), LAMMPS_LAMBDA (const int i) {
|
||||
uCond(i) = 0.0;
|
||||
uMech(i) = 0.0;
|
||||
l_uCond(i) = 0.0;
|
||||
l_uMech(i) = 0.0;
|
||||
});
|
||||
DeviceType::fence();
|
||||
}
|
||||
|
||||
// process neighbors in this AIR
|
||||
if(atom->ntypes > MAX_TYPES_STACKPARAMS) {
|
||||
Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) {
|
||||
int ct = ssa_gitemLen(workPhase, workItem);
|
||||
int ii = ssa_gitemLoc(workPhase, workItem);
|
||||
ssa_update_dpde<false>(ii, ct, workItem);
|
||||
});
|
||||
} else {
|
||||
Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) {
|
||||
int ct = ssa_gitemLen(workPhase, workItem);
|
||||
int ii = ssa_gitemLoc(workPhase, workItem);
|
||||
ssa_update_dpde<true>(ii, ct, workItem);
|
||||
});
|
||||
}
|
||||
if(atom->ntypes > MAX_TYPES_STACKPARAMS)
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagFixShardlowSSAUpdateDPDEGhost<false> >(0,workItemCt),*this);
|
||||
else
|
||||
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagFixShardlowSSAUpdateDPDEGhost<true> >(0,workItemCt),*this);
|
||||
|
||||
// Communicate the ghost deltas to the atom owners
|
||||
comm->reverse_comm_fix(this);
|
||||
@ -755,6 +741,24 @@ fprintf(stdout, "\n%6d %6d,%6d %6d: "
|
||||
copymode = 0;
|
||||
}
|
||||
|
||||
template<class DeviceType>
|
||||
template<bool STACKPARAMS>
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void FixShardlowKokkos<DeviceType>::operator()(TagFixShardlowSSAUpdateDPDE<STACKPARAMS>, const int &workItem) const {
|
||||
const int ct = ssa_itemLen(workPhase, workItem);
|
||||
const int ii = ssa_itemLoc(workPhase, workItem);
|
||||
ssa_update_dpde<STACKPARAMS>(ii, ct, workItem);
|
||||
}
|
||||
|
||||
template<class DeviceType>
|
||||
template<bool STACKPARAMS>
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void FixShardlowKokkos<DeviceType>::operator()(TagFixShardlowSSAUpdateDPDEGhost<STACKPARAMS>, const int &workItem) const {
|
||||
const int ct = ssa_gitemLen(workPhase, workItem);
|
||||
const int ii = ssa_gitemLoc(workPhase, workItem);
|
||||
ssa_update_dpde<STACKPARAMS>(ii, ct, workItem);
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
template<class DeviceType>
|
||||
|
||||
@ -30,6 +30,12 @@ FixStyle(shardlow/kk/host,FixShardlowKokkos<LMPHostType>)
|
||||
|
||||
namespace LAMMPS_NS {
|
||||
|
||||
template<bool STACKPARAMS>
|
||||
struct TagFixShardlowSSAUpdateDPDE{};
|
||||
|
||||
template<bool STACKPARAMS>
|
||||
struct TagFixShardlowSSAUpdateDPDEGhost{};
|
||||
|
||||
template<class DeviceType>
|
||||
class FixShardlowKokkos : public FixShardlow {
|
||||
public:
|
||||
@ -60,6 +66,14 @@ class FixShardlowKokkos : public FixShardlow {
|
||||
F_FLOAT cutinv,halfsigma,kappa,alpha;
|
||||
};
|
||||
|
||||
template<bool STACKPARAMS>
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void operator()(TagFixShardlowSSAUpdateDPDE<STACKPARAMS>, const int&) const;
|
||||
|
||||
template<bool STACKPARAMS>
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void operator()(TagFixShardlowSSAUpdateDPDEGhost<STACKPARAMS>, const int&) const;
|
||||
|
||||
#ifdef DEBUG_PAIR_CT
|
||||
typename AT::t_int_2d d_counters;
|
||||
typename HAT::t_int_2d h_counters;
|
||||
@ -68,6 +82,7 @@ class FixShardlowKokkos : public FixShardlow {
|
||||
#endif
|
||||
|
||||
protected:
|
||||
int workPhase;
|
||||
double boltz_inv,ftm2v,dt;
|
||||
|
||||
// class PairDPDfdt *pairDPD;
|
||||
@ -127,7 +142,7 @@ class FixShardlowKokkos : public FixShardlow {
|
||||
// void ssa_update_dpd(int, int); // Constant Temperature
|
||||
template<bool STACKPARAMS>
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void ssa_update_dpde(int, int, int); // Constant Energy
|
||||
void ssa_update_dpde(int, int, int) const; // Constant Energy
|
||||
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user