Unified ylist CPU and GPU structures

This commit is contained in:
Evan Weinberg
2024-11-19 11:02:32 -08:00
parent 3c4a42ba72
commit abbcd86174
3 changed files with 37 additions and 68 deletions

View File

@ -847,8 +847,8 @@ void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSN
snaKK.ulisttot_gpu(iatom, idxu, ielem) = { utot_re, utot_im };
if (mapper.flip_sign == 0) {
snaKK.ylist_re_gpu(iatom, mapper.idxu_half, ielem) = 0.;
snaKK.ylist_im_gpu(iatom, mapper.idxu_half, ielem) = 0.;
snaKK.ylist_re(iatom, ielem, mapper.idxu_half) = 0.;
snaKK.ylist_im(iatom, ielem, mapper.idxu_half) = 0.;
}
}
}
@ -1118,7 +1118,8 @@ void PairSNAPKokkos<DeviceType, real_type, vector_length>::operator() (TagPairSN
snaKK.ulisttot_full(idxu, ielem, iatom) = utot;
// Zero Yi
snaKK.ylist(idxu_half, ielem, iatom) = {0., 0.};
snaKK.ylist_re(iatom, ielem, idxu_half) = 0;
snaKK.ylist_im(iatom, ielem, idxu_half) = 0;
// Symmetric term
const int sign_factor = (((ma+mb)%2==0)?1:-1);

View File

@ -289,13 +289,17 @@ class SNAKokkos {
// Beta for all atoms in list; aliases the object in PairSnapKokkos
t_sna_2d d_beta;
// Structures for both the CPU, GPU backend
t_sna_3d ylist_re;
t_sna_3d ylist_im;
// Structures for the CPU backend only
t_sna_3d blist;
t_sna_3c_ll ulisttot;
t_sna_3c_ll ulisttot_full; // un-folded ulisttot, cpu only
t_sna_3c_ll zlist;
t_sna_3c_ll ulist;
t_sna_3c_ll ylist;
// derivatives of data
t_sna_4c3_ll dulist;
@ -312,8 +316,6 @@ class SNAKokkos {
t_sna_3c ulisttot_gpu; // packed and de-symmetrized
t_sna_3c zlist_gpu;
t_sna_3d blist_gpu;
t_sna_3d ylist_re_gpu; // split real,
t_sna_3d ylist_im_gpu; // imag
int idxcg_max, idxu_max, idxu_half_max, idxu_cache_max, idxz_max, idxb_max;
@ -409,4 +411,3 @@ class SNAKokkos {
#include "sna_kokkos_impl.h"
#endif

View File

@ -313,8 +313,10 @@ void SNAKokkos<DeviceType, real_type, vector_length>::grow_rij(int newnatom, int
MemKK::realloc_kokkos(element,"sna:element",natom_pad,nmax);
MemKK::realloc_kokkos(dedr,"sna:dedr",natom_pad,nmax,3);
if constexpr (!host_flag) {
MemKK::realloc_kokkos(ylist_re,"sna:ylist_re", natom_pad, nelements, idxu_half_max);
MemKK::realloc_kokkos(ylist_im,"sna:ylist_im", natom_pad, nelements, idxu_half_max);
if constexpr (!host_flag) {
MemKK::realloc_kokkos(a_gpu,"sna:a_gpu",natom_pad,nmax);
MemKK::realloc_kokkos(b_gpu,"sna:b_gpu",natom_pad,nmax);
MemKK::realloc_kokkos(da_gpu,"sna:da_gpu",natom_pad,nmax,3);
@ -330,9 +332,6 @@ void SNAKokkos<DeviceType, real_type, vector_length>::grow_rij(int newnatom, int
MemKK::realloc_kokkos(zlist_gpu,"sna:zlist_gpu",natom_pad,idxz_max,ndoubles);
MemKK::realloc_kokkos(blist,"sna:blist",natom_pad,ntriples,idxb_max);
MemKK::realloc_kokkos(blist_gpu,"sna:blist_gpu",natom_pad,idxb_max,ntriples);
MemKK::realloc_kokkos(ylist,"sna:ylist",1,1,1);
MemKK::realloc_kokkos(ylist_re_gpu,"sna:ylist_re_gpu",natom_pad,idxu_half_max,nelements);
MemKK::realloc_kokkos(ylist_im_gpu,"sna:ylist_im_gpu",natom_pad,idxu_half_max,nelements);
MemKK::realloc_kokkos(dulist,"sna:dulist",1,1,1);
} else {
MemKK::realloc_kokkos(a_gpu,"sna:a_gpu",1,1);
@ -350,9 +349,6 @@ void SNAKokkos<DeviceType, real_type, vector_length>::grow_rij(int newnatom, int
MemKK::realloc_kokkos(zlist_gpu,"sna:zlist_gpu",1,1,1);
MemKK::realloc_kokkos(blist,"sna:blist",natom_pad,ntriples,idxb_max);
MemKK::realloc_kokkos(blist_gpu,"sna:blist_gpu",1,1,1);
MemKK::realloc_kokkos(ylist,"sna:ylist",idxu_half_max,nelements,natom_pad);
MemKK::realloc_kokkos(ylist_re_gpu,"sna:ylist_pack_re",1,1,1);
MemKK::realloc_kokkos(ylist_im_gpu,"sna:ylist_pack_im",1,1,1);
MemKK::realloc_kokkos(dulist,"sna:dulist",idxu_cache_max,natom_pad,nmax);
}
@ -806,8 +802,8 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi(const int& iato
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3);
Kokkos::atomic_add(&(ylist_re_gpu(iatom, jju_half, elem3)), betaj * ztmp.re);
Kokkos::atomic_add(&(ylist_im_gpu(iatom, jju_half, elem3)), betaj * ztmp.im);
Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj * ztmp.re);
Kokkos::atomic_add(&(ylist_im(iatom, elem3, jju_half)), betaj * ztmp.im);
} // end loop over elem3
} // end loop over elem2
} // end loop over elem1
@ -839,8 +835,8 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi_with_zlist(cons
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3);
Kokkos::atomic_add(&(ylist_re_gpu(iatom, jju_half, elem3)), betaj * ztmp.re);
Kokkos::atomic_add(&(ylist_im_gpu(iatom, jju_half, elem3)), betaj * ztmp.im);
Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj * ztmp.re);
Kokkos::atomic_add(&(ylist_im(iatom, elem3, jju_half)), betaj * ztmp.im);
} // end loop over elem3
idouble++;
} // end loop over elem2
@ -1096,7 +1092,7 @@ typename SNAKokkos<DeviceType, real_type, vector_length>::real_type SNAKokkos<De
// grab y_local early
// this will never be the last element of a row, no need to rescale.
complex y_local = complex(ylist_re_gpu(iatom, jjup + ma, jelem), ylist_im_gpu(iatom, jjup+ma, jelem));
complex y_local = complex(ylist_re(iatom, jelem, jjup + ma), ylist_im(iatom, jelem, jjup+ma));
// grab the cached value
const complex ulist_prev = ulist_wrapper.get(ma);
@ -1142,7 +1138,7 @@ typename SNAKokkos<DeviceType, real_type, vector_length>::real_type SNAKokkos<De
for (int ma = 0; ma < j; ma++) {
// grab y_local early
complex y_local = complex(ylist_re_gpu(iatom, jjup + ma, jelem), ylist_im_gpu(iatom, jjup+ma, jelem));
complex y_local = complex(ylist_re(iatom, jelem, jjup + ma), ylist_im(iatom, jelem, jjup+ma));
if (j % 2 == 1 && 2*(mb-1) == j-1) { // double check me...
if (ma == (mb-1)) { y_local = static_cast<real_type>(0.5)*y_local; }
else if (ma > (mb-1)) { y_local.re = static_cast<real_type>(0.); y_local.im = static_cast<real_type>(0.); } // can probably avoid this outright
@ -1459,30 +1455,10 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_yi_cpu(int iter) c
// pick out right beta value
for (int elem3 = 0; elem3 < nelements; elem3++) {
const real_type betaj = evaluate_beta_scaled(j1, j2, j, iatom, elem1, elem2, elem3);
if (j >= j1) {
const int jjb = idxb_block(j1, j2, j);
const int itriple = ((elem1 * nelements + elem2) * nelements + elem3) * idxb_max + jjb;
if (j1 == j) {
if (j2 == j) betaj = 3 * d_beta(iatom, itriple);
else betaj = 2 * d_beta(iatom, itriple);
} else betaj = d_beta(iatom, itriple);
} else if (j >= j2) {
const int jjb = idxb_block(j, j2, j1);
const int itriple = ((elem3 * nelements + elem2) * nelements + elem1) * idxb_max + jjb;
if (j2 == j) betaj = 2 * d_beta(iatom, itriple);
else betaj = d_beta(iatom, itriple);
} else {
const int jjb = idxb_block(j2, j, j1);
const int itriple = ((elem2 * nelements + elem3) * nelements + elem1) * idxb_max + jjb;
betaj = d_beta(iatom, itriple);
}
if (!bnorm_flag && j1 > j)
betaj *= static_cast<real_type>(j1 + 1) / static_cast<real_type>(j + 1);
Kokkos::atomic_add(&(ylist(jju_half, elem3, iatom).re), betaj*ztmp_r);
Kokkos::atomic_add(&(ylist(jju_half, elem3, iatom).im), betaj*ztmp_i);
Kokkos::atomic_add(&(ylist_re(iatom, elem3, jju_half)), betaj*ztmp_r);
Kokkos::atomic_add(&(ylist_im(iatom, elem3, jju_half)), betaj*ztmp_i);
} // end loop over elem3
} // end loop over elem2
} // end loop over elem1
@ -1541,12 +1517,10 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_deidrj_cpu(const t
for (int mb = 0; 2*mb < j; mb++)
for (int ma = 0; ma <= j; ma++) {
sum_tmp.x += dulist(jju_cache,iatom,jnbor,0).re * ylist(jju_half,jelem,iatom).re +
dulist(jju_cache,iatom,jnbor,0).im * ylist(jju_half,jelem,iatom).im;
sum_tmp.y += dulist(jju_cache,iatom,jnbor,1).re * ylist(jju_half,jelem,iatom).re +
dulist(jju_cache,iatom,jnbor,1).im * ylist(jju_half,jelem,iatom).im;
sum_tmp.z += dulist(jju_cache,iatom,jnbor,2).re * ylist(jju_half,jelem,iatom).re +
dulist(jju_cache,iatom,jnbor,2).im * ylist(jju_half,jelem,iatom).im;
const complex y_val = { ylist_re(iatom, jelem, jju_half), ylist_im(iatom, jelem, jju_half) };
sum_tmp.x += dulist(jju_cache,iatom,jnbor,0).re * y_val.re + dulist(jju_cache,iatom,jnbor,0).im * y_val.im;
sum_tmp.y += dulist(jju_cache,iatom,jnbor,1).re * y_val.re + dulist(jju_cache,iatom,jnbor,1).im * y_val.im;
sum_tmp.z += dulist(jju_cache,iatom,jnbor,2).re * y_val.re + dulist(jju_cache,iatom,jnbor,2).im * y_val.im;
jju_half++; jju_cache++;
} //end loop over ma mb
@ -1556,22 +1530,19 @@ void SNAKokkos<DeviceType, real_type, vector_length>::compute_deidrj_cpu(const t
int mb = j/2;
for (int ma = 0; ma < mb; ma++) {
sum_tmp.x += dulist(jju_cache,iatom,jnbor,0).re * ylist(jju_half,jelem,iatom).re +
dulist(jju_cache,iatom,jnbor,0).im * ylist(jju_half,jelem,iatom).im;
sum_tmp.y += dulist(jju_cache,iatom,jnbor,1).re * ylist(jju_half,jelem,iatom).re +
dulist(jju_cache,iatom,jnbor,1).im * ylist(jju_half,jelem,iatom).im;
sum_tmp.z += dulist(jju_cache,iatom,jnbor,2).re * ylist(jju_half,jelem,iatom).re +
dulist(jju_cache,iatom,jnbor,2).im * ylist(jju_half,jelem,iatom).im;
const complex y_val = { ylist_re(iatom, jelem, jju_half), ylist_im(iatom, jelem, jju_half) };
sum_tmp.x += dulist(jju_cache,iatom,jnbor,0).re * y_val.re + dulist(jju_cache,iatom,jnbor,0).im * y_val.im;
sum_tmp.y += dulist(jju_cache,iatom,jnbor,1).re * y_val.re + dulist(jju_cache,iatom,jnbor,1).im * y_val.im;
sum_tmp.z += dulist(jju_cache,iatom,jnbor,2).re * y_val.re + dulist(jju_cache,iatom,jnbor,2).im * y_val.im;
jju_half++; jju_cache++;
}
//int ma = mb;
sum_tmp.x += (dulist(jju_cache,iatom,jnbor,0).re * ylist(jju_half,jelem,iatom).re +
dulist(jju_cache,iatom,jnbor,0).im * ylist(jju_half,jelem,iatom).im)*0.5;
sum_tmp.y += (dulist(jju_cache,iatom,jnbor,1).re * ylist(jju_half,jelem,iatom).re +
dulist(jju_cache,iatom,jnbor,1).im * ylist(jju_half,jelem,iatom).im)*0.5;
sum_tmp.z += (dulist(jju_cache,iatom,jnbor,2).re * ylist(jju_half,jelem,iatom).re +
dulist(jju_cache,iatom,jnbor,2).im * ylist(jju_half,jelem,iatom).im)*0.5;
// 0.5 is meant to avoid double-counting
const complex y_val = { 0.5 * ylist_re(iatom, jelem, jju_half), 0.5 * ylist_im(iatom, jelem, jju_half) };
sum_tmp.x += dulist(jju_cache,iatom,jnbor,0).re * y_val.re + dulist(jju_cache,iatom,jnbor,0).im * y_val.im;
sum_tmp.y += dulist(jju_cache,iatom,jnbor,1).re * y_val.re + dulist(jju_cache,iatom,jnbor,1).im * y_val.im;
sum_tmp.z += dulist(jju_cache,iatom,jnbor,2).re * y_val.re + dulist(jju_cache,iatom,jnbor,2).im * y_val.im;
} // end if jeven
},final_sum); // end loop over j
@ -2328,8 +2299,10 @@ double SNAKokkos<DeviceType, real_type, vector_length>::memory_usage()
bytes += MemKK::memory_usage(rootpqarray);
bytes += MemKK::memory_usage(cglist);
if constexpr (!host_flag) {
bytes += MemKK::memory_usage(ylist_re);
bytes += MemKK::memory_usage(ylist_im);
if constexpr (!host_flag) {
bytes += MemKK::memory_usage(a_gpu);
bytes += MemKK::memory_usage(b_gpu);
bytes += MemKK::memory_usage(da_gpu);
@ -2343,11 +2316,7 @@ double SNAKokkos<DeviceType, real_type, vector_length>::memory_usage()
bytes += MemKK::memory_usage(zlist_gpu);
bytes += MemKK::memory_usage(blist_gpu);
bytes += MemKK::memory_usage(ylist_re_gpu);
bytes += MemKK::memory_usage(ylist_im_gpu);
} else {
bytes += MemKK::memory_usage(ulist);
bytes += MemKK::memory_usage(ulisttot);
bytes += MemKK::memory_usage(ulisttot_full);
@ -2355,8 +2324,6 @@ double SNAKokkos<DeviceType, real_type, vector_length>::memory_usage()
bytes += MemKK::memory_usage(zlist);
bytes += MemKK::memory_usage(blist);
bytes += MemKK::memory_usage(ylist);
bytes += MemKK::memory_usage(dulist);
}