diff --git a/src/KOKKOS/fix_shardlow_kokkos.cpp b/src/KOKKOS/fix_shardlow_kokkos.cpp index b3d4e86244..d2fb937a57 100644 --- a/src/KOKKOS/fix_shardlow_kokkos.cpp +++ b/src/KOKKOS/fix_shardlow_kokkos.cpp @@ -436,7 +436,7 @@ template KOKKOS_INLINE_FUNCTION void FixShardlowKokkos::ssa_update_dpde( int start_ii, int count, int id -) +) const { #ifdef DPD_USE_RAN_MARS class RanMars *pRNG = pp_random[id]; @@ -682,26 +682,18 @@ void FixShardlowKokkos::initial_integrate(int vflag) dt = update->dt; // process neighbors in the local AIR - for (int workPhase = 0; workPhase < ssa_phaseCt; ++workPhase) { + for (workPhase = 0; workPhase < ssa_phaseCt; ++workPhase) { int workItemCt = h_ssa_phaseLen[workPhase]; - if(atom->ntypes > MAX_TYPES_STACKPARAMS) { - Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) { - int ct = ssa_itemLen(workPhase, workItem); - int ii = ssa_itemLoc(workPhase, workItem); - ssa_update_dpde(ii, ct, workItem); - }); - } else { - Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) { - int ct = ssa_itemLen(workPhase, workItem); - int ii = ssa_itemLoc(workPhase, workItem); - ssa_update_dpde(ii, ct, workItem); - }); - } + + if(atom->ntypes > MAX_TYPES_STACKPARAMS) + Kokkos::parallel_for(Kokkos::RangePolicy >(0,workItemCt),*this); + else + Kokkos::parallel_for(Kokkos::RangePolicy >(0,workItemCt),*this); } //Loop over all 13 outward directions (7 stages) - for (int workPhase = 0; workPhase < ssa_gphaseCt; ++workPhase) { + for (workPhase = 0; workPhase < ssa_gphaseCt; ++workPhase) { // int airnum = workPhase + 1; int workItemCt = h_ssa_gphaseLen[workPhase]; @@ -713,27 +705,21 @@ void FixShardlowKokkos::initial_integrate(int vflag) // memset(&(atom->uCond[nlocal]), 0, sizeof(double)*nghost); // memset(&(atom->uMech[nlocal]), 0, sizeof(double)*nghost); + // must capture local variables, not class variables + auto l_uCond = uCond; + auto l_uMech = uMech; Kokkos::parallel_for(Kokkos::RangePolicy(nlocal,nlocal+nghost), LAMMPS_LAMBDA (const int i) { - uCond(i) = 0.0; - uMech(i) = 0.0; + l_uCond(i) = 0.0; + l_uMech(i) = 0.0; }); DeviceType::fence(); } // process neighbors in this AIR - if(atom->ntypes > MAX_TYPES_STACKPARAMS) { - Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) { - int ct = ssa_gitemLen(workPhase, workItem); - int ii = ssa_gitemLoc(workPhase, workItem); - ssa_update_dpde(ii, ct, workItem); - }); - } else { - Kokkos::parallel_for(workItemCt, LAMMPS_LAMBDA (const int workItem ) { - int ct = ssa_gitemLen(workPhase, workItem); - int ii = ssa_gitemLoc(workPhase, workItem); - ssa_update_dpde(ii, ct, workItem); - }); - } + if(atom->ntypes > MAX_TYPES_STACKPARAMS) + Kokkos::parallel_for(Kokkos::RangePolicy >(0,workItemCt),*this); + else + Kokkos::parallel_for(Kokkos::RangePolicy >(0,workItemCt),*this); // Communicate the ghost deltas to the atom owners comm->reverse_comm_fix(this); @@ -755,6 +741,24 @@ fprintf(stdout, "\n%6d %6d,%6d %6d: " copymode = 0; } +template +template +KOKKOS_INLINE_FUNCTION +void FixShardlowKokkos::operator()(TagFixShardlowSSAUpdateDPDE, const int &workItem) const { + const int ct = ssa_itemLen(workPhase, workItem); + const int ii = ssa_itemLoc(workPhase, workItem); + ssa_update_dpde(ii, ct, workItem); +} + +template +template +KOKKOS_INLINE_FUNCTION +void FixShardlowKokkos::operator()(TagFixShardlowSSAUpdateDPDEGhost, const int &workItem) const { + const int ct = ssa_gitemLen(workPhase, workItem); + const int ii = ssa_gitemLoc(workPhase, workItem); + ssa_update_dpde(ii, ct, workItem); +} + /* ---------------------------------------------------------------------- */ template diff --git a/src/KOKKOS/fix_shardlow_kokkos.h b/src/KOKKOS/fix_shardlow_kokkos.h index df8849d80b..91a2fdbc97 100644 --- a/src/KOKKOS/fix_shardlow_kokkos.h +++ b/src/KOKKOS/fix_shardlow_kokkos.h @@ -30,6 +30,12 @@ FixStyle(shardlow/kk/host,FixShardlowKokkos) namespace LAMMPS_NS { +template +struct TagFixShardlowSSAUpdateDPDE{}; + +template +struct TagFixShardlowSSAUpdateDPDEGhost{}; + template class FixShardlowKokkos : public FixShardlow { public: @@ -60,6 +66,14 @@ class FixShardlowKokkos : public FixShardlow { F_FLOAT cutinv,halfsigma,kappa,alpha; }; + template + KOKKOS_INLINE_FUNCTION + void operator()(TagFixShardlowSSAUpdateDPDE, const int&) const; + + template + KOKKOS_INLINE_FUNCTION + void operator()(TagFixShardlowSSAUpdateDPDEGhost, const int&) const; + #ifdef DEBUG_PAIR_CT typename AT::t_int_2d d_counters; typename HAT::t_int_2d h_counters; @@ -68,6 +82,7 @@ class FixShardlowKokkos : public FixShardlow { #endif protected: + int workPhase; double boltz_inv,ftm2v,dt; // class PairDPDfdt *pairDPD; @@ -127,7 +142,7 @@ class FixShardlowKokkos : public FixShardlow { // void ssa_update_dpd(int, int); // Constant Temperature template KOKKOS_INLINE_FUNCTION - void ssa_update_dpde(int, int, int); // Constant Energy + void ssa_update_dpde(int, int, int) const; // Constant Energy };