diff --git a/src/comm.cpp b/src/comm.cpp index 80e5388989..681b54fe97 100644 --- a/src/comm.cpp +++ b/src/comm.cpp @@ -1505,15 +1505,26 @@ void Comm::forward_comm_array(int n, double **array) } /* ---------------------------------------------------------------------- - reverse communication invoked by a Dump + communicate inbuf around full ring of processors with messtag + nbytes = size of inbuf = n datums, each of size nper bytes + use callback() to allow caller to process each proc's inbuf ------------------------------------------------------------------------- */ -void Comm::ring(int n, int nmax, char *buf, char *bufcopy, int messtag, +void Comm::ring(int n, int nper, void *inbuf, int messtag, void (*callback)(int, char *)) { MPI_Request request; MPI_Status status; + int nbytes = n*nper; + int maxbytes; + MPI_Allreduce(&nbytes,&maxbytes,1,MPI_INT,MPI_MAX,world); + + char *buf,*bufcopy; + memory->create(buf,maxbytes,"comm:buf"); + memory->create(bufcopy,maxbytes,"comm:bufcopy"); + memcpy(buf,inbuf,nbytes); + int next = me + 1; int prev = me - 1; if (next == nprocs) next = 0; @@ -1521,14 +1532,17 @@ void Comm::ring(int n, int nmax, char *buf, char *bufcopy, int messtag, for (int loop = 0; loop < nprocs; loop++) { if (me != next) { - MPI_Irecv(bufcopy,nmax,MPI_CHAR,prev,messtag,world,&request); - MPI_Send(buf,n,MPI_CHAR,next,messtag,world); + MPI_Irecv(bufcopy,maxbytes,MPI_CHAR,prev,messtag,world,&request); + MPI_Send(buf,nbytes,MPI_CHAR,next,messtag,world); MPI_Wait(&request,&status); - MPI_Get_count(&status,MPI_CHAR,&n); - memcpy(buf,bufcopy,n); - callback(n,buf); + MPI_Get_count(&status,MPI_CHAR,&nbytes); + memcpy(buf,bufcopy,nbytes); + callback(nbytes/nper,buf); } } + + memory->destroy(buf); + memory->destroy(bufcopy); } /* ---------------------------------------------------------------------- diff --git a/src/comm.h b/src/comm.h index 60c8e6c38c..17c8b0ac50 100644 --- a/src/comm.h +++ b/src/comm.h @@ -61,8 +61,7 @@ class Comm : protected Pointers { virtual void reverse_comm_dump(class Dump *); // reverse comm from a Dump void forward_comm_array(int, double **); // forward comm of array - void ring(int, int, char *, char *, int, - void (*)(int, char *)); // ring communication + void ring(int, int, void *, int, void (*)(int, char *)); // ring comm virtual void set(int, char **); // set communication style void set_processors(int, char **); // set 3d processor grid attributes