Added templating over chemsnap for ComputeYi and ComputeYiWithZlist

This commit is contained in:
Evan Weinberg
2024-11-21 14:17:40 -08:00
parent 28e64fca94
commit 261abaa683
4 changed files with 118 additions and 82 deletions

View File

@ -48,8 +48,8 @@ template <bool chemsnap> struct TagPairSNAPComputeZi{};
template <bool chemsnap> struct TagPairSNAPComputeBi{};
struct TagPairSNAPComputeBetaLinear{};
struct TagPairSNAPComputeBetaQuadratic{};
struct TagPairSNAPComputeYi{};
struct TagPairSNAPComputeYiWithZlist{};
template <bool chemsnap> struct TagPairSNAPComputeYi{};
template <bool chemsnap> struct TagPairSNAPComputeYiWithZlist{};
template<int NEIGHFLAG, int EVFLAG>
struct TagPairSNAPComputeForce{};
@ -258,23 +258,23 @@ class PairSNAPKokkos : public PairSNAP {
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeBetaQuadratic, const int& iatom) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYi, const int& iatom_mod, const int& idxz, const int& iatom_div) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYi<chemsnap>, const int& iatom_mod, const int& idxz, const int& iatom_div) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYi, const int& iatom, const int& idxz) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYi<chemsnap>, const int& iatom, const int& idxz) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYi, const int& iatom) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYi<chemsnap>, const int& iatom) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYiWithZlist, const int& iatom_mod, const int& idxz, const int& iatom_div) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYiWithZlist<chemsnap>, const int& iatom_mod, const int& idxz, const int& iatom_div) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYiWithZlist, const int& iatom, const int& idxz) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYiWithZlist<chemsnap>, const int& iatom, const int& idxz) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYiWithZlist, const int& iatom) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeYiWithZlist<chemsnap>, const int& iatom) const;
template<int dir>
KOKKOS_INLINE_FUNCTION

View File

@ -332,11 +332,21 @@ void PairSNAPKokkos<DeviceType, real_type, vector_length>::compute(int eflag_in,
//Note zeroing `ylist` is fused into `TransformUi`.
if (quadraticflag || eflag) {
auto policy_compute_yi = snap_get_policy<DeviceType, tile_size_compute_yi, TagPairSNAPComputeYiWithZlist>(chunk_size_div, snaKK.idxz_max);
Kokkos::parallel_for("ComputeYiWithZlist", policy_compute_yi, *this);
if (nelements > 1) {
auto policy_compute_yi = snap_get_policy<DeviceType, tile_size_compute_yi, TagPairSNAPComputeYiWithZlist<true>>(chunk_size_div, snaKK.idxz_max);
Kokkos::parallel_for("ComputeYiWithZlistChemsnap", policy_compute_yi, *this);
} else {
auto policy_compute_yi = snap_get_policy<DeviceType, tile_size_compute_yi, TagPairSNAPComputeYiWithZlist<false>>(chunk_size_div, snaKK.idxz_max);
Kokkos::parallel_for("ComputeYiWithZlist", policy_compute_yi, *this);
}
} else {
auto policy_compute_yi = snap_get_policy<DeviceType, tile_size_compute_yi, TagPairSNAPComputeYi, min_blocks_compute_yi>(chunk_size_div, snaKK.idxz_max);
Kokkos::parallel_for("ComputeYi", policy_compute_yi, *this);
if (nelements > 1) {
auto policy_compute_yi = snap_get_policy<DeviceType, tile_size_compute_yi, TagPairSNAPComputeYi<true>, min_blocks_compute_yi>(chunk_size_div, snaKK.idxz_max);
Kokkos::parallel_for("ComputeYiChemsnap", policy_compute_yi, *this);
} else {
auto policy_compute_yi = snap_get_policy<DeviceType, tile_size_compute_yi, TagPairSNAPComputeYi<false>, min_blocks_compute_yi>(chunk_size_div, snaKK.idxz_max);
Kokkos::parallel_for("ComputeYi", policy_compute_yi, *this);
}
}
}
@ -1041,27 +1051,27 @@ void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSN
------------------------------------------------------------------------- */
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYi, const int& iatom_mod, const int& jjz, const int& iatom_div) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYi<chemsnap>, const int& iatom_mod, const int& jjz, const int& iatom_div) const {
const int iatom = iatom_mod + iatom_div * vector_length;
if (iatom >= chunk_size) return;
if (jjz >= snaKK.idxz_max) return;
snaKK.template compute_yi<true>(iatom, jjz);
snaKK.template compute_yi<chemsnap, true>(iatom, jjz);
}
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYi, const int& iatom, const int& jjz) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYi<chemsnap>, const int& iatom, const int& jjz) const {
if (iatom >= chunk_size) return;
snaKK.template compute_yi<true>(iatom, jjz);
snaKK.template compute_yi<chemsnap, true>(iatom, jjz);
}
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYi, const int& iatom) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYi<chemsnap>, const int& iatom) const {
if (iatom >= chunk_size) return;
for (int jjz = 0; jjz < snaKK.idxz_max; jjz++)
snaKK.template compute_yi<false>(iatom, jjz);
snaKK.template compute_yi<chemsnap, false>(iatom, jjz);
}
/* ----------------------------------------------------------------------
@ -1070,27 +1080,27 @@ void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSN
------------------------------------------------------------------------- */
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYiWithZlist, const int& iatom_mod, const int& jjz, const int& iatom_div) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYiWithZlist<chemsnap>, const int& iatom_mod, const int& jjz, const int& iatom_div) const {
const int iatom = iatom_mod + iatom_div * vector_length;
if (iatom >= chunk_size) return;
if (jjz >= snaKK.idxz_max) return;
snaKK.template compute_yi_with_zlist<true>(iatom, jjz);
snaKK.template compute_yi_with_zlist<chemsnap, true>(iatom, jjz);
}
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYiWithZlist, const int& iatom, const int& jjz) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYiWithZlist<chemsnap>, const int& iatom, const int& jjz) const {
if (iatom >= chunk_size) return;
snaKK.template compute_yi_with_zlist<true>(iatom, jjz);
snaKK.template compute_yi_with_zlist<chemsnap, true>(iatom, jjz);
}
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYiWithZlist, const int& iatom) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeYiWithZlist<chemsnap>, const int& iatom) const {
if (iatom >= chunk_size) return;
for (int jjz = 0; jjz < snaKK.idxz_max; jjz++)
snaKK.template compute_yi_with_zlist<false>(iatom, jjz);
snaKK.template compute_yi_with_zlist<chemsnap, false>(iatom, jjz);
}
/* ----------------------------------------------------------------------

View File

@ -206,9 +206,9 @@ class SNAKokkos {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void compute_zi(const int&, const int&) const; // ForceSNAP
template <bool need_atomics> KOKKOS_INLINE_FUNCTION
template <bool chemsnap, bool need_atomics> KOKKOS_INLINE_FUNCTION
void compute_yi(const int&, const int&) const; // ForceSNAP
template <bool need_atomics> KOKKOS_INLINE_FUNCTION
template <bool chemsnap, bool need_atomics> KOKKOS_INLINE_FUNCTION
void compute_yi_with_zlist(const int&, const int&) const; // ForceSNAP
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void compute_bi(const int&, const int&) const; // ForceSNAP

View File

@ -1067,7 +1067,7 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_beta_quadratic(con
------------------------------------------------------------------------- */
template<class DeviceType, typename real_type, int vector_length>
template <bool need_atomics>
template <bool chemsnap, bool need_atomics>
KOKKOS_INLINE_FUNCTION
void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi(const int& iatom, const int& jjz) const
{
@ -1078,31 +1078,44 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi(const int& iato
//int mb = (2 * (mb1min+mb2max) - j1 - j2 + j) / 2;
//int ma = (2 * (ma1min+ma2max) - j1 - j2 + j) / 2;
for (int elem1 = 0; elem1 < nelements; elem1++) {
for (int elem2 = 0; elem2 < nelements; elem2++) {
if constexpr (chemsnap) {
for (int elem1 = 0; elem1 < nelements; elem1++) {
for (int elem2 = 0; elem2 < nelements; elem2++) {
const complex ztmp = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, elem1, elem2, cgblock);
const complex ztmp = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, elem1, elem2, cgblock);
// apply to z(j1,j2,j,ma,mb) to unique element of y(j)
// find right y_list[jju] and beta(iatom,jjb) entries
// multiply and divide by j+1 factors
// account for multiplicity of 1, 2, or 3
// apply to z(j1,j2,j,ma,mb) to unique element of y(j)
// find right y_list[jju] and beta(iatom,jjb) entries
// multiply and divide by j+1 factors
// account for multiplicity of 1, 2, or 3
// pick out right beta value
for (int elem3 = 0; elem3 < nelements; elem3++) {
// pick out right beta value
for (int elem3 = 0; elem3 < nelements; elem3++) {
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3);
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3);
if constexpr (need_atomics) {
Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj * ztmp.re);
Kokkos::atomic_add(&(ylist_im(iatom, elem3, jju_half)), betaj * ztmp.im);
} else {
ylist_re(iatom, elem3, jju_half) += betaj * ztmp.re;
ylist_im(iatom, elem3, jju_half) += betaj * ztmp.im;
}
} // end loop over elem3
} // end loop over elem2
} // end loop over elem1
if constexpr (need_atomics) {
Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj * ztmp.re);
Kokkos::atomic_add(&(ylist_im(iatom, elem3, jju_half)), betaj * ztmp.im);
} else {
ylist_re(iatom, elem3, jju_half) += betaj * ztmp.re;
ylist_im(iatom, elem3, jju_half) += betaj * ztmp.im;
}
} // end loop over elem3
} // end loop over elem2
} // end loop over elem1
} else {
const complex ztmp = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, 0, 0, cgblock);
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, 0, 0, 0);
if constexpr (need_atomics) {
Kokkos::atomic_add(&(ylist_re(iatom, 0, jju_half)), betaj * ztmp.re);
Kokkos::atomic_add(&(ylist_im(iatom, 0, jju_half)), betaj * ztmp.im);
} else {
ylist_re(iatom, 0, jju_half) += betaj * ztmp.re;
ylist_im(iatom, 0, jju_half) += betaj * ztmp.im;
}
}
}
/* ----------------------------------------------------------------------
@ -1110,37 +1123,50 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi(const int& iato
------------------------------------------------------------------------- */
template<class DeviceType, typename real_type, int vector_length>
template <bool need_atomics>
template <bool chemsnap, bool need_atomics>
KOKKOS_INLINE_FUNCTION
void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi_with_zlist(const int& iatom, const int& jjz) const
{
int j1, j2, j, jju_half;
idxz(jjz).get_yi_with_zlist(j1, j2, j, jju_half);
int idouble = 0;
for (int elem1 = 0; elem1 < nelements; elem1++) {
for (int elem2 = 0; elem2 < nelements; elem2++) {
const complex ztmp = zlist(iatom, idouble, jjz);
// apply to z(j1,j2,j,ma,mb) to unique element of y(j)
// find right y_list[jju] and beta(iatom,jjb) entries
// multiply and divide by j+1 factors
// account for multiplicity of 1, 2, or 3
// pick out right beta value
for (int elem3 = 0; elem3 < nelements; elem3++) {
if constexpr (chemsnap) {
int idouble = 0;
for (int elem1 = 0; elem1 < nelements; elem1++) {
for (int elem2 = 0; elem2 < nelements; elem2++) {
const complex ztmp = zlist(iatom, idouble, jjz);
// apply to z(j1,j2,j,ma,mb) to unique element of y(j)
// find right y_list[jju] and beta(iatom,jjb) entries
// multiply and divide by j+1 factors
// account for multiplicity of 1, 2, or 3
// pick out right beta value
for (int elem3 = 0; elem3 < nelements; elem3++) {
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3);
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3);
if constexpr (need_atomics) {
Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj * ztmp.re);
Kokkos::atomic_add(&(ylist_im(iatom, elem3, jju_half)), betaj * ztmp.im);
} else {
ylist_re(iatom, elem3, jju_half) += betaj * ztmp.re;
ylist_im(iatom, elem3, jju_half) += betaj * ztmp.im;
}
} // end loop over elem3
idouble++;
} // end loop over elem2
} // end loop over elem1
if constexpr (need_atomics) {
Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj * ztmp.re);
Kokkos::atomic_add(&(ylist_im(iatom, elem3, jju_half)), betaj * ztmp.im);
} else {
ylist_re(iatom, elem3, jju_half) += betaj * ztmp.re;
ylist_im(iatom, elem3, jju_half) += betaj * ztmp.im;
}
} // end loop over elem3
idouble++;
} // end loop over elem2
} // end loop over elem1
} else {
const complex ztmp = zlist(iatom, 0, jjz);
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, 0, 0, 0);
if constexpr (need_atomics) {
Kokkos::atomic_add(&(ylist_re(iatom, 0, jju_half)), betaj * ztmp.re);
Kokkos::atomic_add(&(ylist_im(iatom, 0, jju_half)), betaj * ztmp.im);
} else {
ylist_re(iatom, 0, jju_half) += betaj * ztmp.re;
ylist_im(iatom, 0, jju_half) += betaj * ztmp.im;
}
}
}
/* ----------------------------------------------------------------------