factor out variable lcomp

This commit is contained in:
Megan J. McCarthy
2023-06-10 09:11:09 -06:00
parent 9eb32fc6b0
commit 38fd78b867
2 changed files with 23 additions and 27 deletions

View File

@ -59,7 +59,7 @@ ComputeLocalCompAtomKokkos<DeviceType>::~ComputeLocalCompAtomKokkos()
if (copymode) return; if (copymode) return;
memoryKK->destroy_kokkos(k_result,result); memoryKK->destroy_kokkos(k_result,result);
memoryKK->destroy_kokkos(k_lcomp,lcomp); // memoryKK->destroy_kokkos(k_lcomp,lcomp);
} }
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
@ -94,15 +94,10 @@ void ComputeLocalCompAtomKokkos<DeviceType>::compute_peratom()
memoryKK->create_kokkos(k_result,result,nmax,size_peratom_cols,"local/comp/atom:result"); memoryKK->create_kokkos(k_result,result,nmax,size_peratom_cols,"local/comp/atom:result");
d_result = k_result.view<DeviceType>(); d_result = k_result.view<DeviceType>();
array_atom = result; array_atom = result;
memoryKK->destroy_kokkos(k_lcomp,lcomp);
nmax = atom->nmax;
memoryKK->create_kokkos(k_lcomp,lcomp,nmax,"local/comp/atom:result");
d_lcomp = k_lcomp.view<DeviceType>();
} }
memoryKK->create_kokkos(k_lcomp,lcomp,size_peratom_cols,"local/comp/atom:lcomp"); // memoryKK->create_kokkos(k_lcomp,lcomp,size_peratom_cols,"local/comp/atom:lcomp");
d_lcomp = k_lcomp.view<DeviceType>(); // d_lcomp = k_lcomp.view<DeviceType>();
// invoke full neighbor list (will copy or build if necessary) // invoke full neighbor list (will copy or build if necessary)
@ -122,9 +117,8 @@ void ComputeLocalCompAtomKokkos<DeviceType>::compute_peratom()
type = atomKK->k_type.view<DeviceType>(); type = atomKK->k_type.view<DeviceType>();
mask = atomKK->k_mask.view<DeviceType>(); mask = atomKK->k_mask.view<DeviceType>();
ntypes = atom->ntypes; ntypes = atom->ntypes;
Kokkos::deep_copy(d_result,0.0); Kokkos::deep_copy(d_result,0.0);
Kokkos::deep_copy(d_lcomp,0.0);
copymode = 1; copymode = 1;
typename Kokkos::RangePolicy<DeviceType, TagComputeLocalCompAtom> policy(0,inum); typename Kokkos::RangePolicy<DeviceType, TagComputeLocalCompAtom> policy(0,inum);
@ -134,8 +128,6 @@ void ComputeLocalCompAtomKokkos<DeviceType>::compute_peratom()
k_result.modify<DeviceType>(); k_result.modify<DeviceType>();
k_result.sync_host(); k_result.sync_host();
k_lcomp.modify<DeviceType>();
k_lcomp.sync_host();
} }
template<class DeviceType> template<class DeviceType>
@ -145,13 +137,13 @@ void ComputeLocalCompAtomKokkos<DeviceType>::operator()(TagComputeLocalCompAtom,
const int i = d_ilist[ii]; const int i = d_ilist[ii];
// initialize / reset lcomp
for (int m = 0; m < ntypes; m++) d_lcomp(m) = 0;
// for (int m = 0; m < ntypes; m++) d_result(i,m) = 0;
if (mask[i] & groupbit) { if (mask[i] & groupbit) {
// initialize / reset lcomp
// for (int m = 0; m < ntypes; m++) d_lcomp(m) = 0;
// for (int m = 0; m < size_peratom_cols; m++) d_result(i,m) = 0.0;
const X_FLOAT xtmp = x(i,0); const X_FLOAT xtmp = x(i,0);
const X_FLOAT ytmp = x(i,1); const X_FLOAT ytmp = x(i,1);
const X_FLOAT ztmp = x(i,2); const X_FLOAT ztmp = x(i,2);
@ -161,8 +153,10 @@ void ComputeLocalCompAtomKokkos<DeviceType>::operator()(TagComputeLocalCompAtom,
int count = 1; int count = 1;
int itype = type[i]; int itype = type[i];
d_lcomp(itype-1)++;
// d_result(i,itype-1)++; // d_lcomp(itype-1)++;
// d_result(i,itype-1) = d_result(i,itype-1) + 1;
d_result(i,itype)++;
for (int jj = 0; jj < jnum; jj++) { for (int jj = 0; jj < jnum; jj++) {
@ -177,8 +171,9 @@ void ComputeLocalCompAtomKokkos<DeviceType>::operator()(TagComputeLocalCompAtom,
const F_FLOAT rsq = delx*delx + dely*dely + delz*delz; const F_FLOAT rsq = delx*delx + dely*dely + delz*delz;
if (rsq < cutsq) { if (rsq < cutsq) {
count++; count++;
d_lcomp(jtype-1)++; // d_lcomp(jtype-1)++;
// d_result(i,jtype-1)++; // d_result(i,jtype) = d_result(i,jtype) + 1;
d_result(i,jtype)++;
} }
// total count of atoms found in sampled radius range // total count of atoms found in sampled radius range
@ -189,10 +184,11 @@ void ComputeLocalCompAtomKokkos<DeviceType>::operator()(TagComputeLocalCompAtom,
double lfac = 1.0 / count; double lfac = 1.0 / count;
for (int n = 0; n < ntypes; n++) { // for (int n = 1; n < size_peratom_cols; n++) {
d_result(i,n+1) = d_lcomp(n+1) * lfac; // // d_result(i,n+1) = d_lcomp(n+1) * lfac;
// d_result(i,n+1) = d_result(i,n+1) * lfac; // d_result(i,n) = d_result(i,n) * lfac;
} // // d_result(i,n+1) = 123.0;
// }
} }
} }

View File

@ -55,9 +55,9 @@ template <class DeviceType> class ComputeLocalCompAtomKokkos : public ComputeLoc
typename AT::t_int_1d d_ilist; typename AT::t_int_1d d_ilist;
typename AT::t_int_1d d_numneigh; typename AT::t_int_1d d_numneigh;
DAT::tdual_int_1d k_lcomp; // DAT::tdual_int_1d k_lcomp;
DAT::tdual_float_2d k_result; DAT::tdual_float_2d k_result;
typename AT::t_int_1d d_lcomp; // typename AT::t_int_1d d_lcomp;
typename AT::t_float_2d d_result; typename AT::t_float_2d d_result;
}; };