diff --git a/src/KOKKOS/npair_ssa_kokkos.cpp b/src/KOKKOS/npair_ssa_kokkos.cpp index b73e54e33f..9f447bda1a 100644 --- a/src/KOKKOS/npair_ssa_kokkos.cpp +++ b/src/KOKKOS/npair_ssa_kokkos.cpp @@ -70,6 +70,7 @@ void NPairSSAKokkos::copy_neighbor_info() k_ex2_bit = neighborKK->k_ex2_bit; k_ex_mol_group = neighborKK->k_ex_mol_group; k_ex_mol_bit = neighborKK->k_ex_mol_bit; + k_ex_mol_intra = neighborKK->k_ex_mol_intra; } /* ---------------------------------------------------------------------- @@ -217,8 +218,12 @@ int NPairSSAKokkosExecute::exclusion(const int &i,const int &j, if (nex_mol) { for (m = 0; m < nex_mol; m++) - if (mask(i) & ex_mol_bit(m) && mask(j) & ex_mol_bit(m) && - molecule(i) == molecule(j)) return 1; + if (ex_mol_intra[m]) { // intra-chain: exclude i-j pair if on same molecule + if (mask[i] & ex_mol_bit[m] && mask[j] & ex_mol_bit[m] && + molecule[i] == molecule[j]) return 1; + } else // exclude i-j pair if on different molecules + if (mask[i] & ex_mol_bit[m] && mask[j] & ex_mol_bit[m] && + molecule[i] != molecule[j]) return 1; } return 0; @@ -418,6 +423,7 @@ fprintf(stdout, "tota%03d total %3d could use %6d inums, expected %6d inums. inu nex_mol, k_ex_mol_group.view(), k_ex_mol_bit.view(), + k_ex_mol_intra.view(), bboxhi,bboxlo, domain->xperiodic,domain->yperiodic,domain->zperiodic, domain->xprd_half,domain->yprd_half,domain->zprd_half); @@ -432,6 +438,7 @@ fprintf(stdout, "tota%03d total %3d could use %6d inums, expected %6d inums. inu k_ex2_bit.sync(); k_ex_mol_group.sync(); k_ex_mol_bit.sync(); + k_ex_mol_intra.sync(); k_bincount.sync(); k_bins.sync(); k_gbincount.sync(); diff --git a/src/KOKKOS/npair_ssa_kokkos.h b/src/KOKKOS/npair_ssa_kokkos.h index 98046feba8..17a23b2811 100644 --- a/src/KOKKOS/npair_ssa_kokkos.h +++ b/src/KOKKOS/npair_ssa_kokkos.h @@ -76,6 +76,7 @@ class NPairSSAKokkos : public NPair { DAT::tdual_int_1d k_ex1_bit,k_ex2_bit; DAT::tdual_int_1d k_ex_mol_group; DAT::tdual_int_1d k_ex_mol_bit; + DAT::tdual_int_1d k_ex_mol_intra; // data from NBinSSA class @@ -123,6 +124,7 @@ class NPairSSAKokkosExecute const int nex_mol; const typename AT::t_int_1d_const ex_mol_group; const typename AT::t_int_1d_const ex_mol_bit; + const typename AT::t_int_1d_const ex_mol_intra; // data from NBinSSA class @@ -233,6 +235,7 @@ class NPairSSAKokkosExecute const int & _nex_mol, const typename AT::t_int_1d_const & _ex_mol_group, const typename AT::t_int_1d_const & _ex_mol_bit, + const typename AT::t_int_1d_const & _ex_mol_intra, const X_FLOAT *_bboxhi, const X_FLOAT* _bboxlo, const int & _xperiodic, const int & _yperiodic, const int & _zperiodic, const int & _xprd_half, const int & _yprd_half, const int & _zprd_half): @@ -266,6 +269,7 @@ class NPairSSAKokkosExecute ex1_group(_ex1_group),ex2_group(_ex2_group), ex1_bit(_ex1_bit),ex2_bit(_ex2_bit),nex_mol(_nex_mol), ex_mol_group(_ex_mol_group),ex_mol_bit(_ex_mol_bit), + ex_mol_intra(_ex_mol_intra), xperiodic(_xperiodic),yperiodic(_yperiodic),zperiodic(_zperiodic), xprd_half(_xprd_half),yprd_half(_yprd_half),zprd_half(_zprd_half) {