diff --git a/src/KOKKOS/pair_table_rx_kokkos.cpp b/src/KOKKOS/pair_table_rx_kokkos.cpp index cc0a416ad9..26e335fcff 100644 --- a/src/KOKKOS/pair_table_rx_kokkos.cpp +++ b/src/KOKKOS/pair_table_rx_kokkos.cpp @@ -92,14 +92,14 @@ void PairTableRXKokkos::compute(int eflag_in, int vflag_in) template template -PairTableRXKokkos::Full::Functor( +PairTableRXKokkos::Functor::Functor( PairTableRXKokkos* c_ptr, NeighListKokkos* list_ptr): c(*c_ptr),f(c.f),list(*list_ptr) {} template template -PairTableRXKokkos::Full::~Functor() { +PairTableRXKokkos::Functor::~Functor() { c.cleanup_copy(); list.clean_copy(); } @@ -110,7 +110,7 @@ template KOKKOS_INLINE_FUNCTION EV_FLOAT PairTableRXKokkos::Functor:: -compute_item(const int& ii) { +compute_item(const int& ii) const { EV_FLOAT ev; const int i = list.d_ilist[ii]; const X_FLOAT xtmp = c.x(i,0); @@ -125,10 +125,10 @@ compute_item(const int& ii) { double uCGnew_i = 0.0; double fx_i = 0.0, fy_i = 0.0, fz_i = 0.0; - double mixWtSite1old_i = mixWtSite1old(i); - double mixWtSite2old_i = mixWtSite2old(i); - double mixWtSite1_i = mixWtSite1(i); - double mixWtSite2_i = mixWtSite2(i); + double mixWtSite1old_i = c.mixWtSite1old_(i); + double mixWtSite2old_i = c.mixWtSite2old_(i); + double mixWtSite1_i = c.mixWtSite1_(i); + double mixWtSite2_i = c.mixWtSite2_(i); for (int jj = 0; jj < jnum; jj++) { int j = jlist(jj); @@ -142,12 +142,12 @@ compute_item(const int& ii) { const int jtype = c.type(j); if(rsq < (STACKPARAMS?c.m_cutsq[itype][jtype]:c.d_cutsq(itype,jtype))) { - double mixWtSite1old_j = mixWtSite1old[j]; - double mixWtSite2old_j = mixWtSite2old[j]; - double mixWtSite1_j = mixWtSite1[j]; - double mixWtSite2_j = mixWtSite2[j]; + double mixWtSite1old_j = c.mixWtSite1old_(j); + double mixWtSite2old_j = c.mixWtSite2old_(j); + double mixWtSite1_j = c.mixWtSite1_(j); + double mixWtSite2_j = c.mixWtSite2_(j); - const F_FLOAT fpair = factor_lj*c.template compute_fpair(rsq,i,j,itype,jtype); + const F_FLOAT fpair = factor_lj*c.template compute_fpair(rsq,i,j,itype,jtype); fx_i += delx*fpair; fy_i += dely*fpair; @@ -164,7 +164,7 @@ compute_item(const int& ii) { auto evdwl = c.template compute_evdwl(rsq,i,j,itype,jtype); double evdwlOld; - if (isite1 == isite2) { + if (c.isite1 == c.isite2) { evdwlOld = sqrt(mixWtSite1old_i*mixWtSite2old_j)*evdwl; evdwl = sqrt(mixWtSite1_i*mixWtSite2_j)*evdwl; } else { @@ -324,48 +324,42 @@ void PairTableRXKokkos::compute_style(int eflag_in, int vflag_in) mixWtSite1_(i), mixWtSite2_(i)); }); + if (neighflag == N2) error->all(FLERR,"pair table/rx/kk can't handle N2 yet\n"); + EV_FLOAT ev; if(atom->ntypes > MAX_TYPES_STACKPARAMS) { if (neighflag == FULL) { - PairComputeFunctor,FULL,false,S_TableCompute > - ff(this,(NeighListKokkos*) list); + Functor ff(this,(NeighListKokkos*) list); if (eflag || vflag) Kokkos::parallel_reduce(list->inum,ff,ev); else Kokkos::parallel_for(list->inum,ff); } else if (neighflag == HALFTHREAD) { - PairComputeFunctor,HALFTHREAD,false,S_TableCompute > - ff(this,(NeighListKokkos*) list); + Functor ff(this,(NeighListKokkos*) list); if (eflag || vflag) Kokkos::parallel_reduce(list->inum,ff,ev); else Kokkos::parallel_for(list->inum,ff); } else if (neighflag == HALF) { - PairComputeFunctor,HALF,false,S_TableCompute > - f(this,(NeighListKokkos*) list); + Functor f(this,(NeighListKokkos*) list); if (eflag || vflag) Kokkos::parallel_reduce(list->inum,f,ev); else Kokkos::parallel_for(list->inum,f); } else if (neighflag == N2) { - PairComputeFunctor,N2,false,S_TableCompute > - f(this,(NeighListKokkos*) list); + Functor f(this,(NeighListKokkos*) list); if (eflag || vflag) Kokkos::parallel_reduce(nlocal,f,ev); else Kokkos::parallel_for(nlocal,f); } } else { if (neighflag == FULL) { - PairComputeFunctor,FULL,true,S_TableCompute > - f(this,(NeighListKokkos*) list); + Functor f(this,(NeighListKokkos*) list); if (eflag || vflag) Kokkos::parallel_reduce(list->inum,f,ev); else Kokkos::parallel_for(list->inum,f); } else if (neighflag == HALFTHREAD) { - PairComputeFunctor,HALFTHREAD,true,S_TableCompute > - f(this,(NeighListKokkos*) list); + Functor f(this,(NeighListKokkos*) list); if (eflag || vflag) Kokkos::parallel_reduce(list->inum,f,ev); else Kokkos::parallel_for(list->inum,f); } else if (neighflag == HALF) { - PairComputeFunctor,HALF,true,S_TableCompute > - f(this,(NeighListKokkos*) list); + Functor f(this,(NeighListKokkos*) list); if (eflag || vflag) Kokkos::parallel_reduce(list->inum,f,ev); else Kokkos::parallel_for(list->inum,f); } else if (neighflag == N2) { - PairComputeFunctor,N2,true,S_TableCompute > - f(this,(NeighListKokkos*) list); + Functor f(this,(NeighListKokkos*) list); if (eflag || vflag) Kokkos::parallel_reduce(nlocal,f,ev); else Kokkos::parallel_for(nlocal,f); } diff --git a/src/KOKKOS/pair_table_rx_kokkos.h b/src/KOKKOS/pair_table_rx_kokkos.h index f717dc3f8a..c468461263 100644 --- a/src/KOKKOS/pair_table_rx_kokkos.h +++ b/src/KOKKOS/pair_table_rx_kokkos.h @@ -102,11 +102,11 @@ class PairTableRXKokkos : public PairTable { void create_kokkos_tables(); void cleanup_copy(); - template + template KOKKOS_INLINE_FUNCTION F_FLOAT compute_fpair(const F_FLOAT& rsq, const int& i, const int&j, const int& itype, const int& jtype) const; - template + template KOKKOS_INLINE_FUNCTION F_FLOAT compute_evdwl(const F_FLOAT& rsq, const int& i, const int&j, const int& itype, const int& jtype) const; @@ -150,12 +150,12 @@ class PairTableRXKokkos : public PairTable { } template KOKKOS_INLINE_FUNCTION - EV_FLOAT compute_item(const int&, - const NeighListKokkos&, const NoCoulTag&) const; + EV_FLOAT compute_item(const int&) const; KOKKOS_INLINE_FUNCTION + void ev_tally(EV_FLOAT &ev, const int &i, const int &j, const F_FLOAT &epair, const F_FLOAT &fpair, const F_FLOAT &delx, - const F_FLOAT &dely, const F_FLOAT &delz) const + const F_FLOAT &dely, const F_FLOAT &delz) const; KOKKOS_INLINE_FUNCTION void operator()(const int) const; KOKKOS_INLINE_FUNCTION