Added templating over chemsnap for evaluate_beta_scaled

This commit is contained in:
Evan Weinberg
2024-11-21 14:33:30 -08:00
parent 261abaa683
commit ce6e0dbe68
2 changed files with 39 additions and 22 deletions

View File

@ -241,7 +241,7 @@ class SNAKokkos {
real_type evaluate_bi(const int&, const int&, const int&, const int&,
const int&, const int&, const int&) const;
// plugged into compute_yi, compute_yi_with_zlist
KOKKOS_FORCEINLINE_FUNCTION
template <bool chemsnap> KOKKOS_FORCEINLINE_FUNCTION
real_type evaluate_beta_scaled(const int&, const int&, const int&, const int&, const int&, const int&, const int&) const;
// plugged into compute_fused_deidrj_small, compute_fused_deidrj_large
KOKKOS_FORCEINLINE_FUNCTION

View File

@ -1092,7 +1092,7 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi(const int& iato
// pick out right beta value
for (int elem3 = 0; elem3 < nelements; elem3++) {
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3);
const real_type betaj = evaluate_beta_scaled<true>(j1, j2, j, iatom, elem1, elem2, elem3);
if constexpr (need_atomics) {
Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj * ztmp.re);
@ -1106,7 +1106,7 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi(const int& iato
} // end loop over elem1
} else {
const complex ztmp = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, 0, 0, cgblock);
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, 0, 0, 0);
const real_type betaj = evaluate_beta_scaled<false>(j1, j2, j, iatom, 0, 0, 0);
if constexpr (need_atomics) {
Kokkos::atomic_add(&(ylist_re(iatom, 0, jju_half)), betaj * ztmp.re);
@ -1142,7 +1142,7 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi_with_zlist(cons
// pick out right beta value
for (int elem3 = 0; elem3 < nelements; elem3++) {
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3);
const real_type betaj = evaluate_beta_scaled<true>(j1, j2, j, iatom, elem1, elem2, elem3);
if constexpr (need_atomics) {
Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj * ztmp.re);
@ -1157,7 +1157,7 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi_with_zlist(cons
} // end loop over elem1
} else {
const complex ztmp = zlist(iatom, 0, jjz);
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, 0, 0, 0);
const real_type betaj = evaluate_beta_scaled<false>(j1, j2, j, iatom, 0, 0, 0);
if constexpr (need_atomics) {
Kokkos::atomic_add(&(ylist_re(iatom, 0, jju_half)), betaj * ztmp.re);
@ -1175,30 +1175,47 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi_with_zlist(cons
------------------------------------------------------------------------- */
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_FORCEINLINE_FUNCTION
template <bool chemsnap> KOKKOS_FORCEINLINE_FUNCTION
typename SNAKokkos<DeviceType, real_type, vector_length>::real_type SNAKokkos<DeviceType, real_type, vector_length>::evaluate_beta_scaled(const int& j1, const int& j2, const int& j,
const int& iatom, const int& elem1, const int& elem2, const int& elem3) const {
real_type betaj = 0;
int itriple_jjb = 0;
real_type factor = 0;
if (j >= j1) {
const int jjb = idxb_block(j1, j2, j);
const int itriple = ((elem1 * nelements + elem2) * nelements + elem3) * idxb_max + jjb;
if (j1 == j) {
if (j2 == j) betaj = static_cast<real_type>(3) * d_beta(iatom, itriple);
else betaj = static_cast<real_type>(2) * d_beta(iatom, itriple);
} else betaj = d_beta(iatom, itriple);
} else if (j >= j2) {
const int jjb = idxb_block(j, j2, j1);
const int itriple = ((elem3 * nelements + elem2) * nelements + elem1) * idxb_max + jjb;
if (j2 == j) betaj = static_cast<real_type>(2) * d_beta(iatom, itriple);
else betaj = d_beta(iatom, itriple);
if constexpr (chemsnap) {
if (j >= j1) {
itriple_jjb = ((elem1 * nelements + elem2) * nelements + elem3) * idxb_max + idxb_block(j1, j2, j);
if (j1 == j) {
if (j2 == j) factor = 3;
else factor = 2;
} else factor = 1;
} else if (j >= j2) {
itriple_jjb = ((elem3 * nelements + elem2) * nelements + elem1) * idxb_max + idxb_block(j, j2, j1);
if (j2 == j) factor = 2;
else factor = 1;
} else {
itriple_jjb = ((elem2 * nelements + elem3) * nelements + elem1) * idxb_max + idxb_block(j2, j, j1);
factor = 1;
}
} else {
const int jjb = idxb_block(j2, j, j1);
const int itriple = ((elem2 * nelements + elem3) * nelements + elem1) * idxb_max + jjb;
betaj = d_beta(iatom, itriple);
if (j >= j1) {
itriple_jjb = idxb_block(j1, j2, j);
if (j1 == j) {
if (j2 == j) factor = 3;
else factor = 2;
} else factor = 1;
} else if (j >= j2) {
itriple_jjb = idxb_block(j, j2, j1);
if (j2 == j) factor = 2;
else factor = 1;
} else {
itriple_jjb = idxb_block(j2, j, j1);
factor = 1;
}
}
real_type betaj = factor * d_beta(iatom, itriple_jjb);
if (!bnorm_flag && j1 > j) {
const real_type scale = static_cast<real_type>(j1 + 1) / static_cast<real_type>(j + 1);
betaj *= scale;