ENH: improved UPstream gather/scatter functionality

- added UPstream::allGatherValues() with a direct call to MPI_Allgather.
  This enables possible benefit from a variety of internal algorithms
  and simplifies the caller

    Old:
        labelList nPerProc
        (
            UPstream::listGatherValues<label>(patch_.size(), myComm)
        );
        Pstream::broadcast(nPerProc, myComm);

    New:

        const labelList nPerProc
        (
            UPstream::allGatherValues<label>(patch_.size(), myComm)
        );

- Pstream::allGatherList uses MPI_Allgather for contiguous values
  instead of the hand-rolled tree walking involved with
  gatherList/scatterList.

-
- simplified the calling parameters for mpiGather/mpiScatter.

  Since send/recv data types are identical, the send/recv count
  is also always identical. Eliminates the possibility of any
  discrepancies.

  Since this is a low-level call, it does not affect much code.
  Currently just Foam::profilingPstream and a UPstream internal.

BUG: call to MPI_Allgather had hard-coded MPI_BYTE (not the data type)

- a latent bug since it is currently only passed char data anyhow
This commit is contained in:
Mark Olesen
2023-05-17 14:26:05 +02:00
parent b687c4927c
commit 5eebc75845
15 changed files with 140 additions and 145 deletions

View File

@ -266,9 +266,8 @@ int main(int argc, char *argv[])
UPstream::mpiGather
(
myDigest.cdata_bytes(), // Send
SHA1Digest::max_size(), // Num send per proc
digests.data_bytes(), // Recv
SHA1Digest::max_size(), // Num recv per proc
SHA1Digest::max_size(), // Num send/recv per rank
UPstream::commGlobal()
);
}

View File

@ -62,10 +62,10 @@ void writeProcStats
)
{
// Determine surface bounding boxes, faces, points
List<treeBoundBox> surfBb(Pstream::nProcs());
surfBb[Pstream::myProcNo()] = treeBoundBox(s.points());
Pstream::gatherList(surfBb);
List<treeBoundBox> surfBb
(
UPstream::listGatherValues<treeBoundBox>(treeBoundBox(s.points()))
);
labelList nPoints(UPstream::listGatherValues<label>(s.points().size()));
labelList nFaces(UPstream::listGatherValues<label>(s.size()));

View File

@ -350,20 +350,19 @@ void Foam::Pstream::allGatherList
{
if (UPstream::is_parallel(comm))
{
// TBD
// if (std::is_arithmetic<T>::value) // or is_contiguous ?
// {
// if (values.size() != UPstream::nProcs(comm))
// {
// FatalErrorInFunction
// << "Size of list:" << values.size()
// << " != number of processors:"
// << UPstream::nProcs(comm)
// << Foam::abort(FatalError);
// }
// UPstream::mpiAllGather(values.data_bytes(), sizeof(T), comm);
// return;
// }
if (is_contiguous<T>::value)
{
if (values.size() < UPstream::nProcs(comm))
{
FatalErrorInFunction
<< "List of values is too small:" << values.size()
<< " vs numProcs:" << UPstream::nProcs(comm) << nl
<< Foam::abort(FatalError);
}
UPstream::mpiAllGather(values.data_bytes(), sizeof(T), comm);
return;
}
const auto& comms = UPstream::whichCommunication(comm);

View File

@ -91,9 +91,8 @@ static List<int> getHostGroupIds(const label parentCommunicator)
UPstream::mpiGather
(
myDigest.cdata_bytes(), // Send
SHA1Digest::max_size(), // Num send per proc
digests.data_bytes(), // Recv
SHA1Digest::max_size(), // Num recv per proc
SHA1Digest::max_size(), // Num send/recv data per rank
parentCommunicator
);

View File

@ -1011,10 +1011,10 @@ public:
( \
/*! On rank: individual value to send */ \
const Native* sendData, \
int sendCount, \
/*! On master: receive buffer with all values */ \
Native* recvData, \
int recvCount, \
/*! Number of send/recv data per rank. Globally consistent! */ \
int count, \
const label communicator = worldComm \
); \
\
@ -1023,10 +1023,10 @@ public:
( \
/*! On master: send buffer with all values */ \
const Native* sendData, \
int sendCount, \
/*! On rank: individual value to receive */ \
Native* recvData, \
int recvCount, \
/*! Number of send/recv data per rank. Globally consistent! */ \
int count, \
const label communicator = worldComm \
); \
\
@ -1034,7 +1034,9 @@ public:
/*! Send data from proc slot, receive into all slots */ \
static void mpiAllGather \
( \
/*! On all ranks: the base of the data locations */ \
Native* allData, \
/*! Number of send/recv data per rank. Globally consistent! */ \
int count, \
const label communicator = worldComm \
);
@ -1080,6 +1082,15 @@ public:
// Gather single, contiguous value(s)
//- Allgather individual values into list locations.
// The returned list has size nProcs, identical on all ranks.
template<class T>
static List<T> allGatherValues
(
const T& localValue,
const label communicator = worldComm
);
//- Gather individual values into list locations.
// On master list length == nProcs, otherwise zero length.
// \n

View File

@ -27,6 +27,42 @@ License
// * * * * * * * * * * * * * * * Member Functions * * * * * * * * * * * * * //
template<class T>
Foam::List<T> Foam::UPstream::allGatherValues
(
const T& localValue,
const label comm
)
{
if (!is_contiguous<T>::value)
{
FatalErrorInFunction
<< "Cannot all-gather values for non-contiguous types" << endl
<< Foam::abort(FatalError);
}
List<T> allValues;
if (UPstream::is_parallel(comm))
{
allValues.resize(UPstream::nProcs(comm));
allValues[UPstream::myProcNo(comm)] = localValue;
UPstream::mpiAllGather(allValues.data_bytes(), sizeof(T), comm);
}
else
{
// non-parallel: return own value
// TBD: only when UPstream::is_rank(comm) as well?
allValues.resize(1);
allValues[0] = localValue;
}
return allValues;
}
template<class T>
Foam::List<T> Foam::UPstream::listGatherValues
(
@ -44,30 +80,25 @@ Foam::List<T> Foam::UPstream::listGatherValues
List<T> allValues;
const label nproc =
(
UPstream::is_parallel(comm) ? UPstream::nProcs(comm) : 1
);
if (nproc > 1)
if (UPstream::is_parallel(comm))
{
if (UPstream::master(comm))
{
allValues.resize(nproc);
allValues.resize(UPstream::nProcs(comm));
}
UPstream::mpiGather
(
reinterpret_cast<const char*>(&localValue),
sizeof(T),
allValues.data_bytes(),
sizeof(T),
sizeof(T), // The send/recv size per rank
comm
);
}
else
{
// non-parallel: return own value
// TBD: only when UPstream::is_rank(comm) as well?
allValues.resize(1);
allValues[0] = localValue;
}
@ -91,15 +122,12 @@ T Foam::UPstream::listScatterValues
}
const label nproc =
(
UPstream::is_parallel(comm) ? UPstream::nProcs(comm) : 1
);
T localValue;
if (nproc > 1)
if (UPstream::is_parallel(comm))
{
const label nproc = UPstream::nProcs(comm);
if (UPstream::master(comm) && allValues.size() < nproc)
{
FatalErrorInFunction
@ -111,9 +139,8 @@ T Foam::UPstream::listScatterValues
UPstream::mpiScatter
(
allValues.cdata_bytes(),
sizeof(T),
reinterpret_cast<char*>(&localValue),
sizeof(T),
sizeof(T), // The send/recv size per rank
comm
);
}

View File

@ -229,9 +229,8 @@ void Foam::profilingPstream::report(const int reportLevel)
UPstream::mpiGather
(
procValues.cdata_bytes(), // Send
procValues.size_bytes(), // Num send per proc
allTimes.data_bytes(), // Recv
procValues.size_bytes(), // Num recv per proc
procValues.size_bytes(), // Num send/recv data per rank
UPstream::commWorld()
);
}
@ -249,9 +248,8 @@ void Foam::profilingPstream::report(const int reportLevel)
UPstream::mpiGather
(
procValues.cdata_bytes(), // Send
procValues.size_bytes(), // Num send per proc
allCounts.data_bytes(), // Recv
procValues.size_bytes(), // Num recv per proc
procValues.size_bytes(), // Num send/recv data per rank
UPstream::commWorld()
);
}

View File

@ -6,7 +6,7 @@
\\/ M anipulation |
-------------------------------------------------------------------------------
Copyright (C) 2011-2016 OpenFOAM Foundation
Copyright (C) 2018-2022 OpenCFD Ltd.
Copyright (C) 2018-2023 OpenCFD Ltd.
-------------------------------------------------------------------------------
License
This file is part of OpenFOAM.
@ -27,7 +27,6 @@ License
\*---------------------------------------------------------------------------*/
#include "globalIndex.H"
#include "labelRange.H"
// * * * * * * * * * * * * * Static Member Functions * * * * * * * * * * * * //
@ -200,14 +199,13 @@ void Foam::globalIndex::reset
{
labelList localLens;
const label len = Pstream::nProcs(comm);
const label len = UPstream::nProcs(comm);
if (len)
{
if (parallel && UPstream::parRun())
if (parallel && UPstream::parRun()) // or UPstream::is_parallel()
{
localLens = UPstream::listGatherValues(localSize, comm);
Pstream::broadcast(localLens, comm);
localLens = UPstream::allGatherValues(localSize, comm);
}
else
{

View File

@ -55,7 +55,6 @@ namespace Foam
// Forward Declarations
class globalIndex;
class labelRange;
Istream& operator>>(Istream& is, globalIndex& gi);
Ostream& operator<<(Ostream& os, const globalIndex& gi);

View File

@ -27,7 +27,6 @@ License
\*---------------------------------------------------------------------------*/
#include "ListOps.H"
#include "labelRange.H"
// * * * * * * * * * * * * * * * * Constructors * * * * * * * * * * * * * * //
@ -121,7 +120,7 @@ inline Foam::globalIndex::globalIndex
const label comm
)
{
// one-sided: non-master only
// one-sided: non-master sizes only
reset
(
UPstream::listGatherValues
@ -401,7 +400,7 @@ inline void Foam::globalIndex::reset
const label comm
)
{
// one-sided: non-master only
// one-sided: non-master sizes only
reset
(
UPstream::listGatherValues

View File

@ -36,28 +36,24 @@ License
void Foam::UPstream::mpiGather \
( \
const Native* sendData, \
int sendCount, \
\
Native* recvData, \
int recvCount, \
int count, \
const label comm \
) \
{ \
std::memmove(recvData, sendData, recvCount*sizeof(Native)); \
std::memmove(recvData, sendData, count*sizeof(Native)); \
} \
\
\
void Foam::UPstream::mpiScatter \
( \
const Native* sendData, \
int sendCount, \
\
Native* recvData, \
int recvCount, \
int count, \
const label comm \
) \
{ \
std::memmove(recvData, sendData, recvCount*sizeof(Native)); \
std::memmove(recvData, sendData, count*sizeof(Native)); \
} \
\
\

View File

@ -38,16 +38,14 @@ License
void Foam::UPstream::mpiGather \
( \
const Native* sendData, \
int sendCount, \
\
Native* recvData, \
int recvCount, \
int count, \
const label comm \
) \
{ \
PstreamDetail::gather \
( \
sendData, sendCount, recvData, recvCount, \
sendData, recvData, count, \
TaggedType, comm \
); \
} \
@ -56,16 +54,14 @@ void Foam::UPstream::mpiGather \
void Foam::UPstream::mpiScatter \
( \
const Native* sendData, \
int sendCount, \
\
Native* recvData, \
int recvCount, \
int count, \
const label comm \
) \
{ \
PstreamDetail::scatter \
( \
sendData, sendCount, recvData, recvCount, \
sendData, recvData, count, \
TaggedType, comm \
); \
} \

View File

@ -164,12 +164,9 @@ void allToAllConsensus
template<class Type>
void gather
(
const Type* sendData,
int sendCount,
Type* recvData, // Ignored on non-root rank
int recvCount, // Ignored on non-root rank
const Type* sendData, // Local send value
Type* recvData, // On master: recv buffer. Ignored elsewhere
int count, // Per rank send/recv count. Globally consistent!
MPI_Datatype datatype, // The send/recv data type
const label comm, // Communicator
UPstream::Request* req = nullptr, // Non-null for non-blocking
@ -181,12 +178,9 @@ void gather
template<class Type>
void scatter
(
const Type* sendData, // Ignored on non-root rank
int sendCount, // Ignored on non-root rank
Type* recvData,
int recvCount,
const Type* sendData, // On master: send buffer. Ignored elsewhere
Type* recvData, // Local recv value
int count, // Per rank send/recv count. Globally consistent!
MPI_Datatype datatype, // The send/recv data type
const label comm, // Communicator
UPstream::Request* req = nullptr, // Non-null for non-blocking

View File

@ -924,14 +924,12 @@ template<class Type>
void Foam::PstreamDetail::gather
(
const Type* sendData,
int sendCount,
Type* recvData,
int recvCount,
int count,
MPI_Datatype datatype,
const label comm,
const label comm,
UPstream::Request* req,
label* requestID
)
@ -940,13 +938,16 @@ void Foam::PstreamDetail::gather
const bool immediate = (req || requestID);
if (!UPstream::is_rank(comm))
if (!UPstream::is_rank(comm) || !count)
{
return;
}
if (!UPstream::is_parallel(comm))
{
std::memmove(recvData, sendData, recvCount*sizeof(Type));
if (recvData)
{
std::memmove(recvData, sendData, count*sizeof(Type));
}
return;
}
@ -963,7 +964,7 @@ void Foam::PstreamDetail::gather
Pout<< "** MPI_Gather (blocking):";
}
Pout<< " numProc:" << numProc
<< " recvCount:" << recvCount
<< " count:" << count
<< " with comm:" << comm
<< " warnComm:" << UPstream::warnComm
<< endl;
@ -982,13 +983,9 @@ void Foam::PstreamDetail::gather
(
MPI_Igather
(
const_cast<Type*>(sendData),
sendCount,
datatype,
recvData,
recvCount,
datatype,
0, // (root rank) == UPstream::masterNo()
const_cast<Type*>(sendData), count, datatype,
recvData, count, datatype,
0, // root: UPstream::masterNo()
PstreamGlobals::MPICommunicators_[comm],
&request
)
@ -996,8 +993,7 @@ void Foam::PstreamDetail::gather
{
FatalErrorInFunction
<< "MPI_Igather [comm: " << comm << "] failed."
<< " sendCount " << sendCount
<< " recvCount " << recvCount
<< " count:" << count << nl
<< Foam::abort(FatalError);
}
@ -1013,21 +1009,16 @@ void Foam::PstreamDetail::gather
(
MPI_Gather
(
const_cast<Type*>(sendData),
sendCount,
datatype,
recvData,
recvCount,
datatype,
0, // (root rank) == UPstream::masterNo()
const_cast<Type*>(sendData), count, datatype,
recvData, count, datatype,
0, // root: UPstream::masterNo()
PstreamGlobals::MPICommunicators_[comm]
)
)
{
FatalErrorInFunction
<< "MPI_Gather [comm: " << comm << "] failed."
<< " sendCount " << sendCount
<< " recvCount " << recvCount
<< " count:" << count << nl
<< Foam::abort(FatalError);
}
@ -1040,14 +1031,12 @@ template<class Type>
void Foam::PstreamDetail::scatter
(
const Type* sendData,
int sendCount,
Type* recvData,
int recvCount,
int count,
MPI_Datatype datatype,
const label comm,
const label comm,
UPstream::Request* req,
label* requestID
)
@ -1056,13 +1045,16 @@ void Foam::PstreamDetail::scatter
const bool immediate = (req || requestID);
if (!UPstream::is_rank(comm))
if (!UPstream::is_rank(comm) || !count)
{
return;
}
if (!UPstream::is_parallel(comm))
{
std::memmove(recvData, sendData, recvCount*sizeof(Type));
if (recvData)
{
std::memmove(recvData, sendData, count*sizeof(Type));
}
return;
}
@ -1079,7 +1071,7 @@ void Foam::PstreamDetail::scatter
Pout<< "** MPI_Scatter (blocking):";
}
Pout<< " numProc:" << numProc
<< " recvCount:" << recvCount
<< " count:" << count
<< " with comm:" << comm
<< " warnComm:" << UPstream::warnComm
<< endl;
@ -1098,13 +1090,9 @@ void Foam::PstreamDetail::scatter
(
MPI_Iscatter
(
const_cast<Type*>(sendData),
sendCount,
datatype,
recvData,
recvCount,
datatype,
0, // (root rank) == UPstream::masterNo()
const_cast<Type*>(sendData), count, datatype,
recvData, count, datatype,
0, // root: UPstream::masterNo()
PstreamGlobals::MPICommunicators_[comm],
&request
)
@ -1112,8 +1100,7 @@ void Foam::PstreamDetail::scatter
{
FatalErrorInFunction
<< "MPI_Iscatter [comm: " << comm << "] failed."
<< " sendCount " << sendCount
<< " recvCount " << recvCount
<< " count:" << count << nl
<< Foam::abort(FatalError);
}
@ -1129,21 +1116,16 @@ void Foam::PstreamDetail::scatter
(
MPI_Scatter
(
const_cast<Type*>(sendData),
sendCount,
datatype,
recvData,
recvCount,
datatype,
0, // (root rank) == UPstream::masterNo()
const_cast<Type*>(sendData), count, datatype,
recvData, count, datatype,
0, // root: UPstream::masterNo()
PstreamGlobals::MPICommunicators_[comm]
)
)
{
FatalErrorInFunction
<< "MPI_Scatter [comm: " << comm << "] failed."
<< " sendCount " << sendCount
<< " recvCount " << recvCount
<< " count:" << count << nl
<< Foam::abort(FatalError);
}
@ -1483,8 +1465,8 @@ void Foam::PstreamDetail::allGather
(
MPI_Iallgather
(
MPI_IN_PLACE, count, MPI_BYTE,
allData, count, MPI_BYTE,
MPI_IN_PLACE, count, datatype,
allData, count, datatype,
PstreamGlobals::MPICommunicators_[comm],
&request
)
@ -1507,8 +1489,8 @@ void Foam::PstreamDetail::allGather
(
MPI_Allgather
(
MPI_IN_PLACE, count, MPI_BYTE,
allData, count, MPI_BYTE,
MPI_IN_PLACE, count, datatype,
allData, count, datatype,
PstreamGlobals::MPICommunicators_[comm]
)
)

View File

@ -240,17 +240,15 @@ void Foam::mappedPatchBase::collectSamples
{
labelList procToWorldIndex
(
UPstream::listGatherValues<label>(mySampleWorld, myComm)
UPstream::allGatherValues<label>(mySampleWorld, myComm)
);
labelList nPerProc
(
UPstream::listGatherValues<label>(patch_.size(), myComm)
UPstream::allGatherValues<label>(patch_.size(), myComm)
);
Pstream::broadcasts(myComm, procToWorldIndex, nPerProc);
patchFaceWorlds.setSize(patchFaces.size());
patchFaceProcs.setSize(patchFaces.size());
patchFaceWorlds.resize(patchFaces.size());
patchFaceProcs.resize(patchFaces.size());
label sampleI = 0;
forAll(nPerProc, proci)