Added templating over chemsnap for ComputeZi and ComputeBi

This commit is contained in:
Evan Weinberg
2024-11-21 14:06:10 -08:00
parent 67470f236e
commit 28e64fca94
4 changed files with 76 additions and 61 deletions

View File

@ -44,8 +44,8 @@ namespace LAMMPS_NS {
// Routines for both the CPU and GPU backend
struct TagPairSNAPPreUi{};
struct TagPairSNAPTransformUi{}; // re-order ulisttot from SoA to AoSoA, zero ylist
struct TagPairSNAPComputeZi{};
struct TagPairSNAPComputeBi{};
template <bool chemsnap> struct TagPairSNAPComputeZi{};
template <bool chemsnap> struct TagPairSNAPComputeBi{};
struct TagPairSNAPComputeBetaLinear{};
struct TagPairSNAPComputeBetaQuadratic{};
struct TagPairSNAPComputeYi{};
@ -222,23 +222,23 @@ class PairSNAPKokkos : public PairSNAP {
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPTransformUi, const int& iatom) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeZi, const int& iatom_mod, const int& idxz, const int& iatom_div) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeZi<chemsnap>, const int& iatom_mod, const int& idxz, const int& iatom_div) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeZi, const int& iatom, const int& idxz) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeZi<chemsnap>, const int& iatom, const int& idxz) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeZi, const int& iatom) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeZi<chemsnap>, const int& iatom) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeBi, const int& iatom_mod, const int& idxb, const int& iatom_div) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeBi<chemsnap>, const int& iatom_mod, const int& idxb, const int& iatom_div) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeBi, const int& iatom, const int& idxb) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeBi<chemsnap>, const int& iatom, const int& idxb) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeBi, const int& iatom) const;
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeBi<chemsnap>, const int& iatom) const;
KOKKOS_INLINE_FUNCTION
void operator() (TagPairSNAPComputeBetaLinear, const int& iatom_mod, const int& idxb, const int& iatom_div) const;

View File

@ -302,13 +302,21 @@ void PairSNAPKokkos<DeviceType, real_type, vector_length>::compute(int eflag_in,
if (quadraticflag || eflag) {
// team_size_[compute_zi, compute_bi, transform_bi] are defined in `pair_snap_kokkos.h`
//ComputeZi
auto policy_compute_zi = snap_get_policy<DeviceType, tile_size_compute_zi, TagPairSNAPComputeZi, min_blocks_compute_zi>(chunk_size_div, snaKK.idxz_max);
Kokkos::parallel_for("ComputeZi", policy_compute_zi, *this);
//ComputeZi and Bi
if (nelements > 1) {
auto policy_compute_zi = snap_get_policy<DeviceType, tile_size_compute_zi, TagPairSNAPComputeZi<true>, min_blocks_compute_zi>(chunk_size_div, snaKK.idxz_max);
Kokkos::parallel_for("ComputeZiChemsnap", policy_compute_zi, *this);
auto policy_compute_bi = snap_get_policy<DeviceType, tile_size_compute_bi, TagPairSNAPComputeBi<true>>(chunk_size_div, snaKK.idxb_max);
Kokkos::parallel_for("ComputeBiChemsnap", policy_compute_bi, *this);
} else {
auto policy_compute_zi = snap_get_policy<DeviceType, tile_size_compute_zi, TagPairSNAPComputeZi<false>, min_blocks_compute_zi>(chunk_size_div, snaKK.idxz_max);
Kokkos::parallel_for("ComputeZi", policy_compute_zi, *this);
auto policy_compute_bi = snap_get_policy<DeviceType, tile_size_compute_bi, TagPairSNAPComputeBi<false>>(chunk_size_div, snaKK.idxb_max);
Kokkos::parallel_for("ComputeBi", policy_compute_bi, *this);
}
//ComputeBi
auto policy_compute_bi = snap_get_policy<DeviceType, tile_size_compute_bi, TagPairSNAPComputeBi>(chunk_size_div, snaKK.idxb_max);
Kokkos::parallel_for("ComputeBi", policy_compute_bi, *this);
}
{
@ -884,27 +892,27 @@ void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSN
------------------------------------------------------------------------- */
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeZi, const int& iatom_mod, const int& jjz, const int& iatom_div) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeZi<chemsnap>, const int& iatom_mod, const int& jjz, const int& iatom_div) const {
const int iatom = iatom_mod + iatom_div * vector_length;
if (iatom >= chunk_size) return;
if (jjz >= snaKK.idxz_max) return;
snaKK.compute_zi(iatom, jjz);
snaKK.template compute_zi<chemsnap>(iatom, jjz);
}
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeZi, const int& iatom, const int& jjz) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeZi<chemsnap>, const int& iatom, const int& jjz) const {
if (iatom >= chunk_size) return;
snaKK.compute_zi(iatom, jjz);
snaKK.template compute_zi<chemsnap>(iatom, jjz);
}
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeZi, const int& iatom) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeZi<chemsnap>, const int& iatom) const {
if (iatom >= chunk_size) return;
for (int jjz = 0; jjz < snaKK.idxz_max; jjz++)
snaKK.compute_zi(iatom, jjz);
snaKK.template compute_zi<chemsnap>(iatom, jjz);
}
/* ----------------------------------------------------------------------
@ -913,27 +921,27 @@ void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSN
------------------------------------------------------------------------- */
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeBi, const int& iatom_mod, const int& jjb, const int& iatom_div) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeBi<chemsnap>, const int& iatom_mod, const int& jjb, const int& iatom_div) const {
const int iatom = iatom_mod + iatom_div * vector_length;
if (iatom >= chunk_size) return;
if (jjb >= snaKK.idxb_max) return;
snaKK.compute_bi(iatom, jjb);
snaKK.template compute_bi<chemsnap>(iatom, jjb);
}
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeBi, const int& iatom, const int& jjb) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeBi<chemsnap>, const int& iatom, const int& jjb) const {
if (iatom >= chunk_size) return;
snaKK.compute_bi(iatom, jjb);
snaKK.template compute_bi<chemsnap>(iatom, jjb);
}
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeBi, const int& iatom) const {
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSNAPComputeBi<chemsnap>, const int& iatom) const {
if (iatom >= chunk_size) return;
for (int jjb = 0; jjb < snaKK.idxb_max; jjb++)
snaKK.compute_bi(iatom, jjb);
snaKK.template compute_bi<chemsnap>(iatom, jjb);
}
/* ----------------------------------------------------------------------

View File

@ -204,13 +204,13 @@ class SNAKokkos {
KOKKOS_INLINE_FUNCTION
void transform_ui(const int&, const int&) const;
KOKKOS_INLINE_FUNCTION
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void compute_zi(const int&, const int&) const; // ForceSNAP
template <bool need_atomics> KOKKOS_INLINE_FUNCTION
void compute_yi(const int&, const int&) const; // ForceSNAP
template <bool need_atomics> KOKKOS_INLINE_FUNCTION
void compute_yi_with_zlist(const int&, const int&) const; // ForceSNAP
KOKKOS_INLINE_FUNCTION
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void compute_bi(const int&, const int&) const; // ForceSNAP
KOKKOS_INLINE_FUNCTION
void compute_beta_linear(const int&, const int&, const int&) const;

View File

@ -794,7 +794,7 @@ void SNAKokkos<DeviceType, real_type, vector_length>::transform_ui(const int& ia
------------------------------------------------------------------------- */
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void SNAKokkos<DeviceType, real_type, vector_length>::compute_zi(const int& iatom, const int& jjz) const
{
int j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, idxcg;
@ -802,14 +802,17 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_zi(const int& iato
const real_type *cgblock = cglist.data() + idxcg;
int idouble = 0;
for (int elem1 = 0; elem1 < nelements; elem1++) {
for (int elem2 = 0; elem2 < nelements; elem2++) {
zlist(iatom, idouble, jjz) = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, elem1, elem2, cgblock);
idouble++;
} // end loop over elem2
} // end loop over elem1
if constexpr (chemsnap) {
int idouble = 0;
for (int elem1 = 0; elem1 < nelements; elem1++) {
for (int elem2 = 0; elem2 < nelements; elem2++) {
zlist(iatom, idouble, jjz) = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, elem1, elem2, cgblock);
idouble++;
} // end loop over elem2
} // end loop over elem1
} else {
zlist(iatom, 0, jjz) = evaluate_zi(j1, j2, j, ma1min, ma2max, mb1min, mb2max, na, nb, iatom, 0, 0, cgblock);
}
}
/* ----------------------------------------------------------------------
@ -873,7 +876,7 @@ typename SNAKokkos<DeviceType, real_type, vector_length>::complex SNAKokkos<Devi
------------------------------------------------------------------------- */
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
template <bool chemsnap> KOKKOS_INLINE_FUNCTION
void SNAKokkos<DeviceType, real_type, vector_length>::compute_bi(const int& iatom, const int& jjb) const
{
// for j1 = 0,...,twojmax
@ -892,17 +895,21 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_bi(const int& iato
const int jjz = idxz_block(j1,j2,j);
const int jju = idxu_block[j];
int itriple = 0;
int idouble = 0;
for (int elem1 = 0; elem1 < nelements; elem1++) {
for (int elem2 = 0; elem2 < nelements; elem2++) {
for (int elem3 = 0; elem3 < nelements; elem3++) {
blist(iatom, itriple, jjb) = evaluate_bi(j, jjz, jju, iatom, elem1, elem2, elem3);
itriple++;
} // end loop over elem3
idouble++;
} // end loop over elem2
} // end loop over elem1
if constexpr (chemsnap) {
int itriple = 0;
int idouble = 0;
for (int elem1 = 0; elem1 < nelements; elem1++) {
for (int elem2 = 0; elem2 < nelements; elem2++) {
for (int elem3 = 0; elem3 < nelements; elem3++) {
blist(iatom, itriple, jjb) = evaluate_bi(j, jjz, jju, iatom, elem1, elem2, elem3);
itriple++;
} // end loop over elem3
idouble++;
} // end loop over elem2
} // end loop over elem1
} else {
blist(iatom, 0, jjb) = evaluate_bi(j, jjz, jju, iatom, 0, 0, 0);
}
}
/* ----------------------------------------------------------------------