Tweaks to team_size/vector_length

This commit is contained in:
Stan Gerald Moore
2021-12-20 09:58:37 -07:00
parent 94aad92b44
commit 79844a3f34
2 changed files with 25 additions and 14 deletions

View File

@ -238,16 +238,8 @@ void FixQEqReaxFFKokkos<DeviceType>::pre_force(int /*vflag*/)
} else { // GPU, use teams
Kokkos::deep_copy(d_mfill_offset,0);
int atoms_per_team = 32;
int vector_length = 1;
#ifdef KOKKOS_ENABLE_CUDA
atoms_per_team = 4;
vector_length = 32;
#endif
#ifdef KOKKOS_ENABLE_HIP
atoms_per_team = 64;
vector_length = 8;
#endif
int atoms_per_team = FixQEqReaxFFKokkos<DeviceType>::compute_h_teamsize;
int vector_length = FixQEqReaxFFKokkos<DeviceType>::compute_h_vectorsize;
int num_teams = inum / atoms_per_team + (inum % atoms_per_team ? 1 : 0);
@ -798,7 +790,7 @@ int FixQEqReaxFFKokkos<DeviceType>::cg_solve_fused()
}
else {
#ifdef HIP_OPT_SPMV
teamsize = 16;
teamsize = FixQEqReaxFFKokkos<DeviceType>::spmv_teamsize;
vectorsize = FixQEqReaxFFKokkos<DeviceType>::vectorsize;
leaguesize = (inum + teamsize - 1) / (teamsize);
#else
@ -1100,7 +1092,7 @@ int FixQEqReaxFFKokkos<DeviceType>::cg_solve2()
}
else {
#ifdef HIP_OPT_SPMV
teamsize = 16;
teamsize = FixQEqReaxFFKokkos<DeviceType>::spmv_teamsize;
vectorsize = FixQEqReaxFFKokkos<DeviceType>::vectorsize;
leaguesize = (inum + teamsize - 1) / (teamsize);
#else

View File

@ -239,15 +239,34 @@ class FixQEqReaxFFKokkos : public FixQEqReaxFF, public KokkosBase {
// There should be a better way to do this for other backends
#if defined(KOKKOS_ENABLE_CUDA)
static constexpr int spmv_teamsize = 8;
// warp length
static constexpr int vectorsize = 32;
// team size for sparse mat-vec operations
static constexpr int spmv_teamsize = 8;
// custom values for FixQEqReaxFFKokkosComputeHFunctor
static constexpr int compute_h_vectorsize = vectorsize;
static constexpr int compute_h_teamsize = 4;
#elif defined(KOKKOS_ENABLE_HIP)
static constexpr int spmv_teamsize = 16;
// wavefront length
static constexpr int vectorsize = 64;
// team size for sparse mat-vec operations
static constexpr int spmv_teamsize = 16;
// custom values for FixQEqReaxFFKokkosComputeHFunctor
static constexpr int compute_h_vectorsize = 8; // not a typo, intentionally sub-wavefront
static constexpr int compute_h_teamsize = 64;
#else
// dummy values, to be updated
static constexpr int spmv_teamsize = 4;
static constexpr int vectorsize = 32;
static constexpr int compute_h_vectorsize = 1;
static constexpr int compute_h_teamsize = 32;
#endif
private: