Switch to vector parallelism for half list in QEq

This commit is contained in:
Stan Gerald Moore
2022-03-03 08:09:25 -07:00
parent b6b7846c50
commit ee2b9f28cb
2 changed files with 49 additions and 45 deletions

View File

@ -858,20 +858,6 @@ void FixQEqReaxFFKokkos<DeviceType>::sparse_matvec_kokkos(typename AT::t_ffloat2
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagQEqSparseMatvec1>(0,nn),*this); Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagQEqSparseMatvec1>(0,nn),*this);
if (neighflag != FULL) {
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagQEqZeroQGhosts>(nn,NN),*this);
if (need_dup)
dup_o.reset_except(d_o);
if (neighflag == HALF)
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagQEqSparseMatvec2_Half<HALF> >(0,nn),*this);
else if (neighflag == HALFTHREAD)
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagQEqSparseMatvec2_Half<HALFTHREAD> >(0,nn),*this);
if (need_dup)
Kokkos::Experimental::contribute(d_o, dup_o);
} else { // FULL
int teamsize; int teamsize;
int vectorsize; int vectorsize;
int leaguesize; int leaguesize;
@ -885,8 +871,21 @@ void FixQEqReaxFFKokkos<DeviceType>::sparse_matvec_kokkos(typename AT::t_ffloat2
leaguesize = (nn + teamsize - 1) / (teamsize); leaguesize = (nn + teamsize - 1) / (teamsize);
} }
if (neighflag != FULL) {
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType,TagQEqZeroQGhosts>(nn,NN),*this);
if (need_dup)
dup_o.reset_except(d_o);
if (neighflag == HALF)
Kokkos::parallel_for(Kokkos::TeamPolicy<DeviceType, TagQEqSparseMatvec2_Half<HALF>>(leaguesize, teamsize, vectorsize), *this);
else if (neighflag == HALFTHREAD)
Kokkos::parallel_for(Kokkos::TeamPolicy<DeviceType, TagQEqSparseMatvec2_Half<HALFTHREAD>>(leaguesize, teamsize, vectorsize), *this);
if (need_dup)
Kokkos::Experimental::contribute(d_o, dup_o);
} else // FULL
Kokkos::parallel_for(Kokkos::TeamPolicy <DeviceType, TagQEqSparseMatvec2_Full>(leaguesize, teamsize, vectorsize), *this); Kokkos::parallel_for(Kokkos::TeamPolicy <DeviceType, TagQEqSparseMatvec2_Full>(leaguesize, teamsize, vectorsize), *this);
}
} }
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
@ -925,19 +924,21 @@ void FixQEqReaxFFKokkos<DeviceType>::operator()(TagQEqZeroQGhosts, const int &i)
template<class DeviceType> template<class DeviceType>
template<int NEIGHFLAG> template<int NEIGHFLAG>
KOKKOS_INLINE_FUNCTION KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::operator()(TagQEqSparseMatvec2_Half<NEIGHFLAG>, const int &ii) const void FixQEqReaxFFKokkos<DeviceType>::operator()(TagQEqSparseMatvec2_Half<NEIGHFLAG>, const typename Kokkos::TeamPolicy<DeviceType, TagQEqSparseMatvec2_Half<NEIGHFLAG>>::member_type &team) const
{ {
int k = team.league_rank() * team.team_size() + team.team_rank();
if (k < nn) {
// The q array is duplicated for OpenMP, atomic for CUDA, and neither for Serial // The q array is duplicated for OpenMP, atomic for CUDA, and neither for Serial
auto v_o = ScatterViewHelper<NeedDup_v<NEIGHFLAG,DeviceType>,decltype(dup_o),decltype(ndup_o)>::get(dup_o,ndup_o); auto v_o = ScatterViewHelper<NeedDup_v<NEIGHFLAG,DeviceType>,decltype(dup_o),decltype(ndup_o)>::get(dup_o,ndup_o);
auto a_o = v_o.template access<AtomicDup_v<NEIGHFLAG,DeviceType>>(); auto a_o = v_o.template access<AtomicDup_v<NEIGHFLAG,DeviceType>>();
const int i = d_ilist[ii]; const int i = d_ilist[k];
if (mask[i] & groupbit) { if (mask[i] & groupbit) {
F_FLOAT2 tmp; F_FLOAT2 tmp;
const auto d_xx_i0 = d_xx(i,0); const double d_xx_i0 = d_xx(i,0);
const auto d_xx_i1 = d_xx(i,1); const double d_xx_i1 = d_xx(i,1);
for (int jj = d_firstnbr[i]; jj < d_firstnbr[i] + d_numnbrs[i]; jj++) { Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(team, d_firstnbr[i], d_firstnbr[i] + d_numnbrs[i]), [&] (const int &jj, F_FLOAT2& tmp) {
const int j = d_jlist(jj); const int j = d_jlist(jj);
const auto d_val_jj = d_val(jj); const auto d_val_jj = d_val(jj);
if (!(converged & 1)) { if (!(converged & 1)) {
@ -948,11 +949,14 @@ void FixQEqReaxFFKokkos<DeviceType>::operator()(TagQEqSparseMatvec2_Half<NEIGHFL
tmp.v[1] += d_val_jj * d_xx(j,1); tmp.v[1] += d_val_jj * d_xx(j,1);
a_o(j,1) += d_val_jj * d_xx_i1; a_o(j,1) += d_val_jj * d_xx_i1;
} }
} }, tmp);
Kokkos::single(Kokkos::PerThread(team), [&] () {
if (!(converged & 1)) if (!(converged & 1))
a_o(i,0) += tmp.v[0]; a_o(i,0) += tmp.v[0];
if (!(converged & 2)) if (!(converged & 2))
a_o(i,1) += tmp.v[1]; a_o(i,1) += tmp.v[1];
});
}
} }
} }

View File

@ -92,7 +92,7 @@ class FixQEqReaxFFKokkos : public FixQEqReaxFF, public KokkosBase {
template<int NEIGHFLAG> template<int NEIGHFLAG>
KOKKOS_INLINE_FUNCTION KOKKOS_INLINE_FUNCTION
void operator()(TagQEqSparseMatvec2_Half<NEIGHFLAG>, const int&) const; void operator()(TagQEqSparseMatvec2_Half<NEIGHFLAG>, const typename Kokkos::TeamPolicy<DeviceType, TagQEqSparseMatvec2_Half<NEIGHFLAG>>::member_type &team) const;
typedef typename Kokkos::TeamPolicy<DeviceType, TagQEqSparseMatvec2_Full>::member_type membertype_vec; typedef typename Kokkos::TeamPolicy<DeviceType, TagQEqSparseMatvec2_Full>::member_type membertype_vec;
KOKKOS_INLINE_FUNCTION KOKKOS_INLINE_FUNCTION