diff --git a/src/OpenFOAM/db/IOstreams/Pstreams/UPstream.H b/src/OpenFOAM/db/IOstreams/Pstreams/UPstream.H index 395ac02160..8fcd0ba7dc 100644 --- a/src/OpenFOAM/db/IOstreams/Pstreams/UPstream.H +++ b/src/OpenFOAM/db/IOstreams/Pstreams/UPstream.H @@ -900,7 +900,7 @@ public: ( \ const Native* sendData, \ int sendCount, \ - char* recvData, \ + Native* recvData, \ int recvCount, \ const label communicator = worldComm \ ); \ @@ -910,10 +910,19 @@ public: ( \ const Native* sendData, \ int sendCount, \ - char* recvData, \ + Native* recvData, \ int recvCount, \ const label communicator = worldComm \ ); \ + \ + /*! \brief Gather/scatter identically-sized \c Native data */ \ + /*! Send data from proc slot, receive into all slots */ \ + static void mpiAllGather \ + ( \ + Native* allData, \ + int count, \ + const label communicator = worldComm \ + ); Pstream_CommonRoutines(char); diff --git a/src/Pstream/dummy/UPstream.C b/src/Pstream/dummy/UPstream.C index 2368e9e899..c9f74eca1d 100644 --- a/src/Pstream/dummy/UPstream.C +++ b/src/Pstream/dummy/UPstream.C @@ -97,7 +97,7 @@ Foam::UPstream::probeMessage const UPstream::commsTypes commsType, const int fromProcNo, const int tag, - const label comm + const label communicator ) { return std::pair(-1, 0); diff --git a/src/Pstream/dummy/UPstreamGatherScatter.C b/src/Pstream/dummy/UPstreamGatherScatter.C index 3a5f54e901..7d2e2d63f9 100644 --- a/src/Pstream/dummy/UPstreamGatherScatter.C +++ b/src/Pstream/dummy/UPstreamGatherScatter.C @@ -5,7 +5,7 @@ \\ / A nd | www.openfoam.com \\/ M anipulation | ------------------------------------------------------------------------------- - Copyright (C) 2022 OpenCFD Ltd. + Copyright (C) 2022-2023 OpenCFD Ltd. ------------------------------------------------------------------------------- License This file is part of OpenFOAM. @@ -32,6 +32,7 @@ License #undef Pstream_CommonRoutines #define Pstream_CommonRoutines(Native) \ + \ void Foam::UPstream::mpiGather \ ( \ const Native* sendData, \ @@ -57,7 +58,16 @@ void Foam::UPstream::mpiScatter \ ) \ { \ std::memmove(recvData, sendData, recvCount*sizeof(Native)); \ -} +} \ + \ + \ +void Foam::UPstream::mpiAllGather \ +( \ + Native* allData, \ + int count, \ + const label comm \ +) \ +{} Pstream_CommonRoutines(char); diff --git a/src/Pstream/mpi/UPstream.C b/src/Pstream/mpi/UPstream.C index 19587ccd90..2be5913e71 100644 --- a/src/Pstream/mpi/UPstream.C +++ b/src/Pstream/mpi/UPstream.C @@ -748,7 +748,7 @@ Foam::UPstream::probeMessage const UPstream::commsTypes commsType, const int fromProcNo, const int tag, - const label comm + const label communicator ) { std::pair result(-1, 0); @@ -775,7 +775,7 @@ Foam::UPstream::probeMessage ( source, tag, - PstreamGlobals::MPICommunicators_[comm], + PstreamGlobals::MPICommunicators_[communicator], &status ) ) @@ -799,7 +799,7 @@ Foam::UPstream::probeMessage ( source, tag, - PstreamGlobals::MPICommunicators_[comm], + PstreamGlobals::MPICommunicators_[communicator], &flag, &status ) diff --git a/src/Pstream/mpi/UPstreamGatherScatter.C b/src/Pstream/mpi/UPstreamGatherScatter.C index 4937be2e1d..180a98b9a3 100644 --- a/src/Pstream/mpi/UPstreamGatherScatter.C +++ b/src/Pstream/mpi/UPstreamGatherScatter.C @@ -5,7 +5,7 @@ \\ / A nd | www.openfoam.com \\/ M anipulation | ------------------------------------------------------------------------------- - Copyright (C) 2022 OpenCFD Ltd. + Copyright (C) 2022-2023 OpenCFD Ltd. ------------------------------------------------------------------------------- License This file is part of OpenFOAM. @@ -34,6 +34,7 @@ License #undef Pstream_CommonRoutines #define Pstream_CommonRoutines(Native, TaggedType) \ + \ void Foam::UPstream::mpiGather \ ( \ const Native* sendData, \ @@ -67,6 +68,21 @@ void Foam::UPstream::mpiScatter \ sendData, sendCount, recvData, recvCount, \ TaggedType, comm \ ); \ +} \ + \ + \ +void Foam::UPstream::mpiAllGather \ +( \ + Native* allData, \ + int count, \ + const label comm \ +) \ +{ \ + PstreamDetail::allGather \ + ( \ + allData, count, \ + TaggedType, comm \ + ); \ } Pstream_CommonRoutines(char, MPI_BYTE); diff --git a/src/Pstream/mpi/UPstreamWrapping.H b/src/Pstream/mpi/UPstreamWrapping.H index 125da4defd..92791503e0 100644 --- a/src/Pstream/mpi/UPstreamWrapping.H +++ b/src/Pstream/mpi/UPstreamWrapping.H @@ -230,6 +230,20 @@ void scatterv ); +// MPI_Allgather or MPI_Iallgather +template +void allGather +( + Type* allData, // The send/recv data + int count, // The send/recv count per element + + MPI_Datatype datatype, // The send/recv data type + const label comm, // Communicator + UPstream::Request* req = nullptr, // Non-null for non-blocking + label* requestID = nullptr // (alternative to UPstream::Request) +); + + // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * // } // End namespace PstreamDetail diff --git a/src/Pstream/mpi/UPstreamWrappingTemplates.C b/src/Pstream/mpi/UPstreamWrappingTemplates.C index f19507e857..fba6933a17 100644 --- a/src/Pstream/mpi/UPstreamWrappingTemplates.C +++ b/src/Pstream/mpi/UPstreamWrappingTemplates.C @@ -1366,4 +1366,103 @@ void Foam::PstreamDetail::scatterv } +template +void Foam::PstreamDetail::allGather +( + Type* allData, + int count, + + MPI_Datatype datatype, + const label comm, + + UPstream::Request* req, + label* requestID +) +{ + PstreamGlobals::reset_request(req, requestID); + + const bool immediate = (req || requestID); + + if (!UPstream::parRun() || UPstream::nProcs(comm) < 2) + { + // Nothing to do - ignore + return; + } + + const label numProc = UPstream::nProcs(comm); + + if (UPstream::warnComm >= 0 && comm != UPstream::warnComm) + { + if (immediate) + { + Pout<< "** MPI_Iallgather (non-blocking):"; + } + else + { + Pout<< "** MPI_Allgather (blocking):"; + } + Pout<< " numProc:" << numProc + << " with comm:" << comm + << " warnComm:" << UPstream::warnComm + << endl; + error::printStack(Pout); + } + + bool handled(false); + +#if defined(MPI_VERSION) && (MPI_VERSION >= 3) + // MPI-3 : eg, openmpi-1.7 (2013) and later + if (immediate) + { + profilingPstream::beginTiming(); + + handled = true; + MPI_Request request; + + if + ( + MPI_Iallgather + ( + MPI_IN_PLACE, count, MPI_BYTE, + allData, count, MPI_BYTE, + PstreamGlobals::MPICommunicators_[comm], + &request + ) + ) + { + FatalErrorInFunction + << "MPI_Iallgather [comm: " << comm << "] failed." + << Foam::abort(FatalError); + } + + PstreamGlobals::push_request(request, req, requestID); + profilingPstream::addRequestTime(); + } +#endif + + if (!handled) + { + profilingPstream::beginTiming(); + + if + ( + MPI_Allgather + ( + MPI_IN_PLACE, count, MPI_BYTE, + allData, count, MPI_BYTE, + PstreamGlobals::MPICommunicators_[comm] + ) + ) + { + FatalErrorInFunction + << "MPI_Allgather [comm: " << comm << "] failed." + << Foam::abort(FatalError); + } + + // Is actually gather/scatter but we can't split it apart + profilingPstream::addGatherTime(); + } +} + + // ************************************************************************* //