diff --git a/src/KOKKOS/pair_snap_kokkos.h b/src/KOKKOS/pair_snap_kokkos.h index ece1384a49..611c79b56d 100644 --- a/src/KOKKOS/pair_snap_kokkos.h +++ b/src/KOKKOS/pair_snap_kokkos.h @@ -30,9 +30,15 @@ PairStyle(snap/kk/host,PairSNAPKokkosDevice); #include "pair_snap.h" #include "kokkos_type.h" #include "neigh_list_kokkos.h" -#include "sna_kokkos.h" #include "pair_kokkos.h" +namespace LAMMPS_NS { +// pre-declare so sna_kokkos.h can refer to it +template class PairSNAPKokkos; +}; + +#include "sna_kokkos.h" + namespace LAMMPS_NS { // Routines for both the CPU and GPU backend @@ -262,7 +268,7 @@ class PairSNAPKokkos : public PairSNAP { Kokkos::View d_radelem; // element radii Kokkos::View d_wjelem; // elements weights - Kokkos::View d_coeffelem; // element bispectrum coefficients + typename SNAKokkos::t_sna_2d_lr d_coeffelem; // element bispectrum coefficients Kokkos::View d_sinnerelem; // element inner cutoff midpoint Kokkos::View d_dinnerelem; // element inner cutoff half-width Kokkos::View d_map; // mapping from atom types to elements @@ -302,6 +308,9 @@ class PairSNAPKokkos : public PairSNAP { template int scratch_size_helper(int values_per_team); + // Make SNAKokkos a friend + friend class SNAKokkos; + }; diff --git a/src/KOKKOS/pair_snap_kokkos_impl.h b/src/KOKKOS/pair_snap_kokkos_impl.h index c365d03c90..721ce5352d 100644 --- a/src/KOKKOS/pair_snap_kokkos_impl.h +++ b/src/KOKKOS/pair_snap_kokkos_impl.h @@ -245,14 +245,14 @@ void PairSNAPKokkos::compute(int eflag_in, // tile_size_compute_ck is defined in `pair_snap_kokkos.h` Snap3DRangePolicy policy_compute_ck({0,0,0},{vector_length,max_neighs,chunk_size_div},{vector_length,tile_size_compute_ck,1}); - Kokkos::parallel_for("ComputeCayleyKlein",policy_compute_ck,*this); + Kokkos::parallel_for("ComputeCayleyKlein", policy_compute_ck, *this); } // PreUi; same CPU and GPU codepath { auto policy_pre_ui = snap_get_policy(chunk_size_div, twojmax + 1); //typename Kokkos::RangePolicy policy_preui_cpu(0, chunk_size * (twojmax + 1)); - Kokkos::parallel_for("PreUi",policy_pre_ui,*this); + Kokkos::parallel_for("PreUi", policy_pre_ui, *this); } // ComputeUi; separate CPU, GPU codepaths @@ -531,8 +531,8 @@ void PairSNAPKokkos::coeff(int narg, char Kokkos::deep_copy(d_dinnerelem,h_dinnerelem); Kokkos::deep_copy(d_map,h_map); - snaKK = SNAKokkos(rfac0,twojmax, - rmin0,switchflag,bzeroflag,chemflag,bnormflag,wselfallflag,nelements,switchinnerflag); + snaKK = SNAKokkos(*this); //rfac0,twojmax, + //rmin0,switchflag,bzeroflag,chemflag,bnormflag,wselfallflag,nelements,switchinnerflag); snaKK.grow_rij(0,0); snaKK.init(); } diff --git a/src/KOKKOS/sna_kokkos.h b/src/KOKKOS/sna_kokkos.h index 92f413ed17..24dfa1f4ac 100644 --- a/src/KOKKOS/sna_kokkos.h +++ b/src/KOKKOS/sna_kokkos.h @@ -143,6 +143,7 @@ class SNAKokkos { typedef Kokkos::View t_sna_2i; typedef Kokkos::View t_sna_2d; typedef Kokkos::View t_sna_2d_ll; + typedef Kokkos::View t_sna_2d_lr; typedef Kokkos::View t_sna_3d; typedef Kokkos::View t_sna_3d_ll; typedef Kokkos::View t_sna_4d; @@ -170,7 +171,8 @@ class SNAKokkos { SNAKokkos(const SNAKokkos& sna, const typename Kokkos::TeamPolicy::member_type& team); inline - SNAKokkos(real_type, int, real_type, int, int, int, int, int, int, int); + //SNAKokkos(real_type, int, real_type, int, int, int, int, int, int, int); + SNAKokkos(const PairSNAPKokkos&); KOKKOS_INLINE_FUNCTION ~SNAKokkos(); @@ -282,7 +284,12 @@ class SNAKokkos { int twojmax, diagonalstyle; + // Input beta coefficients; aliases the object in PairSnapKokkos + t_sna_2d_lr d_coeffelem; + // Beta for all atoms in list; aliases the object in PairSnapKokkos + // for qSNAP the quadratic terms get accumulated into it + // in compute_bi t_sna_2d d_beta; // Structures for both the CPU, GPU backend @@ -379,6 +386,9 @@ class SNAKokkos { real_type wself; int wselfall_flag; + // quadratic flag + int quadratic_flag; + int bzero_flag; // 1 if bzero subtracted from barray Kokkos::View bzero; // array of B values for isolated atoms }; diff --git a/src/KOKKOS/sna_kokkos_impl.h b/src/KOKKOS/sna_kokkos_impl.h index 4cdd37d1f5..90232d1333 100644 --- a/src/KOKKOS/sna_kokkos_impl.h +++ b/src/KOKKOS/sna_kokkos_impl.h @@ -30,27 +30,18 @@ static const double MY_PI2 = 1.57079632679489661923; // pi/2 template inline -SNAKokkos::SNAKokkos(real_type rfac0_in, - int twojmax_in, real_type rmin0_in, int switch_flag_in, int bzero_flag_in, - int chem_flag_in, int bnorm_flag_in, int wselfall_flag_in, int nelements_in, int switch_inner_flag_in) +SNAKokkos::SNAKokkos(const PairSNAPKokkos& psk) + : rfac0(psk.rfac0), rmin0(psk.rmin0), switch_flag(psk.switchflag), + bzero_flag(psk.bzeroflag), chem_flag(psk.chemflag), bnorm_flag(psk.bnormflag), + wselfall_flag(psk.wselfallflag), switch_inner_flag(psk.switchinnerflag), + quadratic_flag(psk.quadraticflag), twojmax(psk.twojmax), d_coeffelem(psk.d_coeffelem) { wself = static_cast(1.0); - rfac0 = rfac0_in; - rmin0 = rmin0_in; - switch_flag = switch_flag_in; - switch_inner_flag = switch_inner_flag_in; - bzero_flag = bzero_flag_in; - - chem_flag = chem_flag_in; if (chem_flag) - nelements = nelements_in; + nelements = psk.nelements; else nelements = 1; - bnorm_flag = bnorm_flag_in; - wselfall_flag = wselfall_flag_in; - - twojmax = twojmax_in; ncoeff = compute_ncoeff();