Another refactor

This commit is contained in:
Stan Moore
2023-03-01 15:48:04 -07:00
parent 3667382067
commit 6d29e9209d
7 changed files with 71 additions and 70 deletions

View File

@ -80,7 +80,6 @@ CommKokkos::CommKokkos(LAMMPS *lmp) : CommBrick(lmp)
max_buf_fix = 0;
k_buf_send_fix = DAT::tdual_xfloat_1d("comm:k_buf_send_fix",1);
k_buf_recv_fix = DAT::tdual_xfloat_1d("comm:k_recv_send_fix",1);
}
/* ---------------------------------------------------------------------- */
@ -148,6 +147,44 @@ void CommKokkos::init()
if (ghost_velocity && atomKK->avecKK->no_comm_vel_flag) // not all Kokkos atom_vec styles have comm vel pack/unpack routines yet
forward_comm_classic = true;
if (!exchange_comm_classic) {
if (atom->nextra_grow) {
// check if all fixes with atom-based arrays support exchange on device
bool flag = true;
for (int iextra = 0; iextra < atom->nextra_grow; iextra++) {
auto fix_iextra = modify->fix[atom->extra_grow[iextra]];
if (!fix_iextra->exchange_comm_device) {
flag = false;
break;
}
if (!atomKK->avecKK->unpack_exchange_indices_flag || !flag) {
if (comm->me == 0) {
if (!atomKK->avecKK->unpack_exchange_indices_flag)
error->warning(FLERR,"Atom style not compatible with fix sending data in Kokkos communication, "
"switching to classic exchange/border communication");
else if (!flag)
error->warning(FLERR,"Fix with atom-based arrays not compatible with sending data in Kokkos communication, "
"switching to classic exchange/border communication");
}
exchange_comm_classic = true;
}
}
if (atom->nextra_border || mode != Comm::SINGLE || bordergroup ||
(ghost_velocity && atomKK->avecKK->no_border_vel_flag)) {
if (comm->me == 0) {
error->warning(FLERR,"Required border comm not yet implemented in Kokkos communication, "
"switching to classic exchange/border communication");
}
exchange_comm_classic = true;
}
}
}
}
/* ----------------------------------------------------------------------
@ -642,35 +679,6 @@ void CommKokkos::reverse_comm(Dump *dump)
void CommKokkos::exchange()
{
if (!exchange_comm_classic) {
if (atom->nextra_grow + atom->nextra_border) {
// check if all fixes with atom-based arrays derive from KokkosBase so we can enable exchange on device
// we are assuming that every fix with atom-based arrays need to send info during exchange
bool fix_flag = true;
for (int iextra = 0; iextra < atom->nextra_grow; iextra++) {
if (!dynamic_cast<KokkosBase*>(modify->fix[atom->extra_grow[iextra]])) {
fix_flag = false;
break;
}
}
if (!atomKK->avecKK->unpack_exchange_indices_flag || !fix_flag) {
static int print = 1;
if (print && comm->me == 0) {
if (!atomKK->avecKK->unpack_exchange_indices_flag)
error->warning(FLERR,"Atom style not compatible with fix sending data in Kokkos communication, "
"switching to classic exchange/border communication");
if (!fix_flag)
error->warning(FLERR,"Fix with atom-based arrays not compatible with sending data in Kokkos communication, "
"switching to classic exchange/border communication");
}
print = 0;
exchange_comm_classic = true;
}
}
}
if (!exchange_comm_classic) {
if (exchange_comm_on_host) exchange_device<LMPHostType>();
else exchange_device<LMPDeviceType>();
@ -812,7 +820,7 @@ void CommKokkos::exchange_device()
k_exchange_copylist.modify<LMPHostType>();
k_exchange_copylist.sync<DeviceType>();
nsend = k_count.h_view();
if (nsend > maxsend) grow_send_kokkos(nsend,1);
if (nsend > maxsend) grow_send_kokkos(nsend,0);
nsend =
atomKK->avecKK->pack_exchange_kokkos(k_count.h_view(),k_buf_send,
k_exchange_sendlist,k_exchange_copylist,
@ -827,7 +835,7 @@ void CommKokkos::exchange_device()
// if 2 procs in dimension, single send/recv
// if more than 2 procs in dimension, send/recv to both neighbors
const int data_size = atom->avec->size_border+atom->avec->size_velocity+2;
const int data_size = atomKK->avecKK->size_border+atomKK->avecKK->size_velocity+2;
DAT::tdual_int_1d k_indices;
if (procgrid[dim] == 1) nrecv = 0;
@ -872,9 +880,13 @@ void CommKokkos::exchange_device()
if (atom->nextra_grow) {
for (int iextra = 0; iextra < atom->nextra_grow; iextra++) {
KokkosBase *kkbase = dynamic_cast<KokkosBase*>(modify->fix[atom->extra_grow[iextra]]);
auto fix_iextra = modify->fix[atom->extra_grow[iextra]];
KokkosBase *kkbase = dynamic_cast<KokkosBase*>(fix_iextra);
int nextrasend = 0;
if (k_count.h_view()) {
nsend = k_count.h_view();
if (nsend) {
if (nsend*fix_iextra->maxexchange > maxsend)
grow_send_kokkos(nsend*fix_iextra->maxexchange,0);
nextrasend = kkbase->pack_exchange_kokkos(
k_count.h_view(),k_buf_send,k_exchange_sendlist,k_exchange_copylist,
ExecutionSpaceFromDevice<DeviceType>::space);
@ -949,20 +961,6 @@ void CommKokkos::exchange_device()
void CommKokkos::borders()
{
if (!exchange_comm_classic) {
static int print = 1;
if (mode != Comm::SINGLE || bordergroup ||
(ghost_velocity && atomKK->avecKK->no_border_vel_flag)) {
if (print && comm->me==0) {
error->warning(FLERR,"Required border comm not yet implemented in Kokkos communication, "
"switching to classic exchange/border communication");
}
print = 0;
exchange_comm_classic = true;
}
}
if (!exchange_comm_classic) {
if (exchange_comm_on_host) borders_device<LMPHostType>();
else borders_device<LMPDeviceType>();
@ -1354,8 +1352,9 @@ void CommKokkos::grow_recv(int n)
void CommKokkos::grow_send_kokkos(int n, int flag, ExecutionSpace space)
{
maxsend = static_cast<int> (BUFFACTOR * n);
int maxsend_border = (maxsend+BUFEXTRA+5)/atomKK->avecKK->size_border + 2;
int maxsend_border = (maxsend+BUFEXTRA)/atomKK->avecKK->size_border;
if (flag) {
if (space == Device)
k_buf_send.modify<LMPDeviceType>();
@ -1368,16 +1367,13 @@ void CommKokkos::grow_send_kokkos(int n, int flag, ExecutionSpace space)
else
k_buf_send.resize(maxsend_border,atomKK->avecKK->size_border);
buf_send = k_buf_send.view<LMPHostType>().data();
}
else {
} else {
if (ghost_velocity)
k_buf_send = DAT::
tdual_xfloat_2d("comm:k_buf_send",
maxsend_border,
MemoryKokkos::realloc_kokkos(k_buf_send,"comm:k_buf_send",maxsend_border,
atomKK->avecKK->size_border + atomKK->avecKK->size_velocity);
else
k_buf_send = DAT::
tdual_xfloat_2d("comm:k_buf_send",maxsend_border,atomKK->avecKK->size_border);
MemoryKokkos::realloc_kokkos(k_buf_send,"comm:k_buf_send",maxsend_border,
atomKK->avecKK->size_border);
buf_send = k_buf_send.view<LMPHostType>().data();
}
}
@ -1389,9 +1385,10 @@ void CommKokkos::grow_send_kokkos(int n, int flag, ExecutionSpace space)
void CommKokkos::grow_recv_kokkos(int n, ExecutionSpace /*space*/)
{
maxrecv = static_cast<int> (BUFFACTOR * n);
int maxrecv_border = (maxrecv+BUFEXTRA+5)/atomKK->avecKK->size_border + 2;
k_buf_recv = DAT::
tdual_xfloat_2d("comm:k_buf_recv",maxrecv_border,atomKK->avecKK->size_border);
int maxrecv_border = (maxrecv+BUFEXTRA)/atomKK->avecKK->size_border;
MemoryKokkos::realloc_kokkos(k_buf_recv,"comm:k_buf_recv",maxrecv_border,
atomKK->avecKK->size_border);
buf_recv = k_buf_recv.view<LMPHostType>().data();
}

View File

@ -71,8 +71,6 @@ class CommKokkos : public CommBrick {
DAT::tdual_int_2d k_exchange_lists;
DAT::tdual_int_1d k_exchange_sendlist,k_exchange_copylist,k_sendflag;
DAT::tdual_int_scalar k_count;
//double *buf_send; // send buffer for all comm
//double *buf_recv; // recv buffer for all comm
DAT::tdual_int_2d k_swap;
DAT::tdual_int_2d k_swap2;

View File

@ -32,6 +32,7 @@ FixNeighHistoryKokkos<DeviceType>::FixNeighHistoryKokkos(LAMMPS *lmp, int narg,
FixNeighHistory(lmp, narg, arg)
{
kokkosable = 1;
exchange_comm_device = 1;
atomKK = (AtomKokkos *)atom;
execution_space = ExecutionSpaceFromDevice<DeviceType>::space;

View File

@ -58,7 +58,7 @@ FixQEqReaxFFKokkos(LAMMPS *lmp, int narg, char **arg) :
{
kokkosable = 1;
comm_forward = comm_reverse = 2; // fused
forward_comm_device = 2;
forward_comm_device = exchange_comm_device = 1;
atomKK = (AtomKokkos *) atom;
execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
@ -68,6 +68,7 @@ FixQEqReaxFFKokkos(LAMMPS *lmp, int narg, char **arg) :
nmax = m_cap = 0;
allocated_flag = 0;
nprev = 4;
maxexchange = nprev*2;
memory->destroy(s_hist);
memory->destroy(t_hist);
@ -1344,8 +1345,8 @@ KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::operator()(TagQEqPackExchange, const int &mysend) const {
const int i = d_exchange_sendlist(mysend);
for (int m = 0; m < nprev; m++) d_exchange_buf(mysend,m) = d_s_hist(i,m);
for (int m = 0; m < nprev; m++) d_exchange_buf(mysend,nprev+m) = d_t_hist(i,m);
for (int m = 0; m < nprev; m++) d_buf(mysend*nprev*2 + m) = d_s_hist(i,m);
for (int m = 0; m < nprev; m++) d_buf(mysend*nprev*2 + nprev+m) = d_t_hist(i,m);
const int j = d_copylist(mysend);
@ -1367,7 +1368,9 @@ int FixQEqReaxFFKokkos<DeviceType>::pack_exchange_kokkos(
k_copylist.sync<DeviceType>();
k_exchange_sendlist.sync<DeviceType>();
d_exchange_buf = k_buf.view<DeviceType>();
d_buf = typename ArrayTypes<DeviceType>::t_xfloat_1d_um(
k_buf.template view<DeviceType>().data(),
k_buf.extent(0)*k_buf.extent(1));
d_copylist = k_copylist.view<DeviceType>();
d_exchange_sendlist = k_exchange_sendlist.view<DeviceType>();
this->nsend = nsend;
@ -1395,8 +1398,8 @@ void FixQEqReaxFFKokkos<DeviceType>::operator()(TagQEqUnpackExchange, const int
{
int index = d_indices(i);
if (index > 0) {
for (int m = 0; m < nprev; m++) d_s_hist(index,m) = d_exchange_buf(i,m);
for (int m = 0; m < nprev; m++) d_t_hist(index,m) = d_exchange_buf(i,nprev+m);
for (int m = 0; m < nprev; m++) d_s_hist(index,m) = d_buf(i*nprev*2 + m);
for (int m = 0; m < nprev; m++) d_t_hist(index,m) = d_buf(i*nprev*2 + nprev+m);
}
}
@ -1410,7 +1413,9 @@ void FixQEqReaxFFKokkos<DeviceType>::unpack_exchange_kokkos(
k_buf.sync<DeviceType>();
k_indices.sync<DeviceType>();
d_exchange_buf = k_buf.view<DeviceType>();
d_buf = typename ArrayTypes<DeviceType>::t_xfloat_1d_um(
k_buf.template view<DeviceType>().data(),
k_buf.extent(0)*k_buf.extent(1));
d_indices = k_indices.view<DeviceType>();
k_s_hist.template sync<DeviceType>();

View File

@ -260,7 +260,6 @@ class FixQEqReaxFFKokkos : public FixQEqReaxFF, public KokkosBase {
typename AT::t_xfloat_1d d_buf;
typename AT::t_int_1d d_copylist;
typename AT::t_int_1d d_indices;
typename AT::t_xfloat_2d d_exchange_buf;
typename AT::t_int_1d d_exchange_sendlist;
void init_shielding_k();

View File

@ -109,7 +109,7 @@ Fix::Fix(LAMMPS *lmp, int /*narg*/, char **arg) :
datamask_modify = ALL_MASK;
kokkosable = 0;
forward_comm_device = 0;
forward_comm_device = exchange_comm_device = 0;
copymode = 0;
}

View File

@ -131,6 +131,7 @@ class Fix : protected Pointers {
int kokkosable; // 1 if Kokkos fix
int forward_comm_device; // 1 if forward comm on Device
int exchange_comm_device; // 1 if exchange comm on Device
ExecutionSpace execution_space;
unsigned int datamask_read, datamask_modify;