Tried two ways of doing parallel reduce for fsum

This commit is contained in:
Trung Nguyen
2023-08-02 06:59:24 -05:00
parent 6a991ff0a0
commit 34c398dd37
2 changed files with 69 additions and 7 deletions

View File

@ -17,7 +17,7 @@
#include "atom_kokkos.h"
#include "update.h"
#include "modify.h"
#include "domain.h"
#include "domain_kokkos.h"
#include "region.h"
#include "input.h"
#include "variable.h"
@ -77,10 +77,11 @@ void FixEfieldKokkos<DeviceType>::init()
template<class DeviceType>
void FixEfieldKokkos<DeviceType>::post_force(int /*vflag*/)
{
atomKK->sync(execution_space, F_MASK | Q_MASK | MASK_MASK);
atomKK->sync(execution_space, F_MASK | Q_MASK | IMAGE_MASK | MASK_MASK);
f = atomKK->k_f.view<DeviceType>();
q = atomKK->k_q.view<DeviceType>();
image = atomKK->k_image.view<DeviceType>();
mask = atomKK->k_mask.view<DeviceType>();
int nlocal = atom->nlocal;
@ -113,7 +114,50 @@ void FixEfieldKokkos<DeviceType>::post_force(int /*vflag*/)
if (varflag == CONSTANT) {
copymode = 1;
Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType, TagFixEfieldConstant>(0,nlocal),*this,fsum_kk);
//Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType, TagFixEfieldConstant>(0,nlocal),*this,fsum_kk);
{
// local variables for lambda capture
auto prd = Few<double,3>(domain->prd);
auto h = Few<double,6>(domain->h);
auto triclinic = domain->triclinic;
auto l_ex = ex;
auto l_ey = ey;
auto l_ez = ez;
auto l_x = x;
auto l_q = q;
auto l_f = f;
auto l_mask = mask;
auto l_image = image;
auto l_groupbit = groupbit;
Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType>(0,nlocal),
LAMMPS_LAMBDA(int i, double_4& fsum_kk) {
if (l_mask[i] & l_groupbit) {
Few<double,3> x_i;
x_i[0] = l_x(i,0);
x_i[1] = l_x(i,1);
x_i[2] = l_x(i,2);
auto unwrap = DomainKokkos::unmap(prd,h,triclinic,x_i,l_image(i));
auto qtmp = l_q(i);
auto fx = qtmp * l_ex;
auto fy = qtmp * l_ey;
auto fz = qtmp * l_ez;
l_f(i,0) += fx;
l_f(i,1) += fy;
l_f(i,2) += fz;
fsum_kk.d0 -= fx * unwrap[0] + fy * unwrap[1] + fz * unwrap[2];
fsum_kk.d1 += fx;
fsum_kk.d2 += fy;
fsum_kk.d3 += fz;
}
}, fsum_kk);
}
copymode = 0;
// variable force, wrap with clear/add
@ -159,6 +203,14 @@ KOKKOS_INLINE_FUNCTION
void FixEfieldKokkos<DeviceType>::operator()(TagFixEfieldConstant, const int &i, double_4& fsum_kk) const {
if (mask[i] & groupbit) {
if (region && !d_match[i]) return;
auto prd = Few<double,3>(domain->prd);
auto h = Few<double,6>(domain->h);
auto triclinic = domain->triclinic;
Few<double,3> x_i;
x_i[0] = x(i,0);
x_i[1] = x(i,1);
x_i[2] = x(i,2);
auto unwrap = DomainKokkos::unmap(prd,h,triclinic,x_i,image(i));
const F_FLOAT qtmp = q[i];
const F_FLOAT fx = qtmp * ex;
const F_FLOAT fy = qtmp * ey;
@ -166,7 +218,8 @@ void FixEfieldKokkos<DeviceType>::operator()(TagFixEfieldConstant, const int &i,
f(i,0) += fx;
f(i,1) += fy;
f(i,2) += fz;
//fsum_kk.d0 -= fx * unwrap[0] + fy * unwrap[1] + fz * unwrap[2];
fsum_kk.d0 -= fx * unwrap[0] + fy * unwrap[1] + fz * unwrap[2];
fsum_kk.d1 += fx;
fsum_kk.d2 += fy;
fsum_kk.d3 += fz;
@ -176,19 +229,27 @@ void FixEfieldKokkos<DeviceType>::operator()(TagFixEfieldConstant, const int &i,
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixEfieldKokkos<DeviceType>::operator()(TagFixEfieldNonConstant, const int &i, double_4& fsum_kk) const {
auto prd = Few<double,3>(domain->prd);
auto h = Few<double,6>(domain->h);
auto triclinic = domain->triclinic;
if (mask[i] & groupbit) {
if (region && !d_match[i]) return;
Few<double,3> x_i;
x_i[0] = x(i,0);
x_i[1] = x(i,1);
x_i[2] = x(i,2);
auto unwrap = DomainKokkos::unmap(prd,h,triclinic,x_i,image(i));
const F_FLOAT qtmp = q[i];
const F_FLOAT fx = qtmp * ex;
const F_FLOAT fy = qtmp * ey;
const F_FLOAT fz = qtmp * ez;
if (xstyle == ATOM) f(i,0) += d_efield(i,0);
else if (xstyle) f(i,0) += fx;
if (ystyle == ATOM) f(i,1) = d_efield(i,1);
if (ystyle == ATOM) f(i,1) += d_efield(i,1);
else if (ystyle) f(i,1) += fy;
if (zstyle == ATOM) f(i,2) = d_efield(i,2);
if (zstyle == ATOM) f(i,2) += d_efield(i,2);
else if (zstyle) f(i,2) += fz;
//fsum_kk.d0 -= fx * unwrap[0] + fy * unwrap[1] + fz * unwrap[2];
fsum_kk.d0 -= fx * unwrap[0] + fy * unwrap[1] + fz * unwrap[2];
fsum_kk.d1 += fx;
fsum_kk.d2 += fy;
fsum_kk.d3 += fz;

View File

@ -75,6 +75,7 @@ class FixEfieldKokkos : public FixEfield {
typename AT::t_x_array_randomread x;
typename AT::t_float_1d_randomread q;
typename AT::t_f_array f;
typename AT::t_imageint_1d_randomread image;
typename AT::t_int_1d_randomread mask;
};