Adding optional size arg for forward/reverse comm methods to Kokkos

This commit is contained in:
jtclemm
2024-10-31 15:27:22 -06:00
parent 3fd4f9b7f3
commit df882a9552
4 changed files with 135 additions and 80 deletions

View File

@ -455,10 +455,10 @@ void CommKokkos::forward_comm_device(Fix *fix, int size)
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
reverse communication invoked by a Fix reverse communication invoked by a Fix
size/nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Fix size = 0 (default) -> use comm_reverse from Fix
size > 0 -> Fix passes max size per atom size > 0 -> Fix passes max size per atom
the latter is only useful if Fix does several comm modes, the latter is only useful if Fix does several comm modes,
some are smaller than max stored in its comm_forward some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommKokkos::reverse_comm(Fix *fix, int size) void CommKokkos::reverse_comm(Fix *fix, int size)
@ -482,72 +482,94 @@ void CommKokkos::reverse_comm_variable(Fix *fix)
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
forward communication invoked by a Compute forward communication invoked by a Compute
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Compute
size > 0 -> Compute passes max size per atom
the latter is only useful if Compute does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommKokkos::forward_comm(Compute *compute) void CommKokkos::forward_comm(Compute *compute, int size)
{ {
k_sendlist.sync<LMPHostType>(); k_sendlist.sync<LMPHostType>();
CommBrick::forward_comm(compute); CommBrick::forward_comm(compute, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
forward communication invoked by a Bond forward communication invoked by a Bond
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Bond
size > 0 -> Bond passes max size per atom
the latter is only useful if Bond does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommKokkos::forward_comm(Bond *bond) void CommKokkos::forward_comm(Bond *bond, int size)
{ {
CommBrick::forward_comm(bond); CommBrick::forward_comm(bond, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
reverse communication invoked by a Bond reverse communication invoked by a Bond
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Bond
size > 0 -> Bond passes max size per atom
the latter is only useful if Bond does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommKokkos::reverse_comm(Bond *bond) void CommKokkos::reverse_comm(Bond *bond, int size)
{ {
CommBrick::reverse_comm(bond); CommBrick::reverse_comm(bond, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
reverse communication invoked by a Compute reverse communication invoked by a Compute
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Compute
size > 0 -> Compute passes max size per atom
the latter is only useful if Compute does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommKokkos::reverse_comm(Compute *compute) void CommKokkos::reverse_comm(Compute *compute, int size)
{ {
k_sendlist.sync<LMPHostType>(); k_sendlist.sync<LMPHostType>();
CommBrick::reverse_comm(compute); CommBrick::reverse_comm(compute, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
forward communication invoked by a Dump forward communication invoked by a Dump
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Dump
size > 0 -> Dump passes max size per atom
the latter is only useful if Dump does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommKokkos::forward_comm(Pair *pair) void CommKokkos::forward_comm(Pair *pair, int size)
{ {
if (pair->execution_space == Host || forward_pair_comm_classic) { if (pair->execution_space == Host || forward_pair_comm_classic) {
k_sendlist.sync<LMPHostType>(); k_sendlist.sync<LMPHostType>();
CommBrick::forward_comm(pair); CommBrick::forward_comm(pair, size);
} else { } else {
k_sendlist.sync<LMPDeviceType>(); k_sendlist.sync<LMPDeviceType>();
forward_comm_device<LMPDeviceType>(pair); forward_comm_device<LMPDeviceType>(pair, size);
} }
} }
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
template<class DeviceType> template<class DeviceType>
void CommKokkos::forward_comm_device(Pair *pair) void CommKokkos::forward_comm_device(Pair *pair, int size)
{ {
int iswap,n; int iswap,n,nsize;
MPI_Request request; MPI_Request request;
DAT::tdual_xfloat_1d k_buf_tmp; DAT::tdual_xfloat_1d k_buf_tmp;
int nsize = pair->comm_forward; if (size) nsize = size;
else nsize = pair->comm_forward;
KokkosBase* pairKKBase = dynamic_cast<KokkosBase*>(pair); KokkosBase* pairKKBase = dynamic_cast<KokkosBase*>(pair);
int nmax = max_buf_pair; int nmax = max_buf_pair;
@ -623,21 +645,21 @@ void CommKokkos::grow_buf_fix(int n) {
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
void CommKokkos::reverse_comm(Pair *pair) void CommKokkos::reverse_comm(Pair *pair, int size)
{ {
if (pair->execution_space == Host || !pair->reverse_comm_device || reverse_pair_comm_classic) { if (pair->execution_space == Host || !pair->reverse_comm_device || reverse_pair_comm_classic) {
k_sendlist.sync<LMPHostType>(); k_sendlist.sync<LMPHostType>();
CommBrick::reverse_comm(pair); CommBrick::reverse_comm(pair, size);
} else { } else {
k_sendlist.sync<LMPDeviceType>(); k_sendlist.sync<LMPDeviceType>();
reverse_comm_device<LMPDeviceType>(pair); reverse_comm_device<LMPDeviceType>(pair, size);
} }
} }
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
template<class DeviceType> template<class DeviceType>
void CommKokkos::reverse_comm_device(Pair *pair) void CommKokkos::reverse_comm_device(Pair *pair, int size)
{ {
int iswap,n; int iswap,n;
MPI_Request request; MPI_Request request;
@ -645,7 +667,8 @@ void CommKokkos::reverse_comm_device(Pair *pair)
KokkosBase* pairKKBase = dynamic_cast<KokkosBase*>(pair); KokkosBase* pairKKBase = dynamic_cast<KokkosBase*>(pair);
int nsize = MAX(pair->comm_reverse,pair->comm_reverse_off); if (size) nsize = size;
else nsize = MAX(pair->comm_reverse,pair->comm_reverse_off);
int nmax = max_buf_pair; int nmax = max_buf_pair;
for (iswap = 0; iswap < nswap; iswap++) { for (iswap = 0; iswap < nswap; iswap++) {
@ -702,18 +725,18 @@ void CommKokkos::reverse_comm_device(Pair *pair)
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
void CommKokkos::forward_comm(Dump *dump) void CommKokkos::forward_comm(Dump *dump, int size)
{ {
k_sendlist.sync<LMPHostType>(); k_sendlist.sync<LMPHostType>();
CommBrick::forward_comm(dump); CommBrick::forward_comm(dump, size);
} }
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
void CommKokkos::reverse_comm(Dump *dump) void CommKokkos::reverse_comm(Dump *dump, int size)
{ {
k_sendlist.sync<LMPHostType>(); k_sendlist.sync<LMPHostType>();
CommBrick::reverse_comm(dump); CommBrick::reverse_comm(dump, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------

View File

@ -45,24 +45,24 @@ class CommKokkos : public CommBrick {
void exchange() override; // move atoms to new procs void exchange() override; // move atoms to new procs
void borders() override; // setup list of atoms to comm void borders() override; // setup list of atoms to comm
void forward_comm(class Pair *) override; // forward comm from a Pair void forward_comm(class Pair *, int size = 0) override; // forward comm from a Pair
void reverse_comm(class Pair *) override; // reverse comm from a Pair void reverse_comm(class Pair *, int size = 0) override; // reverse comm from a Pair
void forward_comm(class Bond *) override; // forward comm from a Bond void forward_comm(class Bond *, int size = 0) override; // forward comm from a Bond
void reverse_comm(class Bond *) override; // reverse comm from a Bond void reverse_comm(class Bond *, int size = 0) override; // reverse comm from a Bond
void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix
void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix
void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix
void forward_comm(class Compute *) override; // forward from a Compute void forward_comm(class Compute *, int size = 0) override; // forward from a Compute
void reverse_comm(class Compute *) override; // reverse from a Compute void reverse_comm(class Compute *, int size = 0) override; // reverse from a Compute
void forward_comm(class Dump *) override; // forward comm from a Dump void forward_comm(class Dump *, int size = 0) override; // forward comm from a Dump
void reverse_comm(class Dump *) override; // reverse comm from a Dump void reverse_comm(class Dump *, int size = 0) override; // reverse comm from a Dump
void forward_comm_array(int, double **) override; // forward comm of array void forward_comm_array(int, double **) override; // forward comm of array
template<class DeviceType> void forward_comm_device(); template<class DeviceType> void forward_comm_device();
template<class DeviceType> void reverse_comm_device(); template<class DeviceType> void reverse_comm_device();
template<class DeviceType> void forward_comm_device(Pair *pair); template<class DeviceType> void forward_comm_device(Pair *pair, int size=0);
template<class DeviceType> void reverse_comm_device(Pair *pair); template<class DeviceType> void reverse_comm_device(Pair *pair, int size=0);
template<class DeviceType> void forward_comm_device(Fix *fix, int size=0); template<class DeviceType> void forward_comm_device(Fix *fix, int size=0);
template<class DeviceType> void exchange_device(); template<class DeviceType> void exchange_device();
template<class DeviceType> void borders_device(); template<class DeviceType> void borders_device();

View File

@ -417,7 +417,11 @@ void CommTiledKokkos::borders()
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
forward communication invoked by a Pair forward communication invoked by a Pair
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Fix
size > 0 -> Fix passes max size per atom
the latter is only useful if Fix does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommTiledKokkos::forward_comm(Pair *pair) void CommTiledKokkos::forward_comm(Pair *pair)
@ -427,32 +431,44 @@ void CommTiledKokkos::forward_comm(Pair *pair)
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
reverse communication invoked by a Pair reverse communication invoked by a Pair
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Pair
size > 0 -> Pair passes max size per atom
the latter is only useful if Pair does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommTiledKokkos::reverse_comm(Pair *pair) void CommTiledKokkos::reverse_comm(Pair *pair, int size)
{ {
CommTiled::reverse_comm(pair); CommTiled::reverse_comm(pair, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
forward communication invoked by a Bond forward communication invoked by a Bond
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Bond
size > 0 -> Bond passes max size per atom
the latter is only useful if Bond does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommTiledKokkos::forward_comm(Bond *bond) void CommTiledKokkos::forward_comm(Bond *bond, int size)
{ {
CommTiled::forward_comm(bond); CommTiled::forward_comm(bond, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
reverse communication invoked by a Bond reverse communication invoked by a Bond
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Bond
size > 0 -> Bond passes max size per atom
the latter is only useful if Bond does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommTiledKokkos::reverse_comm(Bond *bond) void CommTiledKokkos::reverse_comm(Bond *bond, int size)
{ {
CommTiled::reverse_comm(bond); CommTiled::reverse_comm(bond, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
@ -466,21 +482,21 @@ void CommTiledKokkos::reverse_comm(Bond *bond)
void CommTiledKokkos::forward_comm(Fix *fix, int size) void CommTiledKokkos::forward_comm(Fix *fix, int size)
{ {
CommTiled::forward_comm(fix,size); CommTiled::forward_comm(fix, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
reverse communication invoked by a Fix reverse communication invoked by a Fix
size/nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Fix size = 0 (default) -> use comm_reverse from Fix
size > 0 -> Fix passes max size per atom size > 0 -> Fix passes max size per atom
the latter is only useful if Fix does several comm modes, the latter is only useful if Fix does several comm modes,
some are smaller than max stored in its comm_forward some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommTiledKokkos::reverse_comm(Fix *fix, int size) void CommTiledKokkos::reverse_comm(Fix *fix, int size)
{ {
CommTiled::reverse_comm(fix,size); CommTiled::reverse_comm(fix, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
@ -497,42 +513,58 @@ void CommTiledKokkos::reverse_comm_variable(Fix *fix)
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
forward communication invoked by a Compute forward communication invoked by a Compute
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Compute
size > 0 -> Compute passes max size per atom
the latter is only useful if Compute does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommTiledKokkos::forward_comm(Compute *compute) void CommTiledKokkos::forward_comm(Compute *compute, int size)
{ {
CommTiled::forward_comm(compute); CommTiled::forward_comm(compute, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
reverse communication invoked by a Compute reverse communication invoked by a Compute
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Compute
size > 0 -> Compute passes max size per atom
the latter is only useful if Compute does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommTiledKokkos::reverse_comm(Compute *compute) void CommTiledKokkos::reverse_comm(Compute *compute, int size)
{ {
CommTiled::reverse_comm(compute); CommTiled::reverse_comm(compute, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
forward communication invoked by a Dump forward communication invoked by a Dump
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Dump
size > 0 -> Dump passes max size per atom
the latter is only useful if Dump does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommTiledKokkos::forward_comm(Dump *dump) void CommTiledKokkos::forward_comm(Dump *dump, int size)
{ {
CommTiled::forward_comm(dump); CommTiled::forward_comm(dump, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
reverse communication invoked by a Dump reverse communication invoked by a Dump
nsize used only to set recv buffer limit size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Dump
size > 0 -> Dump passes max size per atom
the latter is only useful if Dump does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void CommTiledKokkos::reverse_comm(Dump *dump) void CommTiledKokkos::reverse_comm(Dump *dump, int size)
{ {
CommTiled::reverse_comm(dump); CommTiled::reverse_comm(dump, size);
} }
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------

View File

@ -46,17 +46,17 @@ class CommTiledKokkos : public CommTiled {
void exchange() override; // move atoms to new procs void exchange() override; // move atoms to new procs
void borders() override; // setup list of atoms to comm void borders() override; // setup list of atoms to comm
void forward_comm(class Pair *) override; // forward comm from a Pair void forward_comm(class Pair *, int size = 0) override; // forward comm from a Pair
void reverse_comm(class Pair *) override; // reverse comm from a Pair void reverse_comm(class Pair *, int size = 0) override; // reverse comm from a Pair
void forward_comm(class Bond *) override; // forward comm from a Bond void forward_comm(class Bond *, int size = 0) override; // forward comm from a Bond
void reverse_comm(class Bond *) override; // reverse comm from a Bond void reverse_comm(class Bond *, int size = 0) override; // reverse comm from a Bond
void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix
void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix
void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix
void forward_comm(class Compute *) override; // forward from a Compute void forward_comm(class Compute *, int size = 0) override; // forward from a Compute
void reverse_comm(class Compute *) override; // reverse from a Compute void reverse_comm(class Compute *, int size = 0) override; // reverse from a Compute
void forward_comm(class Dump *) override; // forward comm from a Dump void forward_comm(class Dump *, int size = 0) override; // forward comm from a Dump
void reverse_comm(class Dump *) override; // reverse comm from a Dump void reverse_comm(class Dump *, int size = 0) override; // reverse comm from a Dump
void forward_comm_array(int, double **) override; // forward comm of array void forward_comm_array(int, double **) override; // forward comm of array