From df882a9552b96e0724d9024d8241f1703076e310 Mon Sep 17 00:00:00 2001 From: jtclemm Date: Thu, 31 Oct 2024 15:27:22 -0600 Subject: [PATCH] Adding optional size arg for forward/reverse comm methods to Kokkos --- src/KOKKOS/comm_kokkos.cpp | 83 +++++++++++++++++++------------ src/KOKKOS/comm_kokkos.h | 26 +++++----- src/KOKKOS/comm_tiled_kokkos.cpp | 84 ++++++++++++++++++++++---------- src/KOKKOS/comm_tiled_kokkos.h | 22 ++++----- 4 files changed, 135 insertions(+), 80 deletions(-) diff --git a/src/KOKKOS/comm_kokkos.cpp b/src/KOKKOS/comm_kokkos.cpp index 8f821c3036..eea79248fe 100644 --- a/src/KOKKOS/comm_kokkos.cpp +++ b/src/KOKKOS/comm_kokkos.cpp @@ -455,10 +455,10 @@ void CommKokkos::forward_comm_device(Fix *fix, int size) /* ---------------------------------------------------------------------- reverse communication invoked by a Fix size/nsize used only to set recv buffer limit - size = 0 (default) -> use comm_forward from Fix + size = 0 (default) -> use comm_reverse from Fix size > 0 -> Fix passes max size per atom the latter is only useful if Fix does several comm modes, - some are smaller than max stored in its comm_forward + some are smaller than max stored in its comm_reverse ------------------------------------------------------------------------- */ void CommKokkos::reverse_comm(Fix *fix, int size) @@ -482,72 +482,94 @@ void CommKokkos::reverse_comm_variable(Fix *fix) /* ---------------------------------------------------------------------- forward communication invoked by a Compute - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_forward from Compute + size > 0 -> Compute passes max size per atom + the latter is only useful if Compute does several comm modes, + some are smaller than max stored in its comm_forward ------------------------------------------------------------------------- */ -void CommKokkos::forward_comm(Compute *compute) +void CommKokkos::forward_comm(Compute *compute, int size) { k_sendlist.sync(); - CommBrick::forward_comm(compute); + CommBrick::forward_comm(compute, size); } /* ---------------------------------------------------------------------- forward communication invoked by a Bond - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_forward from Bond + size > 0 -> Bond passes max size per atom + the latter is only useful if Bond does several comm modes, + some are smaller than max stored in its comm_forward ------------------------------------------------------------------------- */ -void CommKokkos::forward_comm(Bond *bond) +void CommKokkos::forward_comm(Bond *bond, int size) { - CommBrick::forward_comm(bond); + CommBrick::forward_comm(bond, size); } /* ---------------------------------------------------------------------- reverse communication invoked by a Bond - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_reverse from Bond + size > 0 -> Bond passes max size per atom + the latter is only useful if Bond does several comm modes, + some are smaller than max stored in its comm_reverse ------------------------------------------------------------------------- */ -void CommKokkos::reverse_comm(Bond *bond) +void CommKokkos::reverse_comm(Bond *bond, int size) { - CommBrick::reverse_comm(bond); + CommBrick::reverse_comm(bond, size); } /* ---------------------------------------------------------------------- reverse communication invoked by a Compute - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_reverse from Compute + size > 0 -> Compute passes max size per atom + the latter is only useful if Compute does several comm modes, + some are smaller than max stored in its comm_reverse ------------------------------------------------------------------------- */ -void CommKokkos::reverse_comm(Compute *compute) +void CommKokkos::reverse_comm(Compute *compute, int size) { k_sendlist.sync(); - CommBrick::reverse_comm(compute); + CommBrick::reverse_comm(compute, size); } /* ---------------------------------------------------------------------- forward communication invoked by a Dump - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_forward from Dump + size > 0 -> Dump passes max size per atom + the latter is only useful if Dump does several comm modes, + some are smaller than max stored in its comm_forward ------------------------------------------------------------------------- */ -void CommKokkos::forward_comm(Pair *pair) +void CommKokkos::forward_comm(Pair *pair, int size) { if (pair->execution_space == Host || forward_pair_comm_classic) { k_sendlist.sync(); - CommBrick::forward_comm(pair); + CommBrick::forward_comm(pair, size); } else { k_sendlist.sync(); - forward_comm_device(pair); + forward_comm_device(pair, size); } } /* ---------------------------------------------------------------------- */ template -void CommKokkos::forward_comm_device(Pair *pair) +void CommKokkos::forward_comm_device(Pair *pair, int size) { - int iswap,n; + int iswap,n,nsize; MPI_Request request; DAT::tdual_xfloat_1d k_buf_tmp; - int nsize = pair->comm_forward; + if (size) nsize = size; + else nsize = pair->comm_forward; + KokkosBase* pairKKBase = dynamic_cast(pair); int nmax = max_buf_pair; @@ -623,21 +645,21 @@ void CommKokkos::grow_buf_fix(int n) { /* ---------------------------------------------------------------------- */ -void CommKokkos::reverse_comm(Pair *pair) +void CommKokkos::reverse_comm(Pair *pair, int size) { if (pair->execution_space == Host || !pair->reverse_comm_device || reverse_pair_comm_classic) { k_sendlist.sync(); - CommBrick::reverse_comm(pair); + CommBrick::reverse_comm(pair, size); } else { k_sendlist.sync(); - reverse_comm_device(pair); + reverse_comm_device(pair, size); } } /* ---------------------------------------------------------------------- */ template -void CommKokkos::reverse_comm_device(Pair *pair) +void CommKokkos::reverse_comm_device(Pair *pair, int size) { int iswap,n; MPI_Request request; @@ -645,7 +667,8 @@ void CommKokkos::reverse_comm_device(Pair *pair) KokkosBase* pairKKBase = dynamic_cast(pair); - int nsize = MAX(pair->comm_reverse,pair->comm_reverse_off); + if (size) nsize = size; + else nsize = MAX(pair->comm_reverse,pair->comm_reverse_off); int nmax = max_buf_pair; for (iswap = 0; iswap < nswap; iswap++) { @@ -702,18 +725,18 @@ void CommKokkos::reverse_comm_device(Pair *pair) /* ---------------------------------------------------------------------- */ -void CommKokkos::forward_comm(Dump *dump) +void CommKokkos::forward_comm(Dump *dump, int size) { k_sendlist.sync(); - CommBrick::forward_comm(dump); + CommBrick::forward_comm(dump, size); } /* ---------------------------------------------------------------------- */ -void CommKokkos::reverse_comm(Dump *dump) +void CommKokkos::reverse_comm(Dump *dump, int size) { k_sendlist.sync(); - CommBrick::reverse_comm(dump); + CommBrick::reverse_comm(dump, size); } /* ---------------------------------------------------------------------- diff --git a/src/KOKKOS/comm_kokkos.h b/src/KOKKOS/comm_kokkos.h index 4fb4dfbe29..42941ff517 100644 --- a/src/KOKKOS/comm_kokkos.h +++ b/src/KOKKOS/comm_kokkos.h @@ -45,24 +45,24 @@ class CommKokkos : public CommBrick { void exchange() override; // move atoms to new procs void borders() override; // setup list of atoms to comm - void forward_comm(class Pair *) override; // forward comm from a Pair - void reverse_comm(class Pair *) override; // reverse comm from a Pair - void forward_comm(class Bond *) override; // forward comm from a Bond - void reverse_comm(class Bond *) override; // reverse comm from a Bond - void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix - void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix - void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix - void forward_comm(class Compute *) override; // forward from a Compute - void reverse_comm(class Compute *) override; // reverse from a Compute - void forward_comm(class Dump *) override; // forward comm from a Dump - void reverse_comm(class Dump *) override; // reverse comm from a Dump + void forward_comm(class Pair *, int size = 0) override; // forward comm from a Pair + void reverse_comm(class Pair *, int size = 0) override; // reverse comm from a Pair + void forward_comm(class Bond *, int size = 0) override; // forward comm from a Bond + void reverse_comm(class Bond *, int size = 0) override; // reverse comm from a Bond + void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix + void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix + void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix + void forward_comm(class Compute *, int size = 0) override; // forward from a Compute + void reverse_comm(class Compute *, int size = 0) override; // reverse from a Compute + void forward_comm(class Dump *, int size = 0) override; // forward comm from a Dump + void reverse_comm(class Dump *, int size = 0) override; // reverse comm from a Dump void forward_comm_array(int, double **) override; // forward comm of array template void forward_comm_device(); template void reverse_comm_device(); - template void forward_comm_device(Pair *pair); - template void reverse_comm_device(Pair *pair); + template void forward_comm_device(Pair *pair, int size=0); + template void reverse_comm_device(Pair *pair, int size=0); template void forward_comm_device(Fix *fix, int size=0); template void exchange_device(); template void borders_device(); diff --git a/src/KOKKOS/comm_tiled_kokkos.cpp b/src/KOKKOS/comm_tiled_kokkos.cpp index 2e4ca30bed..69c5e8f847 100644 --- a/src/KOKKOS/comm_tiled_kokkos.cpp +++ b/src/KOKKOS/comm_tiled_kokkos.cpp @@ -417,7 +417,11 @@ void CommTiledKokkos::borders() /* ---------------------------------------------------------------------- forward communication invoked by a Pair - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_forward from Fix + size > 0 -> Fix passes max size per atom + the latter is only useful if Fix does several comm modes, + some are smaller than max stored in its comm_forward ------------------------------------------------------------------------- */ void CommTiledKokkos::forward_comm(Pair *pair) @@ -427,32 +431,44 @@ void CommTiledKokkos::forward_comm(Pair *pair) /* ---------------------------------------------------------------------- reverse communication invoked by a Pair - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_reverse from Pair + size > 0 -> Pair passes max size per atom + the latter is only useful if Pair does several comm modes, + some are smaller than max stored in its comm_reverse ------------------------------------------------------------------------- */ -void CommTiledKokkos::reverse_comm(Pair *pair) +void CommTiledKokkos::reverse_comm(Pair *pair, int size) { - CommTiled::reverse_comm(pair); + CommTiled::reverse_comm(pair, size); } /* ---------------------------------------------------------------------- forward communication invoked by a Bond - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_forward from Bond + size > 0 -> Bond passes max size per atom + the latter is only useful if Bond does several comm modes, + some are smaller than max stored in its comm_forward ------------------------------------------------------------------------- */ -void CommTiledKokkos::forward_comm(Bond *bond) +void CommTiledKokkos::forward_comm(Bond *bond, int size) { - CommTiled::forward_comm(bond); + CommTiled::forward_comm(bond, size); } /* ---------------------------------------------------------------------- reverse communication invoked by a Bond - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_reverse from Bond + size > 0 -> Bond passes max size per atom + the latter is only useful if Bond does several comm modes, + some are smaller than max stored in its comm_reverse ------------------------------------------------------------------------- */ -void CommTiledKokkos::reverse_comm(Bond *bond) +void CommTiledKokkos::reverse_comm(Bond *bond, int size) { - CommTiled::reverse_comm(bond); + CommTiled::reverse_comm(bond, size); } /* ---------------------------------------------------------------------- @@ -466,21 +482,21 @@ void CommTiledKokkos::reverse_comm(Bond *bond) void CommTiledKokkos::forward_comm(Fix *fix, int size) { - CommTiled::forward_comm(fix,size); + CommTiled::forward_comm(fix, size); } /* ---------------------------------------------------------------------- reverse communication invoked by a Fix size/nsize used only to set recv buffer limit - size = 0 (default) -> use comm_forward from Fix + size = 0 (default) -> use comm_reverse from Fix size > 0 -> Fix passes max size per atom the latter is only useful if Fix does several comm modes, - some are smaller than max stored in its comm_forward + some are smaller than max stored in its comm_reverse ------------------------------------------------------------------------- */ void CommTiledKokkos::reverse_comm(Fix *fix, int size) { - CommTiled::reverse_comm(fix,size); + CommTiled::reverse_comm(fix, size); } /* ---------------------------------------------------------------------- @@ -497,42 +513,58 @@ void CommTiledKokkos::reverse_comm_variable(Fix *fix) /* ---------------------------------------------------------------------- forward communication invoked by a Compute - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_forward from Compute + size > 0 -> Compute passes max size per atom + the latter is only useful if Compute does several comm modes, + some are smaller than max stored in its comm_forward ------------------------------------------------------------------------- */ -void CommTiledKokkos::forward_comm(Compute *compute) +void CommTiledKokkos::forward_comm(Compute *compute, int size) { - CommTiled::forward_comm(compute); + CommTiled::forward_comm(compute, size); } /* ---------------------------------------------------------------------- reverse communication invoked by a Compute - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_reverse from Compute + size > 0 -> Compute passes max size per atom + the latter is only useful if Compute does several comm modes, + some are smaller than max stored in its comm_reverse ------------------------------------------------------------------------- */ -void CommTiledKokkos::reverse_comm(Compute *compute) +void CommTiledKokkos::reverse_comm(Compute *compute, int size) { - CommTiled::reverse_comm(compute); + CommTiled::reverse_comm(compute, size); } /* ---------------------------------------------------------------------- forward communication invoked by a Dump - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_forward from Dump + size > 0 -> Dump passes max size per atom + the latter is only useful if Dump does several comm modes, + some are smaller than max stored in its comm_forward ------------------------------------------------------------------------- */ -void CommTiledKokkos::forward_comm(Dump *dump) +void CommTiledKokkos::forward_comm(Dump *dump, int size) { - CommTiled::forward_comm(dump); + CommTiled::forward_comm(dump, size); } /* ---------------------------------------------------------------------- reverse communication invoked by a Dump - nsize used only to set recv buffer limit + size/nsize used only to set recv buffer limit + size = 0 (default) -> use comm_reverse from Dump + size > 0 -> Dump passes max size per atom + the latter is only useful if Dump does several comm modes, + some are smaller than max stored in its comm_reverse ------------------------------------------------------------------------- */ -void CommTiledKokkos::reverse_comm(Dump *dump) +void CommTiledKokkos::reverse_comm(Dump *dump, int size) { - CommTiled::reverse_comm(dump); + CommTiled::reverse_comm(dump, size); } /* ---------------------------------------------------------------------- diff --git a/src/KOKKOS/comm_tiled_kokkos.h b/src/KOKKOS/comm_tiled_kokkos.h index 9033714796..ef226489c8 100644 --- a/src/KOKKOS/comm_tiled_kokkos.h +++ b/src/KOKKOS/comm_tiled_kokkos.h @@ -46,17 +46,17 @@ class CommTiledKokkos : public CommTiled { void exchange() override; // move atoms to new procs void borders() override; // setup list of atoms to comm - void forward_comm(class Pair *) override; // forward comm from a Pair - void reverse_comm(class Pair *) override; // reverse comm from a Pair - void forward_comm(class Bond *) override; // forward comm from a Bond - void reverse_comm(class Bond *) override; // reverse comm from a Bond - void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix - void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix - void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix - void forward_comm(class Compute *) override; // forward from a Compute - void reverse_comm(class Compute *) override; // reverse from a Compute - void forward_comm(class Dump *) override; // forward comm from a Dump - void reverse_comm(class Dump *) override; // reverse comm from a Dump + void forward_comm(class Pair *, int size = 0) override; // forward comm from a Pair + void reverse_comm(class Pair *, int size = 0) override; // reverse comm from a Pair + void forward_comm(class Bond *, int size = 0) override; // forward comm from a Bond + void reverse_comm(class Bond *, int size = 0) override; // reverse comm from a Bond + void forward_comm(class Fix *, int size = 0) override; // forward comm from a Fix + void reverse_comm(class Fix *, int size = 0) override; // reverse comm from a Fix + void reverse_comm_variable(class Fix *) override; // variable size reverse comm from a Fix + void forward_comm(class Compute *, int size = 0) override; // forward from a Compute + void reverse_comm(class Compute *, int size = 0) override; // reverse from a Compute + void forward_comm(class Dump *, int size = 0) override; // forward comm from a Dump + void reverse_comm(class Dump *, int size = 0) override; // reverse comm from a Dump void forward_comm_array(int, double **) override; // forward comm of array