Simplified the constuctor for SNAKokkos by passing PairSNAPKokkos in directly by reference
This commit is contained in:
@ -30,9 +30,15 @@ PairStyle(snap/kk/host,PairSNAPKokkosDevice<LMPHostType>);
|
||||
#include "pair_snap.h"
|
||||
#include "kokkos_type.h"
|
||||
#include "neigh_list_kokkos.h"
|
||||
#include "sna_kokkos.h"
|
||||
#include "pair_kokkos.h"
|
||||
|
||||
namespace LAMMPS_NS {
|
||||
// pre-declare so sna_kokkos.h can refer to it
|
||||
template<class DeviceType, typename real_type_, int vector_length_> class PairSNAPKokkos;
|
||||
};
|
||||
|
||||
#include "sna_kokkos.h"
|
||||
|
||||
namespace LAMMPS_NS {
|
||||
|
||||
// Routines for both the CPU and GPU backend
|
||||
@ -262,7 +268,7 @@ class PairSNAPKokkos : public PairSNAP {
|
||||
|
||||
Kokkos::View<real_type*, DeviceType> d_radelem; // element radii
|
||||
Kokkos::View<real_type*, DeviceType> d_wjelem; // elements weights
|
||||
Kokkos::View<real_type**, Kokkos::LayoutRight, DeviceType> d_coeffelem; // element bispectrum coefficients
|
||||
typename SNAKokkos<DeviceType, real_type, vector_length>::t_sna_2d_lr d_coeffelem; // element bispectrum coefficients
|
||||
Kokkos::View<real_type*, DeviceType> d_sinnerelem; // element inner cutoff midpoint
|
||||
Kokkos::View<real_type*, DeviceType> d_dinnerelem; // element inner cutoff half-width
|
||||
Kokkos::View<T_INT*, DeviceType> d_map; // mapping from atom types to elements
|
||||
@ -302,6 +308,9 @@ class PairSNAPKokkos : public PairSNAP {
|
||||
template <typename scratch_type>
|
||||
int scratch_size_helper(int values_per_team);
|
||||
|
||||
// Make SNAKokkos a friend
|
||||
friend class SNAKokkos<DeviceType, real_type, vector_length>;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -245,14 +245,14 @@ void PairSNAPKokkos<DeviceType, real_type, vector_length>::compute(int eflag_in,
|
||||
// tile_size_compute_ck is defined in `pair_snap_kokkos.h`
|
||||
Snap3DRangePolicy<DeviceType, tile_size_compute_ck, TagPairSNAPComputeCayleyKlein>
|
||||
policy_compute_ck({0,0,0},{vector_length,max_neighs,chunk_size_div},{vector_length,tile_size_compute_ck,1});
|
||||
Kokkos::parallel_for("ComputeCayleyKlein",policy_compute_ck,*this);
|
||||
Kokkos::parallel_for("ComputeCayleyKlein", policy_compute_ck, *this);
|
||||
}
|
||||
|
||||
// PreUi; same CPU and GPU codepath
|
||||
{
|
||||
auto policy_pre_ui = snap_get_policy<DeviceType, tile_size_pre_ui, TagPairSNAPPreUi>(chunk_size_div, twojmax + 1);
|
||||
//typename Kokkos::RangePolicy<DeviceType,TagPairSNAPPreUi> policy_preui_cpu(0, chunk_size * (twojmax + 1));
|
||||
Kokkos::parallel_for("PreUi",policy_pre_ui,*this);
|
||||
Kokkos::parallel_for("PreUi", policy_pre_ui, *this);
|
||||
}
|
||||
|
||||
// ComputeUi; separate CPU, GPU codepaths
|
||||
@ -531,8 +531,8 @@ void PairSNAPKokkos<DeviceType, real_type, vector_length>::coeff(int narg, char
|
||||
Kokkos::deep_copy(d_dinnerelem,h_dinnerelem);
|
||||
Kokkos::deep_copy(d_map,h_map);
|
||||
|
||||
snaKK = SNAKokkos<DeviceType, real_type, vector_length>(rfac0,twojmax,
|
||||
rmin0,switchflag,bzeroflag,chemflag,bnormflag,wselfallflag,nelements,switchinnerflag);
|
||||
snaKK = SNAKokkos<DeviceType, real_type, vector_length>(*this); //rfac0,twojmax,
|
||||
//rmin0,switchflag,bzeroflag,chemflag,bnormflag,wselfallflag,nelements,switchinnerflag);
|
||||
snaKK.grow_rij(0,0);
|
||||
snaKK.init();
|
||||
}
|
||||
|
||||
@ -143,6 +143,7 @@ class SNAKokkos {
|
||||
typedef Kokkos::View<int**, DeviceType> t_sna_2i;
|
||||
typedef Kokkos::View<real_type**, DeviceType> t_sna_2d;
|
||||
typedef Kokkos::View<real_type**, Kokkos::LayoutLeft, DeviceType> t_sna_2d_ll;
|
||||
typedef Kokkos::View<real_type**, Kokkos::LayoutRight, DeviceType> t_sna_2d_lr;
|
||||
typedef Kokkos::View<real_type***, DeviceType> t_sna_3d;
|
||||
typedef Kokkos::View<real_type***, Kokkos::LayoutLeft, DeviceType> t_sna_3d_ll;
|
||||
typedef Kokkos::View<real_type***[3], DeviceType> t_sna_4d;
|
||||
@ -170,7 +171,8 @@ class SNAKokkos {
|
||||
SNAKokkos(const SNAKokkos<DeviceType,real_type,vector_length>& sna, const typename Kokkos::TeamPolicy<DeviceType>::member_type& team);
|
||||
|
||||
inline
|
||||
SNAKokkos(real_type, int, real_type, int, int, int, int, int, int, int);
|
||||
//SNAKokkos(real_type, int, real_type, int, int, int, int, int, int, int);
|
||||
SNAKokkos(const PairSNAPKokkos<DeviceType, real_type, vector_length>&);
|
||||
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
~SNAKokkos();
|
||||
@ -282,7 +284,12 @@ class SNAKokkos {
|
||||
|
||||
int twojmax, diagonalstyle;
|
||||
|
||||
// Input beta coefficients; aliases the object in PairSnapKokkos
|
||||
t_sna_2d_lr d_coeffelem;
|
||||
|
||||
// Beta for all atoms in list; aliases the object in PairSnapKokkos
|
||||
// for qSNAP the quadratic terms get accumulated into it
|
||||
// in compute_bi
|
||||
t_sna_2d d_beta;
|
||||
|
||||
// Structures for both the CPU, GPU backend
|
||||
@ -379,6 +386,9 @@ class SNAKokkos {
|
||||
real_type wself;
|
||||
int wselfall_flag;
|
||||
|
||||
// quadratic flag
|
||||
int quadratic_flag;
|
||||
|
||||
int bzero_flag; // 1 if bzero subtracted from barray
|
||||
Kokkos::View<real_type*, DeviceType> bzero; // array of B values for isolated atoms
|
||||
};
|
||||
|
||||
@ -30,27 +30,18 @@ static const double MY_PI2 = 1.57079632679489661923; // pi/2
|
||||
|
||||
template<class DeviceType, typename real_type, int vector_length>
|
||||
inline
|
||||
SNAKokkos<DeviceType, real_type, vector_length>::SNAKokkos(real_type rfac0_in,
|
||||
int twojmax_in, real_type rmin0_in, int switch_flag_in, int bzero_flag_in,
|
||||
int chem_flag_in, int bnorm_flag_in, int wselfall_flag_in, int nelements_in, int switch_inner_flag_in)
|
||||
SNAKokkos<DeviceType, real_type, vector_length>::SNAKokkos(const PairSNAPKokkos<DeviceType, real_type, vector_length>& psk)
|
||||
: rfac0(psk.rfac0), rmin0(psk.rmin0), switch_flag(psk.switchflag),
|
||||
bzero_flag(psk.bzeroflag), chem_flag(psk.chemflag), bnorm_flag(psk.bnormflag),
|
||||
wselfall_flag(psk.wselfallflag), switch_inner_flag(psk.switchinnerflag),
|
||||
quadratic_flag(psk.quadraticflag), twojmax(psk.twojmax), d_coeffelem(psk.d_coeffelem)
|
||||
{
|
||||
wself = static_cast<real_type>(1.0);
|
||||
|
||||
rfac0 = rfac0_in;
|
||||
rmin0 = rmin0_in;
|
||||
switch_flag = switch_flag_in;
|
||||
switch_inner_flag = switch_inner_flag_in;
|
||||
bzero_flag = bzero_flag_in;
|
||||
|
||||
chem_flag = chem_flag_in;
|
||||
if (chem_flag)
|
||||
nelements = nelements_in;
|
||||
nelements = psk.nelements;
|
||||
else
|
||||
nelements = 1;
|
||||
bnorm_flag = bnorm_flag_in;
|
||||
wselfall_flag = wselfall_flag_in;
|
||||
|
||||
twojmax = twojmax_in;
|
||||
|
||||
ncoeff = compute_ncoeff();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user