added rendezvous via all2all

This commit is contained in:
Steve Plimpton
2019-01-23 14:49:52 -07:00
committed by Axel Kohlmeyer
parent 981f12ebeb
commit fc002e30d3
7 changed files with 566 additions and 105 deletions

View File

@ -729,34 +729,78 @@ void Comm::ring(int n, int nper, void *inbuf, int messtag,
/* ----------------------------------------------------------------------
rendezvous communication operation
three stages:
first Irregular converts inbuf from caller decomp to rvous decomp
first comm sends inbuf from caller decomp to rvous decomp
callback operates on data in rendevous decomp
last Irregular converts outbuf from rvous decomp back to caller decomp
second comm sends outbuf from rvous decomp back to caller decomp
inputs:
n = # of input datums
proclist = proc that owns each input datum in rendezvous decomposition
inbuf = list of input datums
insize = size in bytes of each input datum
which = perform (0) irregular or (1) MPI_All2allv communication
n = # of datums in inbuf
inbuf = vector of input datums
insize = byte size of each input datum
inorder = 0 for inbuf in random proc order, 1 for datums ordered by proc
procs: inorder 0 = proc to send each datum to, 1 = # of datums/proc,
callback = caller function to invoke in rendezvous decomposition
takes input datums, returns output datums
outorder = same as inorder, but for datums returned by callback()
ptr = pointer to caller class, passed to callback()
outputs:
nout = # of output datums (function return)
outbuf = list of output datums
outsize = size in bytes of each output datum
outbuf = vector of output datums
outsize = byte size of each output datum
callback inputs:
nrvous = # of rvous decomp datums in inbuf_rvous
inbuf_rvous = vector of rvous decomp input datums
ptr = pointer to caller class
callback outputs:
nrvous_out = # of rvous decomp output datums (function return)
flag = 0 for no second comm, 1 for outbuf_rvous = inbuf_rvous,
2 for second comm with new outbuf_rvous
procs_rvous = outorder 0 = proc to send each datum to, 1 = # of datums/proc
allocated
outbuf_rvous = vector of rvous decomp output datums
NOTE: could use MPI_INT or MPI_DOUBLE insead of MPI_CHAR
to avoid checked-for overflow in MPI_Alltoallv?
------------------------------------------------------------------------- */
int Comm::rendezvous(int n, int *proclist, char *inbuf, int insize,
int (*callback)(int, char *, int &, int *&, char *&, void *),
char *&outbuf, int outsize, void *ptr)
int Comm::
rendezvous(int which, int n, char *inbuf, int insize,
int inorder, int *procs,
int (*callback)(int, char *, int &, int *&, char *&, void *),
int outorder, char *&outbuf, int outsize, void *ptr)
{
// comm inbuf from caller decomposition to rendezvous decomposition
int nout;
if (which == 0)
nout = rendezvous_irregular(n,inbuf,insize,inorder,procs,callback,
outorder,outbuf,outsize,ptr);
else
nout = rendezvous_all2all(n,inbuf,insize,inorder,procs,callback,
outorder,outbuf,outsize,ptr);
return nout;
}
/* ---------------------------------------------------------------------- */
int Comm::
rendezvous_irregular(int n, char *inbuf, int insize, int inorder, int *procs,
int (*callback)(int, char *, int &, int *&, char *&, void *),
int outorder, char *&outbuf,
int outsize, void *ptr)
{
// irregular comm of inbuf from caller decomp to rendezvous decomp
Irregular *irregular = new Irregular(lmp);
int n_rvous = irregular->create_data(n,proclist); // add sort
char *inbuf_rvous = (char *) memory->smalloc((bigint) n_rvous*insize,
"rendezvous:inbuf_rvous");
int nrvous;
if (inorder) nrvous = irregular->create_data_grouped(n,procs);
else nrvous = irregular->create_data(n,procs);
char *inbuf_rvous = (char *) memory->smalloc((bigint) nrvous*insize,
"rendezvous:inbuf");
irregular->exchange_data(inbuf,insize,inbuf_rvous);
bigint irregular1_bytes = 0; //irregular->irregular_bytes;
irregular->destroy_data();
delete irregular;
@ -764,29 +808,253 @@ int Comm::rendezvous(int n, int *proclist, char *inbuf, int insize,
// callback() allocates/populates proclist_rvous and outbuf_rvous
int flag;
int *proclist_rvous;
int *procs_rvous;
char *outbuf_rvous;
int nout_rvous =
callback(n_rvous,inbuf_rvous,flag,proclist_rvous,outbuf_rvous,ptr);
int nrvous_out = callback(nrvous,inbuf_rvous,flag,
procs_rvous,outbuf_rvous,ptr);
if (flag != 1) memory->sfree(inbuf_rvous); // outbuf_rvous = inbuf_vous
if (flag == 0) return 0; // all nout_rvous are 0, no 2nd irregular
if (flag == 0) return 0; // all nout_rvous are 0, no 2nd comm stage
// comm outbuf from rendezvous decomposition back to caller
// irregular comm of outbuf from rendezvous decomp back to caller decomp
// caller will free outbuf
irregular = new Irregular(lmp);
int nout = irregular->create_data(nout_rvous,proclist_rvous);
outbuf = (char *) memory->smalloc((bigint) nout*outsize,"rendezvous:outbuf");
int nout;
if (outorder)
nout = irregular->create_data_grouped(nrvous_out,procs_rvous);
else nout = irregular->create_data(nrvous_out,procs_rvous);
outbuf = (char *) memory->smalloc((bigint) nout*outsize,
"rendezvous:outbuf");
irregular->exchange_data(outbuf_rvous,outsize,outbuf);
bigint irregular2_bytes = 0; //irregular->irregular_bytes;
irregular->destroy_data();
delete irregular;
memory->destroy(proclist_rvous);
memory->destroy(procs_rvous);
memory->sfree(outbuf_rvous);
// approximate memory tally
bigint rvous_bytes = 0;
rvous_bytes += n*insize; // inbuf
rvous_bytes += nout*outsize; // outbuf
rvous_bytes += nrvous*insize; // inbuf_rvous
rvous_bytes += nrvous_out*outsize; // outbuf_rvous
rvous_bytes += nrvous_out*sizeof(int); // procs_rvous
rvous_bytes += MAX(irregular1_bytes,irregular2_bytes); // max of 2 comms
// return number of output datums
return nout;
}
/* ---------------------------------------------------------------------- */
int Comm::
rendezvous_all2all(int n, char *inbuf, int insize, int inorder, int *procs,
int (*callback)(int, char *, int &, int *&, char *&, void *),
int outorder, char *&outbuf, int outsize, void *ptr)
{
int iproc;
bigint all2all1_bytes,all2all2_bytes;
int *sendcount,*sdispls,*recvcount,*rdispls;
int *procs_a2a;
bigint *offsets;
char *inbuf_a2a,*outbuf_a2a;
// create procs and inbuf for All2all if necesary
if (!inorder) {
memory->create(procs_a2a,nprocs,"rendezvous:procs");
inbuf_a2a = (char *) memory->smalloc((bigint) n*insize,
"rendezvous:inbuf");
memory->create(offsets,nprocs,"rendezvous:offsets");
for (int i = 0; i < nprocs; i++) procs_a2a[i] = 0;
for (int i = 0; i < n; i++) procs_a2a[procs[i]]++;
offsets[0] = 0;
for (int i = 1; i < nprocs; i++)
offsets[i] = offsets[i-1] + insize*procs_a2a[i-1];
bigint offset = 0;
for (int i = 0; i < n; i++) {
iproc = procs[i];
memcpy(&inbuf_a2a[offsets[iproc]],&inbuf[offset],insize);
offsets[iproc] += insize;
offset += insize;
}
all2all1_bytes = nprocs*sizeof(int) + nprocs*sizeof(bigint) + n*insize;
} else {
procs_a2a = procs;
inbuf_a2a = inbuf;
all2all1_bytes = 0;
}
// create args for MPI_Alltoallv() on input data
memory->create(sendcount,nprocs,"rendezvous:sendcount");
memcpy(sendcount,procs_a2a,nprocs*sizeof(int));
memory->create(recvcount,nprocs,"rendezvous:recvcount");
MPI_Alltoall(sendcount,1,MPI_INT,recvcount,1,MPI_INT,world);
memory->create(sdispls,nprocs,"rendezvous:sdispls");
memory->create(rdispls,nprocs,"rendezvous:rdispls");
sdispls[0] = rdispls[0] = 0;
for (int i = 1; i < nprocs; i++) {
sdispls[i] = sdispls[i-1] + sendcount[i-1];
rdispls[i] = rdispls[i-1] + recvcount[i-1];
}
int nrvous = rdispls[nprocs-1] + recvcount[nprocs-1];
// test for overflow of input data due to imbalance or insize
// means that individual sdispls or rdispls values overflow
int overflow = 0;
if ((bigint) n*insize > MAXSMALLINT) overflow = 1;
if ((bigint) nrvous*insize > MAXSMALLINT) overflow = 1;
int overflowall;
MPI_Allreduce(&overflow,&overflowall,1,MPI_INT,MPI_MAX,world);
if (overflowall) error->all(FLERR,"Overflow input size in rendezvous_a2a");
for (int i = 0; i < nprocs; i++) {
sendcount[i] *= insize;
sdispls[i] *= insize;
recvcount[i] *= insize;
rdispls[i] *= insize;
}
// all2all comm of inbuf from caller decomp to rendezvous decomp
char *inbuf_rvous = (char *) memory->smalloc((bigint) nrvous*insize,
"rendezvous:inbuf");
MPI_Alltoallv(inbuf_a2a,sendcount,sdispls,MPI_CHAR,
inbuf_rvous,recvcount,rdispls,MPI_CHAR,world);
if (!inorder) {
memory->destroy(procs_a2a);
memory->sfree(inbuf_a2a);
memory->destroy(offsets);
}
// peform rendezvous computation via callback()
// callback() allocates/populates proclist_rvous and outbuf_rvous
int flag;
int *procs_rvous;
char *outbuf_rvous;
int nrvous_out = callback(nrvous,inbuf_rvous,flag,
procs_rvous,outbuf_rvous,ptr);
if (flag != 1) memory->sfree(inbuf_rvous); // outbuf_rvous = inbuf_vous
if (flag == 0) return 0; // all nout_rvous are 0, no 2nd irregular
// create procs and outbuf for All2all if necesary
if (!outorder) {
memory->create(procs_a2a,nprocs,"rendezvous_a2a:procs");
outbuf_a2a = (char *) memory->smalloc((bigint) nrvous_out*outsize,
"rendezvous:outbuf");
memory->create(offsets,nprocs,"rendezvous:offsets");
for (int i = 0; i < nprocs; i++) procs_a2a[i] = 0;
for (int i = 0; i < nrvous_out; i++) procs_a2a[procs_rvous[i]]++;
offsets[0] = 0;
for (int i = 1; i < nprocs; i++)
offsets[i] = offsets[i-1] + outsize*procs_a2a[i-1];
bigint offset = 0;
for (int i = 0; i < nrvous_out; i++) {
iproc = procs_rvous[i];
memcpy(&outbuf_a2a[offsets[iproc]],&outbuf_rvous[offset],outsize);
offsets[iproc] += outsize;
offset += outsize;
}
all2all2_bytes = nprocs*sizeof(int) + nprocs*sizeof(bigint) +
nrvous_out*outsize;
} else {
procs_a2a = procs_rvous;
outbuf_a2a = outbuf_rvous;
all2all2_bytes = 0;
}
// comm outbuf from rendezvous decomposition back to caller
memcpy(sendcount,procs_a2a,nprocs*sizeof(int));
MPI_Alltoall(sendcount,1,MPI_INT,recvcount,1,MPI_INT,world);
sdispls[0] = rdispls[0] = 0;
for (int i = 1; i < nprocs; i++) {
sdispls[i] = sdispls[i-1] + sendcount[i-1];
rdispls[i] = rdispls[i-1] + recvcount[i-1];
}
int nout = rdispls[nprocs-1] + recvcount[nprocs-1];
// test for overflow of outbuf due to imbalance or outsize
// means that individual sdispls or rdispls values overflow
overflow = 0;
if ((bigint) nrvous*outsize > MAXSMALLINT) overflow = 1;
if ((bigint) nout*outsize > MAXSMALLINT) overflow = 1;
MPI_Allreduce(&overflow,&overflowall,1,MPI_INT,MPI_MAX,world);
if (overflowall) error->all(FLERR,"Overflow output in rendezvous_a2a");
for (int i = 0; i < nprocs; i++) {
sendcount[i] *= outsize;
sdispls[i] *= outsize;
recvcount[i] *= outsize;
rdispls[i] *= outsize;
}
// all2all comm of outbuf from rendezvous decomp back to caller decomp
// caller will free outbuf
outbuf = (char *) memory->smalloc((bigint) nout*outsize,"rendezvous:outbuf");
MPI_Alltoallv(outbuf_a2a,sendcount,sdispls,MPI_CHAR,
outbuf,recvcount,rdispls,MPI_CHAR,world);
memory->destroy(procs_rvous);
memory->sfree(outbuf_rvous);
if (!outorder) {
memory->destroy(procs_a2a);
memory->sfree(outbuf_a2a);
memory->destroy(offsets);
}
// clean up
memory->destroy(sendcount);
memory->destroy(recvcount);
memory->destroy(sdispls);
memory->destroy(rdispls);
// approximate memory tally
bigint rvous_bytes = 0;
rvous_bytes += n*insize; // inbuf
rvous_bytes += nout*outsize; // outbuf
rvous_bytes += nrvous*insize; // inbuf_rvous
rvous_bytes += nrvous_out*outsize; // outbuf_rvous
rvous_bytes += nrvous_out*sizeof(int); // procs_rvous
rvous_bytes += 4*nprocs*sizeof(int); // all2all vectors
rvous_bytes += MAX(all2all1_bytes,all2all2_bytes); // reorder ops
// return number of datums
return nout;