diff --git a/src/KOKKOS/pair_snap_kokkos_impl.h b/src/KOKKOS/pair_snap_kokkos_impl.h index ef01ec5ea3..e8a5c115c9 100644 --- a/src/KOKKOS/pair_snap_kokkos_impl.h +++ b/src/KOKKOS/pair_snap_kokkos_impl.h @@ -198,7 +198,7 @@ void PairSNAPKokkos::compute(int eflag_in, int vflag_in) d_ninside = Kokkos::View("PairSNAPKokkos:ninside",inum); } - int chunk_size = MIN(2000,inum); + chunk_size = MIN(chunk_size,inum); chunk_offset = 0; snaKK.grow_rij(chunk_size,max_neighs); @@ -221,7 +221,7 @@ void PairSNAPKokkos::compute(int eflag_in, int vflag_in) Kokkos::parallel_for("PreUi",policy_preui,*this); //ComputeUi - typename Kokkos::TeamPolicy policy_ui(((inum+team_size-1)/team_size)*max_neighs,team_size,vector_length); + typename Kokkos::TeamPolicy policy_ui(((chunk_size+team_size-1)/team_size)*max_neighs,team_size,vector_length); Kokkos::parallel_for("ComputeUi",policy_ui,*this); //Ulisttot transpose @@ -253,11 +253,11 @@ void PairSNAPKokkos::compute(int eflag_in, int vflag_in) Kokkos::parallel_for("ComputeYi",policy_yi,*this); //ComputeDuidrj - typename Kokkos::TeamPolicy policy_duidrj(((inum+team_size-1)/team_size)*max_neighs,team_size,vector_length); + typename Kokkos::TeamPolicy policy_duidrj(((chunk_size+team_size-1)/team_size)*max_neighs,team_size,vector_length); Kokkos::parallel_for("ComputeDuidrj",policy_duidrj,*this); //ComputeDeidrj - typename Kokkos::TeamPolicy policy_deidrj(((inum+team_size-1)/team_size)*max_neighs,team_size,vector_length); + typename Kokkos::TeamPolicy policy_deidrj(((chunk_size+team_size-1)/team_size)*max_neighs,team_size,vector_length); Kokkos::parallel_for("ComputeDeidrj",policy_deidrj,*this); //ComputeForce @@ -514,11 +514,12 @@ void PairSNAPKokkos::operator() (TagPairSNAPComputeUi,const typename SNAKokkos my_sna = snaKK; // Extract the atom number - int ii = team.team_rank() + team.team_size() * (team.league_rank() % ((inum+team.team_size()-1)/team.team_size())); - if (ii >= inum) return; + int ii = team.team_rank() + team.team_size() * (team.league_rank() % + ((chunk_size+team.team_size()-1)/team.team_size())); + if (ii >= chunk_size) return; // Extract the neighbor number - const int jj = team.league_rank() / ((inum+team.team_size()-1)/team.team_size()); + const int jj = team.league_rank() / ((chunk_size+team.team_size()-1)/team.team_size()); const int ninside = d_ninside(ii); if (jj >= ninside) return; @@ -560,11 +561,12 @@ void PairSNAPKokkos::operator() (TagPairSNAPComputeDuidrj,const type SNAKokkos my_sna = snaKK; // Extract the atom number - int ii = team.team_rank() + team.team_size() * (team.league_rank() % ((inum+team.team_size()-1)/team.team_size())); - if (ii >= inum) return; + int ii = team.team_rank() + team.team_size() * (team.league_rank() % + ((chunk_size+team.team_size()-1)/team.team_size())); + if (ii >= chunk_size) return; // Extract the neighbor number - const int jj = team.league_rank() / ((inum+team.team_size()-1)/team.team_size()); + const int jj = team.league_rank() / ((chunk_size+team.team_size()-1)/team.team_size()); const int ninside = d_ninside(ii); if (jj >= ninside) return; @@ -577,11 +579,12 @@ void PairSNAPKokkos::operator() (TagPairSNAPComputeDeidrj,const type SNAKokkos my_sna = snaKK; // Extract the atom number - int ii = team.team_rank() + team.team_size() * (team.league_rank() % ((inum+team.team_size()-1)/team.team_size())); - if (ii >= inum) return; + int ii = team.team_rank() + team.team_size() * (team.league_rank() % + ((chunk_size+team.team_size()-1)/team.team_size())); + if (ii >= chunk_size) return; // Extract the neighbor number - const int jj = team.league_rank() / ((inum+team.team_size()-1)/team.team_size()); + const int jj = team.league_rank() / ((chunk_size+team.team_size()-1)/team.team_size()); const int ninside = d_ninside(ii); if (jj >= ninside) return; diff --git a/src/SNAP/pair_snap.cpp b/src/SNAP/pair_snap.cpp index 6f7cf54659..78250c4718 100644 --- a/src/SNAP/pair_snap.cpp +++ b/src/SNAP/pair_snap.cpp @@ -635,6 +635,7 @@ void PairSNAP::read_files(char *coefffilename, char *paramfilename) switchflag = 1; bzeroflag = 1; quadraticflag = 0; + chunk_size = 2000; // open SNAP parameter file on proc 0 @@ -698,6 +699,8 @@ void PairSNAP::read_files(char *coefffilename, char *paramfilename) bzeroflag = atoi(keyval); else if (strcmp(keywd,"quadraticflag") == 0) quadraticflag = atoi(keyval); + else if (strcmp(keywd,"chunksize") == 0) + chunk_size = atoi(keyval); else error->all(FLERR,"Incorrect SNAP parameter file"); } diff --git a/src/SNAP/pair_snap.h b/src/SNAP/pair_snap.h index c64eaa5d4e..0c387a9073 100644 --- a/src/SNAP/pair_snap.h +++ b/src/SNAP/pair_snap.h @@ -59,6 +59,7 @@ protected: double** bispectrum; // bispectrum components for all atoms in list int *map; // mapping from atom types to elements int twojmax, switchflag, bzeroflag; + int chunk_size; double rfac0, rmin0, wj1, wj2; int rcutfacflag, twojmaxflag; // flags for required parameters int beta_max; // length of beta