Adding optional size arg for forward/reverse comm methods to Kokkos

This commit is contained in:
jtclemm
2024-10-31 15:27:22 -06:00
parent 3fd4f9b7f3
commit df882a9552
4 changed files with 135 additions and 80 deletions

View File

@ -455,10 +455,10 @@ void CommKokkos::forward_comm_device(Fix *fix, int size)
/* ----------------------------------------------------------------------
reverse communication invoked by a Fix
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Fix
size = 0 (default) -> use comm_reverse from Fix
size > 0 -> Fix passes max size per atom
the latter is only useful if Fix does several comm modes,
some are smaller than max stored in its comm_forward
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */
void CommKokkos::reverse_comm(Fix *fix, int size)
@ -482,72 +482,94 @@ void CommKokkos::reverse_comm_variable(Fix *fix)
/* ----------------------------------------------------------------------
forward communication invoked by a Compute
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Compute
size > 0 -> Compute passes max size per atom
the latter is only useful if Compute does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */
void CommKokkos::forward_comm(Compute *compute)
void CommKokkos::forward_comm(Compute *compute, int size)
{
k_sendlist.sync<LMPHostType>();
CommBrick::forward_comm(compute);
CommBrick::forward_comm(compute, size);
}
/* ----------------------------------------------------------------------
forward communication invoked by a Bond
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Bond
size > 0 -> Bond passes max size per atom
the latter is only useful if Bond does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */
void CommKokkos::forward_comm(Bond *bond)
void CommKokkos::forward_comm(Bond *bond, int size)
{
CommBrick::forward_comm(bond);
CommBrick::forward_comm(bond, size);
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Bond
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Bond
size > 0 -> Bond passes max size per atom
the latter is only useful if Bond does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */
void CommKokkos::reverse_comm(Bond *bond)
void CommKokkos::reverse_comm(Bond *bond, int size)
{
CommBrick::reverse_comm(bond);
CommBrick::reverse_comm(bond, size);
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Compute
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Compute
size > 0 -> Compute passes max size per atom
the latter is only useful if Compute does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */
void CommKokkos::reverse_comm(Compute *compute)
void CommKokkos::reverse_comm(Compute *compute, int size)
{
k_sendlist.sync<LMPHostType>();
CommBrick::reverse_comm(compute);
CommBrick::reverse_comm(compute, size);
}
/* ----------------------------------------------------------------------
forward communication invoked by a Dump
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Dump
size > 0 -> Dump passes max size per atom
the latter is only useful if Dump does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */
void CommKokkos::forward_comm(Pair *pair)
void CommKokkos::forward_comm(Pair *pair, int size)
{
if (pair->execution_space == Host || forward_pair_comm_classic) {
k_sendlist.sync<LMPHostType>();
CommBrick::forward_comm(pair);
CommBrick::forward_comm(pair, size);
} else {
k_sendlist.sync<LMPDeviceType>();
forward_comm_device<LMPDeviceType>(pair);
forward_comm_device<LMPDeviceType>(pair, size);
}
}
/* ---------------------------------------------------------------------- */
template<class DeviceType>
void CommKokkos::forward_comm_device(Pair *pair)
void CommKokkos::forward_comm_device(Pair *pair, int size)
{
int iswap,n;
int iswap,n,nsize;
MPI_Request request;
DAT::tdual_xfloat_1d k_buf_tmp;
int nsize = pair->comm_forward;
if (size) nsize = size;
else nsize = pair->comm_forward;
KokkosBase* pairKKBase = dynamic_cast<KokkosBase*>(pair);
int nmax = max_buf_pair;
@ -623,21 +645,21 @@ void CommKokkos::grow_buf_fix(int n) {
/* ---------------------------------------------------------------------- */
void CommKokkos::reverse_comm(Pair *pair)
void CommKokkos::reverse_comm(Pair *pair, int size)
{
if (pair->execution_space == Host || !pair->reverse_comm_device || reverse_pair_comm_classic) {
k_sendlist.sync<LMPHostType>();
CommBrick::reverse_comm(pair);
CommBrick::reverse_comm(pair, size);
} else {
k_sendlist.sync<LMPDeviceType>();
reverse_comm_device<LMPDeviceType>(pair);
reverse_comm_device<LMPDeviceType>(pair, size);
}
}
/* ---------------------------------------------------------------------- */
template<class DeviceType>
void CommKokkos::reverse_comm_device(Pair *pair)
void CommKokkos::reverse_comm_device(Pair *pair, int size)
{
int iswap,n;
MPI_Request request;
@ -645,7 +667,8 @@ void CommKokkos::reverse_comm_device(Pair *pair)
KokkosBase* pairKKBase = dynamic_cast<KokkosBase*>(pair);
int nsize = MAX(pair->comm_reverse,pair->comm_reverse_off);
if (size) nsize = size;
else nsize = MAX(pair->comm_reverse,pair->comm_reverse_off);
int nmax = max_buf_pair;
for (iswap = 0; iswap < nswap; iswap++) {
@ -702,18 +725,18 @@ void CommKokkos::reverse_comm_device(Pair *pair)
/* ---------------------------------------------------------------------- */
void CommKokkos::forward_comm(Dump *dump)
void CommKokkos::forward_comm(Dump *dump, int size)
{
k_sendlist.sync<LMPHostType>();
CommBrick::forward_comm(dump);
CommBrick::forward_comm(dump, size);
}
/* ---------------------------------------------------------------------- */
void CommKokkos::reverse_comm(Dump *dump)
void CommKokkos::reverse_comm(Dump *dump, int size)
{
k_sendlist.sync<LMPHostType>();
CommBrick::reverse_comm(dump);
CommBrick::reverse_comm(dump, size);
}
/* ----------------------------------------------------------------------

View File

@ -45,24 +45,24 @@ class CommKokkos : public CommBrick {
void exchange() override; // move atoms to new procs
void borders() override; // setup list of atoms to comm
void forward_comm(class Pair *) override; // forward comm from a Pair
void reverse_comm(class Pair *) override; // reverse comm from a Pair
void forward_comm(class Bond *) override; // forward comm from a Bond
void reverse_comm(class Bond *) override; // reverse comm from a Bond
void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix
void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix
void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix
void forward_comm(class Compute *) override; // forward from a Compute
void reverse_comm(class Compute *) override; // reverse from a Compute
void forward_comm(class Dump *) override; // forward comm from a Dump
void reverse_comm(class Dump *) override; // reverse comm from a Dump
void forward_comm(class Pair *, int size = 0) override; // forward comm from a Pair
void reverse_comm(class Pair *, int size = 0) override; // reverse comm from a Pair
void forward_comm(class Bond *, int size = 0) override; // forward comm from a Bond
void reverse_comm(class Bond *, int size = 0) override; // reverse comm from a Bond
void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix
void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix
void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix
void forward_comm(class Compute *, int size = 0) override; // forward from a Compute
void reverse_comm(class Compute *, int size = 0) override; // reverse from a Compute
void forward_comm(class Dump *, int size = 0) override; // forward comm from a Dump
void reverse_comm(class Dump *, int size = 0) override; // reverse comm from a Dump
void forward_comm_array(int, double **) override; // forward comm of array
template<class DeviceType> void forward_comm_device();
template<class DeviceType> void reverse_comm_device();
template<class DeviceType> void forward_comm_device(Pair *pair);
template<class DeviceType> void reverse_comm_device(Pair *pair);
template<class DeviceType> void forward_comm_device(Pair *pair, int size=0);
template<class DeviceType> void reverse_comm_device(Pair *pair, int size=0);
template<class DeviceType> void forward_comm_device(Fix *fix, int size=0);
template<class DeviceType> void exchange_device();
template<class DeviceType> void borders_device();

View File

@ -417,7 +417,11 @@ void CommTiledKokkos::borders()
/* ----------------------------------------------------------------------
forward communication invoked by a Pair
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Fix
size > 0 -> Fix passes max size per atom
the latter is only useful if Fix does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */
void CommTiledKokkos::forward_comm(Pair *pair)
@ -427,32 +431,44 @@ void CommTiledKokkos::forward_comm(Pair *pair)
/* ----------------------------------------------------------------------
reverse communication invoked by a Pair
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Pair
size > 0 -> Pair passes max size per atom
the latter is only useful if Pair does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */
void CommTiledKokkos::reverse_comm(Pair *pair)
void CommTiledKokkos::reverse_comm(Pair *pair, int size)
{
CommTiled::reverse_comm(pair);
CommTiled::reverse_comm(pair, size);
}
/* ----------------------------------------------------------------------
forward communication invoked by a Bond
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Bond
size > 0 -> Bond passes max size per atom
the latter is only useful if Bond does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */
void CommTiledKokkos::forward_comm(Bond *bond)
void CommTiledKokkos::forward_comm(Bond *bond, int size)
{
CommTiled::forward_comm(bond);
CommTiled::forward_comm(bond, size);
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Bond
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Bond
size > 0 -> Bond passes max size per atom
the latter is only useful if Bond does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */
void CommTiledKokkos::reverse_comm(Bond *bond)
void CommTiledKokkos::reverse_comm(Bond *bond, int size)
{
CommTiled::reverse_comm(bond);
CommTiled::reverse_comm(bond, size);
}
/* ----------------------------------------------------------------------
@ -466,21 +482,21 @@ void CommTiledKokkos::reverse_comm(Bond *bond)
void CommTiledKokkos::forward_comm(Fix *fix, int size)
{
CommTiled::forward_comm(fix,size);
CommTiled::forward_comm(fix, size);
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Fix
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Fix
size = 0 (default) -> use comm_reverse from Fix
size > 0 -> Fix passes max size per atom
the latter is only useful if Fix does several comm modes,
some are smaller than max stored in its comm_forward
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */
void CommTiledKokkos::reverse_comm(Fix *fix, int size)
{
CommTiled::reverse_comm(fix,size);
CommTiled::reverse_comm(fix, size);
}
/* ----------------------------------------------------------------------
@ -497,42 +513,58 @@ void CommTiledKokkos::reverse_comm_variable(Fix *fix)
/* ----------------------------------------------------------------------
forward communication invoked by a Compute
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Compute
size > 0 -> Compute passes max size per atom
the latter is only useful if Compute does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */
void CommTiledKokkos::forward_comm(Compute *compute)
void CommTiledKokkos::forward_comm(Compute *compute, int size)
{
CommTiled::forward_comm(compute);
CommTiled::forward_comm(compute, size);
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Compute
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Compute
size > 0 -> Compute passes max size per atom
the latter is only useful if Compute does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */
void CommTiledKokkos::reverse_comm(Compute *compute)
void CommTiledKokkos::reverse_comm(Compute *compute, int size)
{
CommTiled::reverse_comm(compute);
CommTiled::reverse_comm(compute, size);
}
/* ----------------------------------------------------------------------
forward communication invoked by a Dump
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_forward from Dump
size > 0 -> Dump passes max size per atom
the latter is only useful if Dump does several comm modes,
some are smaller than max stored in its comm_forward
------------------------------------------------------------------------- */
void CommTiledKokkos::forward_comm(Dump *dump)
void CommTiledKokkos::forward_comm(Dump *dump, int size)
{
CommTiled::forward_comm(dump);
CommTiled::forward_comm(dump, size);
}
/* ----------------------------------------------------------------------
reverse communication invoked by a Dump
nsize used only to set recv buffer limit
size/nsize used only to set recv buffer limit
size = 0 (default) -> use comm_reverse from Dump
size > 0 -> Dump passes max size per atom
the latter is only useful if Dump does several comm modes,
some are smaller than max stored in its comm_reverse
------------------------------------------------------------------------- */
void CommTiledKokkos::reverse_comm(Dump *dump)
void CommTiledKokkos::reverse_comm(Dump *dump, int size)
{
CommTiled::reverse_comm(dump);
CommTiled::reverse_comm(dump, size);
}
/* ----------------------------------------------------------------------

View File

@ -46,17 +46,17 @@ class CommTiledKokkos : public CommTiled {
void exchange() override; // move atoms to new procs
void borders() override; // setup list of atoms to comm
void forward_comm(class Pair *) override; // forward comm from a Pair
void reverse_comm(class Pair *) override; // reverse comm from a Pair
void forward_comm(class Bond *) override; // forward comm from a Bond
void reverse_comm(class Bond *) override; // reverse comm from a Bond
void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix
void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix
void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix
void forward_comm(class Compute *) override; // forward from a Compute
void reverse_comm(class Compute *) override; // reverse from a Compute
void forward_comm(class Dump *) override; // forward comm from a Dump
void reverse_comm(class Dump *) override; // reverse comm from a Dump
void forward_comm(class Pair *, int size = 0) override; // forward comm from a Pair
void reverse_comm(class Pair *, int size = 0) override; // reverse comm from a Pair
void forward_comm(class Bond *, int size = 0) override; // forward comm from a Bond
void reverse_comm(class Bond *, int size = 0) override; // reverse comm from a Bond
void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix
void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix
void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix
void forward_comm(class Compute *, int size = 0) override; // forward from a Compute
void reverse_comm(class Compute *, int size = 0) override; // reverse from a Compute
void forward_comm(class Dump *, int size = 0) override; // forward comm from a Dump
void reverse_comm(class Dump *, int size = 0) override; // reverse comm from a Dump
void forward_comm_array(int, double **) override; // forward comm of array