refactor balance and fix balance to std::string to avoid buffer overflows
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
// clang-format off
|
||||
/* ----------------------------------------------------------------------
|
||||
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
|
||||
https://www.lammps.org/, Sandia National Laboratories
|
||||
@ -21,11 +20,10 @@
|
||||
|
||||
#include "balance.h"
|
||||
|
||||
#include "update.h"
|
||||
#include "atom.h"
|
||||
#include "neighbor.h"
|
||||
#include "comm.h"
|
||||
#include "domain.h"
|
||||
#include "error.h"
|
||||
#include "fix_store_atom.h"
|
||||
#include "force.h"
|
||||
#include "imbalance.h"
|
||||
@ -37,9 +35,10 @@
|
||||
#include "irregular.h"
|
||||
#include "memory.h"
|
||||
#include "modify.h"
|
||||
#include "neighbor.h"
|
||||
#include "pair.h"
|
||||
#include "rcb.h"
|
||||
#include "error.h"
|
||||
#include "update.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
@ -48,17 +47,16 @@ using namespace LAMMPS_NS;
|
||||
|
||||
double EPSNEIGH = 1.0e-3;
|
||||
|
||||
enum{XYZ,SHIFT,BISECTION};
|
||||
enum{NONE,UNIFORM,USER};
|
||||
enum{X,Y,Z};
|
||||
enum { XYZ, SHIFT, BISECTION };
|
||||
enum { NONE, UNIFORM, USER };
|
||||
enum { X, Y, Z };
|
||||
|
||||
// clang-format off
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
Balance::Balance(LAMMPS *lmp) : Command(lmp)
|
||||
{
|
||||
MPI_Comm_rank(world,&me);
|
||||
MPI_Comm_size(world,&nprocs);
|
||||
|
||||
user_xsplit = user_ysplit = user_zsplit = nullptr;
|
||||
shift_allocate = 0;
|
||||
proccost = allproccost = nullptr;
|
||||
@ -118,7 +116,7 @@ void Balance::command(int narg, char **arg)
|
||||
if (domain->box_exist == 0)
|
||||
error->all(FLERR,"Balance command before simulation box is defined");
|
||||
|
||||
if (me == 0) utils::logmesg(lmp,"Balancing ...\n");
|
||||
if (comm->me == 0) utils::logmesg(lmp,"Balancing ...\n");
|
||||
|
||||
// parse required arguments
|
||||
|
||||
@ -196,10 +194,10 @@ void Balance::command(int narg, char **arg)
|
||||
|
||||
} else if (strcmp(arg[iarg],"shift") == 0) {
|
||||
if (style != -1) error->all(FLERR,"Illegal balance command");
|
||||
if (iarg+4 > narg) error->all(FLERR,"Illegal balance command");
|
||||
if (iarg+4 > narg) utils::missing_cmd_args(FLERR, "balance shift", error);
|
||||
style = SHIFT;
|
||||
if (strlen(arg[iarg+1]) > BSTR_SIZE) error->all(FLERR,"Illegal balance command");
|
||||
strncpy(bstr,arg[iarg+1],BSTR_SIZE+1);
|
||||
bstr = arg[iarg+1];
|
||||
if (bstr.size() > BSTR_SIZE) error->all(FLERR,"Illegal balance shift command");
|
||||
nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
|
||||
if (nitermax <= 0) error->all(FLERR,"Illegal balance command");
|
||||
stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp);
|
||||
@ -235,7 +233,7 @@ void Balance::command(int narg, char **arg)
|
||||
}
|
||||
|
||||
if (style == SHIFT) {
|
||||
const int blen=strlen(bstr);
|
||||
const int blen = bstr.size();
|
||||
for (int i = 0; i < blen; i++) {
|
||||
if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z')
|
||||
error->all(FLERR,"Balance shift string is invalid");
|
||||
@ -336,7 +334,7 @@ void Balance::command(int narg, char **arg)
|
||||
|
||||
if (style == SHIFT) {
|
||||
comm->layout = Comm::LAYOUT_NONUNIFORM;
|
||||
shift_setup_static(bstr);
|
||||
shift_setup_static(bstr.c_str());
|
||||
niter = shift();
|
||||
}
|
||||
|
||||
@ -393,7 +391,7 @@ void Balance::command(int narg, char **arg)
|
||||
|
||||
// stats output
|
||||
|
||||
if (me == 0) {
|
||||
if (comm->me == 0) {
|
||||
std::string mesg = fmt::format(" rebalancing time: {:.3f} seconds\n",
|
||||
platform::walltime()-start_time);
|
||||
mesg += fmt::format(" iteration count = {}\n",niter);
|
||||
@ -571,7 +569,7 @@ double Balance::imbalance_factor(double &maxcost)
|
||||
MPI_Allreduce(&mycost,&totalcost,1,MPI_DOUBLE,MPI_SUM,world);
|
||||
|
||||
double imbalance = 1.0;
|
||||
if (maxcost > 0.0) imbalance = maxcost / (totalcost/nprocs);
|
||||
if (maxcost > 0.0) imbalance = maxcost / (totalcost / comm->nprocs);
|
||||
return imbalance;
|
||||
}
|
||||
|
||||
@ -719,12 +717,12 @@ int *Balance::bisection()
|
||||
set rho = 0 for static balancing
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
void Balance::shift_setup_static(char *str)
|
||||
void Balance::shift_setup_static(const char *str)
|
||||
{
|
||||
shift_allocate = 1;
|
||||
|
||||
memory->create(proccost,nprocs,"balance:proccost");
|
||||
memory->create(allproccost,nprocs,"balance:allproccost");
|
||||
memory->create(proccost,comm->nprocs,"balance:proccost");
|
||||
memory->create(allproccost,comm->nprocs,"balance:allproccost");
|
||||
|
||||
ndim = strlen(str);
|
||||
bdim = new int[ndim];
|
||||
@ -771,7 +769,7 @@ void Balance::shift_setup_static(char *str)
|
||||
set rho = 1 to do dynamic balancing after call to shift_setup_static()
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
void Balance::shift_setup(char *str, int nitermax_in, double thresh_in)
|
||||
void Balance::shift_setup(const char *str, int nitermax_in, double thresh_in)
|
||||
{
|
||||
shift_setup_static(str);
|
||||
nitermax = nitermax_in;
|
||||
@ -871,7 +869,7 @@ int Balance::shift()
|
||||
// iterate until balanced
|
||||
|
||||
#ifdef BALANCE_DEBUG
|
||||
if (me == 0) debug_shift_output(idim,0,np,split);
|
||||
if (comm->me == 0) debug_shift_output(idim,0,np,split);
|
||||
#endif
|
||||
|
||||
int doneflag;
|
||||
@ -882,7 +880,7 @@ int Balance::shift()
|
||||
niter++;
|
||||
|
||||
#ifdef BALANCE_DEBUG
|
||||
if (me == 0) debug_shift_output(idim,m+1,np,split);
|
||||
if (comm->me == 0) debug_shift_output(idim,m+1,np,split);
|
||||
if (outflag) dumpout(update->ntimestep);
|
||||
#endif
|
||||
|
||||
@ -1137,7 +1135,7 @@ double Balance::imbalance_splits()
|
||||
int ny = comm->procgrid[1];
|
||||
int nz = comm->procgrid[2];
|
||||
|
||||
for (int i = 0; i < nprocs; i++) proccost[i] = 0.0;
|
||||
for (int i = 0; i < comm->nprocs; i++) proccost[i] = 0.0;
|
||||
|
||||
double **x = atom->x;
|
||||
int nlocal = atom->nlocal;
|
||||
@ -1162,17 +1160,17 @@ double Balance::imbalance_splits()
|
||||
|
||||
// one proc's particles may map to many partitions, so must Allreduce
|
||||
|
||||
MPI_Allreduce(proccost,allproccost,nprocs,MPI_DOUBLE,MPI_SUM,world);
|
||||
MPI_Allreduce(proccost,allproccost,comm->nprocs,MPI_DOUBLE,MPI_SUM,world);
|
||||
|
||||
double maxcost = 0.0;
|
||||
double totalcost = 0.0;
|
||||
for (int i = 0; i < nprocs; i++) {
|
||||
for (int i = 0; i < comm->nprocs; i++) {
|
||||
maxcost = MAX(maxcost,allproccost[i]);
|
||||
totalcost += allproccost[i];
|
||||
}
|
||||
|
||||
double imbalance = 1.0;
|
||||
if (maxcost > 0.0) imbalance = maxcost / (totalcost/nprocs);
|
||||
if (maxcost > 0.0) imbalance = maxcost / (totalcost/comm->nprocs);
|
||||
return imbalance;
|
||||
}
|
||||
|
||||
@ -1188,6 +1186,7 @@ void Balance::dumpout(bigint tstep)
|
||||
{
|
||||
int dimension = domain->dimension;
|
||||
int triclinic = domain->triclinic;
|
||||
int nprocs = comm->nprocs;
|
||||
|
||||
// Allgather each proc's sub-box
|
||||
// could use Gather, but that requires MPI to alloc memory
|
||||
@ -1209,7 +1208,7 @@ void Balance::dumpout(bigint tstep)
|
||||
memory->create(boxall,nprocs,6,"balance:dumpout");
|
||||
MPI_Allgather(box,6,MPI_DOUBLE,&boxall[0][0],6,MPI_DOUBLE,world);
|
||||
|
||||
if (me) {
|
||||
if (comm->me) {
|
||||
memory->destroy(boxall);
|
||||
return;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user