From 28e64fca94cb8abe3f0b2c483dccf5652df8235b Mon Sep 17 00:00:00 2001 From: Evan Weinberg Date: Thu, 21 Nov 2024 14:06:10 -0800 Subject: [PATCH] Added templating over chemsnap for ComputeZi and ComputeBi --- src/KOKKOS/pair_snap_kokkos.h | 28 +++++++-------- src/KOKKOS/pair_snap_kokkos_impl.h | 56 +++++++++++++++++------------- src/KOKKOS/sna_kokkos.h | 4 +-- src/KOKKOS/sna_kokkos_impl.h | 49 +++++++++++++++----------- 4 files changed, 76 insertions(+), 61 deletions(-) diff --git a/src/KOKKOS/pair_snap_kokkos.h b/src/KOKKOS/pair_snap_kokkos.h index 13e838356b..c42d7cdb88 100644 --- a/src/KOKKOS/pair_snap_kokkos.h +++ b/src/KOKKOS/pair_snap_kokkos.h @@ -44,8 +44,8 @@ namespace LAMMPS_NS { // Routines for both the CPU and GPU backend struct TagPairSNAPPreUi{}; struct TagPairSNAPTransformUi{}; // re-order ulisttot from SoA to AoSoA, zero ylist -struct TagPairSNAPComputeZi{}; -struct TagPairSNAPComputeBi{}; +template struct TagPairSNAPComputeZi{}; +template struct TagPairSNAPComputeBi{}; struct TagPairSNAPComputeBetaLinear{}; struct TagPairSNAPComputeBetaQuadratic{}; struct TagPairSNAPComputeYi{}; @@ -222,23 +222,23 @@ class PairSNAPKokkos : public PairSNAP { KOKKOS_INLINE_FUNCTION void operator() (TagPairSNAPTransformUi, const int& iatom) const; - KOKKOS_INLINE_FUNCTION - void operator() (TagPairSNAPComputeZi, const int& iatom_mod, const int& idxz, const int& iatom_div) const; + template KOKKOS_INLINE_FUNCTION + void operator() (TagPairSNAPComputeZi, const int& iatom_mod, const int& idxz, const int& iatom_div) const; - KOKKOS_INLINE_FUNCTION - void operator() (TagPairSNAPComputeZi, const int& iatom, const int& idxz) const; + template KOKKOS_INLINE_FUNCTION + void operator() (TagPairSNAPComputeZi, const int& iatom, const int& idxz) const; - KOKKOS_INLINE_FUNCTION - void operator() (TagPairSNAPComputeZi, const int& iatom) const; + template KOKKOS_INLINE_FUNCTION + void operator() (TagPairSNAPComputeZi, const int& iatom) const; - KOKKOS_INLINE_FUNCTION - void operator() (TagPairSNAPComputeBi, const int& iatom_mod, const int& idxb, const int& iatom_div) const; + template KOKKOS_INLINE_FUNCTION + void operator() (TagPairSNAPComputeBi, const int& iatom_mod, const int& idxb, const int& iatom_div) const; - KOKKOS_INLINE_FUNCTION - void operator() (TagPairSNAPComputeBi, const int& iatom, const int& idxb) const; + template KOKKOS_INLINE_FUNCTION + void operator() (TagPairSNAPComputeBi, const int& iatom, const int& idxb) const; - KOKKOS_INLINE_FUNCTION - void operator() (TagPairSNAPComputeBi, const int& iatom) const; + template KOKKOS_INLINE_FUNCTION + void operator() (TagPairSNAPComputeBi, const int& iatom) const; KOKKOS_INLINE_FUNCTION void operator() (TagPairSNAPComputeBetaLinear, const int& iatom_mod, const int& idxb, const int& iatom_div) const; diff --git a/src/KOKKOS/pair_snap_kokkos_impl.h b/src/KOKKOS/pair_snap_kokkos_impl.h index c2e546912e..dfdee2e1c0 100644 --- a/src/KOKKOS/pair_snap_kokkos_impl.h +++ b/src/KOKKOS/pair_snap_kokkos_impl.h @@ -302,13 +302,21 @@ void PairSNAPKokkos::compute(int eflag_in, if (quadraticflag || eflag) { // team_size_[compute_zi, compute_bi, transform_bi] are defined in `pair_snap_kokkos.h` - //ComputeZi - auto policy_compute_zi = snap_get_policy(chunk_size_div, snaKK.idxz_max); - Kokkos::parallel_for("ComputeZi", policy_compute_zi, *this); + //ComputeZi and Bi + if (nelements > 1) { + auto policy_compute_zi = snap_get_policy, min_blocks_compute_zi>(chunk_size_div, snaKK.idxz_max); + Kokkos::parallel_for("ComputeZiChemsnap", policy_compute_zi, *this); + + auto policy_compute_bi = snap_get_policy>(chunk_size_div, snaKK.idxb_max); + Kokkos::parallel_for("ComputeBiChemsnap", policy_compute_bi, *this); + } else { + auto policy_compute_zi = snap_get_policy, min_blocks_compute_zi>(chunk_size_div, snaKK.idxz_max); + Kokkos::parallel_for("ComputeZi", policy_compute_zi, *this); + + auto policy_compute_bi = snap_get_policy>(chunk_size_div, snaKK.idxb_max); + Kokkos::parallel_for("ComputeBi", policy_compute_bi, *this); + } - //ComputeBi - auto policy_compute_bi = snap_get_policy(chunk_size_div, snaKK.idxb_max); - Kokkos::parallel_for("ComputeBi", policy_compute_bi, *this); } { @@ -884,27 +892,27 @@ void PairSNAPKokkos::operator() (TagPairSN ------------------------------------------------------------------------- */ template -KOKKOS_INLINE_FUNCTION -void PairSNAPKokkos::operator() (TagPairSNAPComputeZi, const int& iatom_mod, const int& jjz, const int& iatom_div) const { +template KOKKOS_INLINE_FUNCTION +void PairSNAPKokkos::operator() (TagPairSNAPComputeZi, const int& iatom_mod, const int& jjz, const int& iatom_div) const { const int iatom = iatom_mod + iatom_div * vector_length; if (iatom >= chunk_size) return; if (jjz >= snaKK.idxz_max) return; - snaKK.compute_zi(iatom, jjz); + snaKK.template compute_zi(iatom, jjz); } template -KOKKOS_INLINE_FUNCTION -void PairSNAPKokkos::operator() (TagPairSNAPComputeZi, const int& iatom, const int& jjz) const { +template KOKKOS_INLINE_FUNCTION +void PairSNAPKokkos::operator() (TagPairSNAPComputeZi, const int& iatom, const int& jjz) const { if (iatom >= chunk_size) return; - snaKK.compute_zi(iatom, jjz); + snaKK.template compute_zi(iatom, jjz); } template -KOKKOS_INLINE_FUNCTION -void PairSNAPKokkos::operator() (TagPairSNAPComputeZi, const int& iatom) const { +template KOKKOS_INLINE_FUNCTION +void PairSNAPKokkos::operator() (TagPairSNAPComputeZi, const int& iatom) const { if (iatom >= chunk_size) return; for (int jjz = 0; jjz < snaKK.idxz_max; jjz++) - snaKK.compute_zi(iatom, jjz); + snaKK.template compute_zi(iatom, jjz); } /* ---------------------------------------------------------------------- @@ -913,27 +921,27 @@ void PairSNAPKokkos::operator() (TagPairSN ------------------------------------------------------------------------- */ template -KOKKOS_INLINE_FUNCTION -void PairSNAPKokkos::operator() (TagPairSNAPComputeBi, const int& iatom_mod, const int& jjb, const int& iatom_div) const { +template KOKKOS_INLINE_FUNCTION +void PairSNAPKokkos::operator() (TagPairSNAPComputeBi, const int& iatom_mod, const int& jjb, const int& iatom_div) const { const int iatom = iatom_mod + iatom_div * vector_length; if (iatom >= chunk_size) return; if (jjb >= snaKK.idxb_max) return; - snaKK.compute_bi(iatom, jjb); + snaKK.template compute_bi(iatom, jjb); } template -KOKKOS_INLINE_FUNCTION -void PairSNAPKokkos::operator() (TagPairSNAPComputeBi, const int& iatom, const int& jjb) const { +template KOKKOS_INLINE_FUNCTION +void PairSNAPKokkos::operator() (TagPairSNAPComputeBi, const int& iatom, const int& jjb) const { if (iatom >= chunk_size) return; - snaKK.compute_bi(iatom, jjb); + snaKK.template compute_bi(iatom, jjb); } template -KOKKOS_INLINE_FUNCTION -void PairSNAPKokkos::operator() (TagPairSNAPComputeBi, const int& iatom) const { +template KOKKOS_INLINE_FUNCTION +void PairSNAPKokkos::operator() (TagPairSNAPComputeBi, const int& iatom) const { if (iatom >= chunk_size) return; for (int jjb = 0; jjb < snaKK.idxb_max; jjb++) - snaKK.compute_bi(iatom, jjb); + snaKK.template compute_bi(iatom, jjb); } /* ---------------------------------------------------------------------- diff --git a/src/KOKKOS/sna_kokkos.h b/src/KOKKOS/sna_kokkos.h index 922aa22351..ee1eb263dc 100644 --- a/src/KOKKOS/sna_kokkos.h +++ b/src/KOKKOS/sna_kokkos.h @@ -204,13 +204,13 @@ class SNAKokkos { KOKKOS_INLINE_FUNCTION void transform_ui(const int&, const int&) const; - KOKKOS_INLINE_FUNCTION + template KOKKOS_INLINE_FUNCTION void compute_zi(const int&, const int&) const; // ForceSNAP template KOKKOS_INLINE_FUNCTION void compute_yi(const int&, const int&) const; // ForceSNAP template KOKKOS_INLINE_FUNCTION void compute_yi_with_zlist(const int&, const int&) const; // ForceSNAP - KOKKOS_INLINE_FUNCTION + template KOKKOS_INLINE_FUNCTION void compute_bi(const int&, const int&) const; // ForceSNAP KOKKOS_INLINE_FUNCTION void compute_beta_linear(const int&, const int&, const int&) const; diff --git a/src/KOKKOS/sna_kokkos_impl.h b/src/KOKKOS/sna_kokkos_impl.h index f0d4881f8d..2d567759ea 100644 --- a/src/KOKKOS/sna_kokkos_impl.h +++ b/src/KOKKOS/sna_kokkos_impl.h @@ -794,7 +794,7 @@ void SNAKokkos::transform_ui(const int& ia ------------------------------------------------------------------------- */ template -KOKKOS_INLINE_FUNCTION +template KOKKOS_INLINE_FUNCTION void SNAKokkos::compute_zi(const int& iatom, const int& jjz) const { int j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, idxcg; @@ -802,14 +802,17 @@ void SNAKokkos::compute_zi(const int& iato const real_type *cgblock = cglist.data() + idxcg; - int idouble = 0; - - for (int elem1 = 0; elem1 < nelements; elem1++) { - for (int elem2 = 0; elem2 < nelements; elem2++) { - zlist(iatom, idouble, jjz) = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, elem1, elem2, cgblock); - idouble++; - } // end loop over elem2 - } // end loop over elem1 + if constexpr (chemsnap) { + int idouble = 0; + for (int elem1 = 0; elem1 < nelements; elem1++) { + for (int elem2 = 0; elem2 < nelements; elem2++) { + zlist(iatom, idouble, jjz) = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, elem1, elem2, cgblock); + idouble++; + } // end loop over elem2 + } // end loop over elem1 + } else { + zlist(iatom, 0, jjz) = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, 0, 0, cgblock); + } } /* ---------------------------------------------------------------------- @@ -873,7 +876,7 @@ typename SNAKokkos::complex SNAKokkos -KOKKOS_INLINE_FUNCTION +template KOKKOS_INLINE_FUNCTION void SNAKokkos::compute_bi(const int& iatom, const int& jjb) const { // for j1 = 0,...,twojmax @@ -892,17 +895,21 @@ void SNAKokkos::compute_bi(const int& iato const int jjz = idxz_block(j1,j2,j); const int jju = idxu_block[j]; - int itriple = 0; - int idouble = 0; - for (int elem1 = 0; elem1 < nelements; elem1++) { - for (int elem2 = 0; elem2 < nelements; elem2++) { - for (int elem3 = 0; elem3 < nelements; elem3++) { - blist(iatom, itriple, jjb) = evaluate_bi(j, jjz, jju, iatom, elem1, elem2, elem3); - itriple++; - } // end loop over elem3 - idouble++; - } // end loop over elem2 - } // end loop over elem1 + if constexpr (chemsnap) { + int itriple = 0; + int idouble = 0; + for (int elem1 = 0; elem1 < nelements; elem1++) { + for (int elem2 = 0; elem2 < nelements; elem2++) { + for (int elem3 = 0; elem3 < nelements; elem3++) { + blist(iatom, itriple, jjb) = evaluate_bi(j, jjz, jju, iatom, elem1, elem2, elem3); + itriple++; + } // end loop over elem3 + idouble++; + } // end loop over elem2 + } // end loop over elem1 + } else { + blist(iatom, 0, jjb) = evaluate_bi(j, jjz, jju, iatom, 0, 0, 0); + } } /* ----------------------------------------------------------------------