diff --git a/src/KOKKOS/sna_kokkos.h b/src/KOKKOS/sna_kokkos.h index 4247a79504..a438ccd25e 100644 --- a/src/KOKKOS/sna_kokkos.h +++ b/src/KOKKOS/sna_kokkos.h @@ -241,7 +241,7 @@ class SNAKokkos { real_type evaluate_bi(const int&, const int&, const int&, const int&, const int&, const int&, const int&) const; // plugged into compute_yi, compute_yi_with_zlist - KOKKOS_FORCEINLINE_FUNCTION + template KOKKOS_FORCEINLINE_FUNCTION real_type evaluate_beta_scaled(const int&, const int&, const int&, const int&, const int&, const int&, const int&) const; // plugged into compute_fused_deidrj_small, compute_fused_deidrj_large KOKKOS_FORCEINLINE_FUNCTION diff --git a/src/KOKKOS/sna_kokkos_impl.h b/src/KOKKOS/sna_kokkos_impl.h index 4c11b1213e..9a97f229b5 100644 --- a/src/KOKKOS/sna_kokkos_impl.h +++ b/src/KOKKOS/sna_kokkos_impl.h @@ -1092,7 +1092,7 @@ void SNAKokkos::compute_yi(const int& iato // pick out right beta value for (int elem3 = 0; elem3 < nelements; elem3++) { - const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3); + const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3); if constexpr (need_atomics) { Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj * ztmp.re); @@ -1106,7 +1106,7 @@ void SNAKokkos::compute_yi(const int& iato } // end loop over elem1 } else { const complex ztmp = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, 0, 0, cgblock); - const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, 0, 0, 0); + const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, 0, 0, 0); if constexpr (need_atomics) { Kokkos::atomic_add(&(ylist_re(iatom, 0, jju_half)), betaj * ztmp.re); @@ -1142,7 +1142,7 @@ void SNAKokkos::compute_yi_with_zlist(cons // pick out right beta value for (int elem3 = 0; elem3 < nelements; elem3++) { - const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3); + const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3); if constexpr (need_atomics) { Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj * ztmp.re); @@ -1157,7 +1157,7 @@ void SNAKokkos::compute_yi_with_zlist(cons } // end loop over elem1 } else { const complex ztmp = zlist(iatom, 0, jjz); - const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, 0, 0, 0); + const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, 0, 0, 0); if constexpr (need_atomics) { Kokkos::atomic_add(&(ylist_re(iatom, 0, jju_half)), betaj * ztmp.re); @@ -1175,30 +1175,47 @@ void SNAKokkos::compute_yi_with_zlist(cons ------------------------------------------------------------------------- */ template -KOKKOS_FORCEINLINE_FUNCTION +template KOKKOS_FORCEINLINE_FUNCTION typename SNAKokkos::real_type SNAKokkos::evaluate_beta_scaled(const int& j1, const int& j2, const int& j, const int& iatom, const int& elem1, const int& elem2, const int& elem3) const { - real_type betaj = 0; + int itriple_jjb = 0; + real_type factor = 0; - if (j >= j1) { - const int jjb = idxb_block(j1, j2, j); - const int itriple = ((elem1 * nelements + elem2) * nelements + elem3) * idxb_max + jjb; - if (j1 == j) { - if (j2 == j) betaj = static_cast(3) * d_beta(iatom, itriple); - else betaj = static_cast(2) * d_beta(iatom, itriple); - } else betaj = d_beta(iatom, itriple); - } else if (j >= j2) { - const int jjb = idxb_block(j, j2, j1); - const int itriple = ((elem3 * nelements + elem2) * nelements + elem1) * idxb_max + jjb; - if (j2 == j) betaj = static_cast(2) * d_beta(iatom, itriple); - else betaj = d_beta(iatom, itriple); + if constexpr (chemsnap) { + if (j >= j1) { + itriple_jjb = ((elem1 * nelements + elem2) * nelements + elem3) * idxb_max + idxb_block(j1, j2, j); + if (j1 == j) { + if (j2 == j) factor = 3; + else factor = 2; + } else factor = 1; + } else if (j >= j2) { + itriple_jjb = ((elem3 * nelements + elem2) * nelements + elem1) * idxb_max + idxb_block(j, j2, j1); + if (j2 == j) factor = 2; + else factor = 1; + } else { + itriple_jjb = ((elem2 * nelements + elem3) * nelements + elem1) * idxb_max + idxb_block(j2, j, j1); + factor = 1; + } } else { - const int jjb = idxb_block(j2, j, j1); - const int itriple = ((elem2 * nelements + elem3) * nelements + elem1) * idxb_max + jjb; - betaj = d_beta(iatom, itriple); + if (j >= j1) { + itriple_jjb = idxb_block(j1, j2, j); + if (j1 == j) { + if (j2 == j) factor = 3; + else factor = 2; + } else factor = 1; + } else if (j >= j2) { + itriple_jjb = idxb_block(j, j2, j1); + if (j2 == j) factor = 2; + else factor = 1; + } else { + itriple_jjb = idxb_block(j2, j, j1); + factor = 1; + } } + real_type betaj = factor * d_beta(iatom, itriple_jjb); + if (!bnorm_flag && j1 > j) { const real_type scale = static_cast(j1 + 1) / static_cast(j + 1); betaj *= scale;