diff --git a/src/MISC/fix_imd.cpp b/src/MISC/fix_imd.cpp index 9192f8a24c..6f2bb53a8c 100644 --- a/src/MISC/fix_imd.cpp +++ b/src/MISC/fix_imd.cpp @@ -76,7 +76,6 @@ negotiate an appropriate license for such distribution." #endif #include -#include using namespace LAMMPS_NS; using namespace FixConst; @@ -461,7 +460,6 @@ MPI_Datatype MPI_CommData; FixIMD::FixIMD(LAMMPS *lmp, int narg, char **arg) : Fix(lmp, narg, arg) { - if (narg < 4) error->all(FLERR,"Illegal fix imd command"); @@ -508,7 +506,9 @@ FixIMD::FixIMD(LAMMPS *lmp, int narg, char **arg) : vel_flag = utils::logical(FLERR, arg[iarg+1], false, lmp); } else if (0 == strcmp(arg[iarg], "forces")) { force_flag = utils::logical(FLERR, arg[iarg+1], false, lmp); - } else error->all(FLERR,"Unknown fix imd parameter"); + } else { + error->all(FLERR,"Unknown fix imd parameter"); + } ++iarg; ++iarg; } @@ -542,7 +542,6 @@ FixIMD::FixIMD(LAMMPS *lmp, int narg, char **arg) : imdsinfo->energies = false; } - bigint n = group->count(igroup); if (n > MAXSMALLINT) error->all(FLERR,"Too many atoms for fix imd"); num_coords = static_cast (n); @@ -568,25 +567,30 @@ FixIMD::FixIMD(LAMMPS *lmp, int narg, char **arg) : idmap = nullptr; rev_idmap = nullptr; + if (imd_version == 3) { + msglen = 0; + if (imdsinfo->time) { + msglen += 24+IMDHEADERSIZE; + } + if (imdsinfo->box) { + msglen += 9*4+IMDHEADERSIZE; + } + if (imdsinfo->coords) { + msglen += 3*4*num_coords+IMDHEADERSIZE; + } + if (imdsinfo->velocities) { + msglen += 3*4*num_coords+IMDHEADERSIZE; + } + if (imdsinfo->forces) { + msglen += 3*4*num_coords+IMDHEADERSIZE; + } + msgdata = new char[msglen]; + } + else { + msglen = 3*sizeof(float)*num_coords+IMDHEADERSIZE; + msgdata = new char[msglen]; + } - msglen = 0; - if (imdsinfo->time) { - msglen += 24+IMDHEADERSIZE; - } - if (imdsinfo->box) { - msglen += 9*4+IMDHEADERSIZE; - } - if (imdsinfo->coords) { - msglen += 3*4*num_coords+IMDHEADERSIZE; - } - if (imdsinfo->velocities) { - msglen += 3*4*num_coords+IMDHEADERSIZE; - } - if (imdsinfo->forces) { - msglen += 3*4*num_coords+IMDHEADERSIZE; - } - msgdata = new char[msglen]; - if (me == 0) { /* set up incoming socket on MPI rank 0. */ imdsock_init(); @@ -638,6 +642,7 @@ FixIMD::FixIMD(LAMMPS *lmp, int narg, char **arg) : *********************************/ FixIMD::~FixIMD() { + #if defined(LAMMPS_ASYNC_IMD) if (me == 0) { pthread_mutex_lock(&write_mutex); @@ -657,12 +662,12 @@ FixIMD::~FixIMD() memory->destroy(vel_data); memory->destroy(force_data); - memory->destroy(msgdata); + delete[] msgdata; memory->destroy(recv_force_buf); taginthash_destroy(hashtable); delete hashtable; free(rev_idmap); - free(imdsinfo); + delete imdsinfo; // close sockets imdsock_shutdown(clientsock); imdsock_destroy(clientsock); @@ -763,6 +768,103 @@ int FixIMD::reconnect() /* wait for IMD client (e.g. VMD) to respond, initialize communication * buffers and collect tag/id maps. */ void FixIMD::setup(int) +{ + if (imd_version == 2) { + setup_v2(); + } + else { + setup_v3(); + } +} + +void FixIMD::setup_v2() { + /* nme: number of atoms in group on this MPI task + * nmax: max number of atoms in group across all MPI tasks + * nlocal: all local atoms + */ + int i,j; + int nmax,nme,nlocal; + int *mask = atom->mask; + tagint *tag = atom->tag; + nlocal = atom->nlocal; + nme=0; + for (i=0; i < nlocal; ++i) + if (mask[i] & groupbit) ++nme; + + MPI_Allreduce(&nme,&nmax,1,MPI_INT,MPI_MAX,world); + memory->destroy(coord_data); + maxbuf = nmax*size_one; + coord_data = (void *) memory->smalloc(maxbuf,"imd:coord_data"); + + connect_msg = 1; + reconnect(); + MPI_Bcast(&imd_inactive, 1, MPI_INT, 0, world); + MPI_Bcast(&imd_terminate, 1, MPI_INT, 0, world); + if (imd_terminate) + error->all(FLERR,"LAMMPS terminated on error in setting up IMD connection."); + + /* initialize and build hashtable. */ + auto hashtable=new taginthash_t; + taginthash_init(hashtable, num_coords); + idmap = (void *)hashtable; + + int tmp, ndata; + auto buf = static_cast(coord_data); + + if (me == 0) { + MPI_Status status; + MPI_Request request; + auto taglist = new tagint[num_coords]; + int numtag=0; /* counter to map atom tags to a 0-based consecutive index list */ + + for (i=0; i < nlocal; ++i) { + if (mask[i] & groupbit) { + taglist[numtag] = tag[i]; + ++numtag; + } + } + + /* loop over procs to receive remote data */ + for (i=1; i < comm->nprocs; ++i) { + MPI_Irecv(coord_data, maxbuf, MPI_BYTE, i, 0, world, &request); + MPI_Send(&tmp, 0, MPI_INT, i, 0, world); + MPI_Wait(&request, &status); + MPI_Get_count(&status, MPI_BYTE, &ndata); + ndata /= size_one; + + for (j=0; j < ndata; ++j) { + taglist[numtag] = buf[j].tag; + ++numtag; + } + } + + /* sort list of tags by value to have consistently the + * same list when running in parallel and build hash table. */ + id_sort(taglist, 0, num_coords-1); + for (i=0; i < num_coords; ++i) { + taginthash_insert(hashtable, taglist[i], i); + } + delete[] taglist; + + /* generate reverse index-to-tag map for communicating + * IMD forces back to the proper atoms */ + rev_idmap=taginthash_keys(hashtable); + } else { + nme=0; + for (i=0; i < nlocal; ++i) { + if (mask[i] & groupbit) { + buf[nme].tag = tag[i]; + ++nme; + } + } + /* blocking receive to wait until it is our turn to send data. */ + MPI_Recv(&tmp, 0, MPI_INT, 0, 0, world, MPI_STATUS_IGNORE); + MPI_Rsend(coord_data, nme*size_one, MPI_BYTE, 0, 0, world); + } + + } + +void FixIMD::setup_v3() { /* nme: number of atoms in group on this MPI task * nmax: max number of atoms in group across all MPI tasks @@ -807,11 +909,24 @@ void FixIMD::setup(int) idmap = (void *)hashtable; int tmp, ndata; - auto buf = static_cast(coord_data); + + struct commdata *buf = nullptr; + if (imdsinfo->coords) { + buf = static_cast(coord_data); + } + else if (imdsinfo->velocities) { + buf = static_cast(vel_data); + } + else if (imdsinfo->forces) { + buf = static_cast(force_data); + } if (me == 0) { - std::vector statuses; - std::vector requests; + if (buf == nullptr) { + return; + } + MPI_Status status; + MPI_Request request; auto taglist = new tagint[num_coords]; int numtag=0; /* counter to map atom tags to a 0-based consecutive index list */ @@ -824,39 +939,15 @@ void FixIMD::setup(int) /* loop over procs to receive remote data */ for (i=1; i < comm->nprocs; ++i) { - /* We're assuming tags are consistent across x,v,f */ - bool tag_recvd = false; - statuses.clear(); - requests.clear(); - - if (imdsinfo->coords) { - requests.push_back(MPI_Request()); - MPI_Irecv(coord_data, maxbuf, MPI_BYTE, i, 0, world, &requests.back()); - } - if (imdsinfo->velocities) { - requests.push_back(MPI_Request()); - MPI_Irecv(vel_data, maxbuf, MPI_BYTE, i, 0, world, &requests.back()); - } - if (imdsinfo->forces) { - requests.push_back(MPI_Request()); - MPI_Irecv(vel_data, maxbuf, MPI_BYTE, i, 0, world, &requests.back()); - } - statuses.resize(requests.size()); + MPI_Irecv(coord_data, maxbuf, MPI_BYTE, i, 0, world, &request); MPI_Send(&tmp, 0, MPI_INT, i, 0, world); - MPI_Waitall(requests.size(), requests.data(), statuses.data()); + MPI_Wait(&request, &status); + MPI_Get_count(&status, MPI_BYTE, &ndata); + ndata /= size_one; - for (size_t k=0; k < statuses.size(); ++k) { - if (!tag_recvd) { - MPI_Get_count(&statuses[k], MPI_BYTE, &ndata); - ndata /= size_one; - for (j=0; j < ndata; ++j) { - taglist[numtag] = buf[j].tag; - ++numtag; - } - tag_recvd = true; - } else { - break; - } + for (j=0; j < ndata; ++j) { + taglist[numtag] = buf[j].tag; + ++numtag; } } @@ -881,19 +972,10 @@ void FixIMD::setup(int) } /* blocking receive to wait until it is our turn to send data. */ MPI_Recv(&tmp, 0, MPI_INT, 0, 0, world, MPI_STATUS_IGNORE); - if (imdsinfo->coords) { - MPI_Rsend(coord_data, nme*size_one, MPI_BYTE, 0, 0, world); - } - if (imdsinfo->velocities) { - MPI_Rsend(vel_data, nme*size_one, MPI_BYTE, 0, 0, world); - } - if (imdsinfo->forces) { - MPI_Rsend(force_data, nme*size_one, MPI_BYTE, 0, 0, world); - } + MPI_Rsend(coord_data, nme*size_one, MPI_BYTE, 0, 0, world); } } - /* worker threads for asynchronous i/o */ #if defined(LAMMPS_ASYNC_IMD) /* c bindings wrapper */ diff --git a/src/MISC/fix_imd.h b/src/MISC/fix_imd.h index 6b8778dbf0..03d242f32b 100644 --- a/src/MISC/fix_imd.h +++ b/src/MISC/fix_imd.h @@ -116,6 +116,8 @@ class FixIMD : public Fix { char *msgdata; private: + void setup_v2(); + void setup_v3(); void handle_step_v2(); void handle_client_input_v3(); void handle_output_v3();