refactor balance and fix balance to std::string to avoid buffer overflows
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
// clang-format off
|
||||
/* ----------------------------------------------------------------------
|
||||
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
|
||||
https://www.lammps.org/, Sandia National Laboratories
|
||||
@ -21,11 +20,10 @@
|
||||
|
||||
#include "balance.h"
|
||||
|
||||
#include "update.h"
|
||||
#include "atom.h"
|
||||
#include "neighbor.h"
|
||||
#include "comm.h"
|
||||
#include "domain.h"
|
||||
#include "error.h"
|
||||
#include "fix_store_atom.h"
|
||||
#include "force.h"
|
||||
#include "imbalance.h"
|
||||
@ -37,9 +35,10 @@
|
||||
#include "irregular.h"
|
||||
#include "memory.h"
|
||||
#include "modify.h"
|
||||
#include "neighbor.h"
|
||||
#include "pair.h"
|
||||
#include "rcb.h"
|
||||
#include "error.h"
|
||||
#include "update.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
@ -48,17 +47,16 @@ using namespace LAMMPS_NS;
|
||||
|
||||
double EPSNEIGH = 1.0e-3;
|
||||
|
||||
enum{XYZ,SHIFT,BISECTION};
|
||||
enum{NONE,UNIFORM,USER};
|
||||
enum{X,Y,Z};
|
||||
enum { XYZ, SHIFT, BISECTION };
|
||||
enum { NONE, UNIFORM, USER };
|
||||
enum { X, Y, Z };
|
||||
|
||||
// clang-format off
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
Balance::Balance(LAMMPS *lmp) : Command(lmp)
|
||||
{
|
||||
MPI_Comm_rank(world,&me);
|
||||
MPI_Comm_size(world,&nprocs);
|
||||
|
||||
user_xsplit = user_ysplit = user_zsplit = nullptr;
|
||||
shift_allocate = 0;
|
||||
proccost = allproccost = nullptr;
|
||||
@ -118,7 +116,7 @@ void Balance::command(int narg, char **arg)
|
||||
if (domain->box_exist == 0)
|
||||
error->all(FLERR,"Balance command before simulation box is defined");
|
||||
|
||||
if (me == 0) utils::logmesg(lmp,"Balancing ...\n");
|
||||
if (comm->me == 0) utils::logmesg(lmp,"Balancing ...\n");
|
||||
|
||||
// parse required arguments
|
||||
|
||||
@ -196,10 +194,10 @@ void Balance::command(int narg, char **arg)
|
||||
|
||||
} else if (strcmp(arg[iarg],"shift") == 0) {
|
||||
if (style != -1) error->all(FLERR,"Illegal balance command");
|
||||
if (iarg+4 > narg) error->all(FLERR,"Illegal balance command");
|
||||
if (iarg+4 > narg) utils::missing_cmd_args(FLERR, "balance shift", error);
|
||||
style = SHIFT;
|
||||
if (strlen(arg[iarg+1]) > BSTR_SIZE) error->all(FLERR,"Illegal balance command");
|
||||
strncpy(bstr,arg[iarg+1],BSTR_SIZE+1);
|
||||
bstr = arg[iarg+1];
|
||||
if (bstr.size() > BSTR_SIZE) error->all(FLERR,"Illegal balance shift command");
|
||||
nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
|
||||
if (nitermax <= 0) error->all(FLERR,"Illegal balance command");
|
||||
stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp);
|
||||
@ -235,7 +233,7 @@ void Balance::command(int narg, char **arg)
|
||||
}
|
||||
|
||||
if (style == SHIFT) {
|
||||
const int blen=strlen(bstr);
|
||||
const int blen = bstr.size();
|
||||
for (int i = 0; i < blen; i++) {
|
||||
if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z')
|
||||
error->all(FLERR,"Balance shift string is invalid");
|
||||
@ -336,7 +334,7 @@ void Balance::command(int narg, char **arg)
|
||||
|
||||
if (style == SHIFT) {
|
||||
comm->layout = Comm::LAYOUT_NONUNIFORM;
|
||||
shift_setup_static(bstr);
|
||||
shift_setup_static(bstr.c_str());
|
||||
niter = shift();
|
||||
}
|
||||
|
||||
@ -393,7 +391,7 @@ void Balance::command(int narg, char **arg)
|
||||
|
||||
// stats output
|
||||
|
||||
if (me == 0) {
|
||||
if (comm->me == 0) {
|
||||
std::string mesg = fmt::format(" rebalancing time: {:.3f} seconds\n",
|
||||
platform::walltime()-start_time);
|
||||
mesg += fmt::format(" iteration count = {}\n",niter);
|
||||
@ -571,7 +569,7 @@ double Balance::imbalance_factor(double &maxcost)
|
||||
MPI_Allreduce(&mycost,&totalcost,1,MPI_DOUBLE,MPI_SUM,world);
|
||||
|
||||
double imbalance = 1.0;
|
||||
if (maxcost > 0.0) imbalance = maxcost / (totalcost/nprocs);
|
||||
if (maxcost > 0.0) imbalance = maxcost / (totalcost / comm->nprocs);
|
||||
return imbalance;
|
||||
}
|
||||
|
||||
@ -719,12 +717,12 @@ int *Balance::bisection()
|
||||
set rho = 0 for static balancing
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
void Balance::shift_setup_static(char *str)
|
||||
void Balance::shift_setup_static(const char *str)
|
||||
{
|
||||
shift_allocate = 1;
|
||||
|
||||
memory->create(proccost,nprocs,"balance:proccost");
|
||||
memory->create(allproccost,nprocs,"balance:allproccost");
|
||||
memory->create(proccost,comm->nprocs,"balance:proccost");
|
||||
memory->create(allproccost,comm->nprocs,"balance:allproccost");
|
||||
|
||||
ndim = strlen(str);
|
||||
bdim = new int[ndim];
|
||||
@ -771,7 +769,7 @@ void Balance::shift_setup_static(char *str)
|
||||
set rho = 1 to do dynamic balancing after call to shift_setup_static()
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
void Balance::shift_setup(char *str, int nitermax_in, double thresh_in)
|
||||
void Balance::shift_setup(const char *str, int nitermax_in, double thresh_in)
|
||||
{
|
||||
shift_setup_static(str);
|
||||
nitermax = nitermax_in;
|
||||
@ -871,7 +869,7 @@ int Balance::shift()
|
||||
// iterate until balanced
|
||||
|
||||
#ifdef BALANCE_DEBUG
|
||||
if (me == 0) debug_shift_output(idim,0,np,split);
|
||||
if (comm->me == 0) debug_shift_output(idim,0,np,split);
|
||||
#endif
|
||||
|
||||
int doneflag;
|
||||
@ -882,7 +880,7 @@ int Balance::shift()
|
||||
niter++;
|
||||
|
||||
#ifdef BALANCE_DEBUG
|
||||
if (me == 0) debug_shift_output(idim,m+1,np,split);
|
||||
if (comm->me == 0) debug_shift_output(idim,m+1,np,split);
|
||||
if (outflag) dumpout(update->ntimestep);
|
||||
#endif
|
||||
|
||||
@ -1137,7 +1135,7 @@ double Balance::imbalance_splits()
|
||||
int ny = comm->procgrid[1];
|
||||
int nz = comm->procgrid[2];
|
||||
|
||||
for (int i = 0; i < nprocs; i++) proccost[i] = 0.0;
|
||||
for (int i = 0; i < comm->nprocs; i++) proccost[i] = 0.0;
|
||||
|
||||
double **x = atom->x;
|
||||
int nlocal = atom->nlocal;
|
||||
@ -1162,17 +1160,17 @@ double Balance::imbalance_splits()
|
||||
|
||||
// one proc's particles may map to many partitions, so must Allreduce
|
||||
|
||||
MPI_Allreduce(proccost,allproccost,nprocs,MPI_DOUBLE,MPI_SUM,world);
|
||||
MPI_Allreduce(proccost,allproccost,comm->nprocs,MPI_DOUBLE,MPI_SUM,world);
|
||||
|
||||
double maxcost = 0.0;
|
||||
double totalcost = 0.0;
|
||||
for (int i = 0; i < nprocs; i++) {
|
||||
for (int i = 0; i < comm->nprocs; i++) {
|
||||
maxcost = MAX(maxcost,allproccost[i]);
|
||||
totalcost += allproccost[i];
|
||||
}
|
||||
|
||||
double imbalance = 1.0;
|
||||
if (maxcost > 0.0) imbalance = maxcost / (totalcost/nprocs);
|
||||
if (maxcost > 0.0) imbalance = maxcost / (totalcost/comm->nprocs);
|
||||
return imbalance;
|
||||
}
|
||||
|
||||
@ -1188,6 +1186,7 @@ void Balance::dumpout(bigint tstep)
|
||||
{
|
||||
int dimension = domain->dimension;
|
||||
int triclinic = domain->triclinic;
|
||||
int nprocs = comm->nprocs;
|
||||
|
||||
// Allgather each proc's sub-box
|
||||
// could use Gather, but that requires MPI to alloc memory
|
||||
@ -1209,7 +1208,7 @@ void Balance::dumpout(bigint tstep)
|
||||
memory->create(boxall,nprocs,6,"balance:dumpout");
|
||||
MPI_Allgather(box,6,MPI_DOUBLE,&boxall[0][0],6,MPI_DOUBLE,world);
|
||||
|
||||
if (me) {
|
||||
if (comm->me) {
|
||||
memory->destroy(boxall);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ class Balance : public Command {
|
||||
void init_imbalance(int);
|
||||
void set_weights();
|
||||
double imbalance_factor(double &);
|
||||
void shift_setup(char *, int, double);
|
||||
void shift_setup(const char *, int, double);
|
||||
int shift();
|
||||
int *bisection();
|
||||
void dumpout(bigint);
|
||||
@ -49,8 +49,6 @@ class Balance : public Command {
|
||||
static constexpr int BSTR_SIZE = 3;
|
||||
|
||||
private:
|
||||
int me, nprocs;
|
||||
|
||||
double thresh; // threshold to perform LB
|
||||
int style; // style of LB
|
||||
int xflag, yflag, zflag; // xyz LB flags
|
||||
@ -59,7 +57,7 @@ class Balance : public Command {
|
||||
|
||||
int nitermax; // params for shift LB
|
||||
double stopthresh;
|
||||
char bstr[BSTR_SIZE + 1];
|
||||
std::string bstr;
|
||||
|
||||
int shift_allocate; // 1 if SHIFT vectors have been allocated
|
||||
int ndim; // length of balance string bstr
|
||||
@ -84,7 +82,7 @@ class Balance : public Command {
|
||||
int firststep;
|
||||
|
||||
double imbalance_splits();
|
||||
void shift_setup_static(char *);
|
||||
void shift_setup_static(const char *);
|
||||
void tally(int, int, double *);
|
||||
int adjust(int, double *);
|
||||
#ifdef BALANCE_DEBUG
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
// clang-format off
|
||||
/* ----------------------------------------------------------------------
|
||||
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
|
||||
https://www.lammps.org/, Sandia National Laboratories
|
||||
@ -34,7 +33,9 @@
|
||||
using namespace LAMMPS_NS;
|
||||
using namespace FixConst;
|
||||
|
||||
enum{SHIFT,BISECTION};
|
||||
enum { SHIFT, BISECTION };
|
||||
|
||||
// clang-format off
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
@ -66,10 +67,9 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
|
||||
|
||||
int iarg = 5;
|
||||
if (lbstyle == SHIFT) {
|
||||
if (iarg+4 > narg) error->all(FLERR,"Illegal fix balance command");
|
||||
if (strlen(arg[iarg+1]) > Balance::BSTR_SIZE)
|
||||
error->all(FLERR,"Illegal fix balance command");
|
||||
strncpy(bstr,arg[iarg+1], Balance::BSTR_SIZE+1);
|
||||
if (iarg+4 > narg) utils::missing_cmd_args(FLERR, "fix balance shift", error);
|
||||
bstr = arg[iarg+1];
|
||||
if (bstr.size() > Balance::BSTR_SIZE) error->all(FLERR,"Illegal fix balance shift command");
|
||||
nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
|
||||
if (nitermax <= 0) error->all(FLERR,"Illegal fix balance command");
|
||||
stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp);
|
||||
@ -83,7 +83,7 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
|
||||
// error checks
|
||||
|
||||
if (lbstyle == SHIFT) {
|
||||
int blen = strlen(bstr);
|
||||
int blen = bstr.size();
|
||||
for (int i = 0; i < blen; i++) {
|
||||
if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z')
|
||||
error->all(FLERR,"Fix balance shift string is invalid");
|
||||
@ -103,7 +103,7 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
|
||||
// process remaining optional args via Balance
|
||||
|
||||
balance = new Balance(lmp);
|
||||
if (lbstyle == SHIFT) balance->shift_setup(bstr,nitermax,thresh);
|
||||
if (lbstyle == SHIFT) balance->shift_setup(bstr.c_str(),nitermax,thresh);
|
||||
balance->options(iarg,narg,arg,0);
|
||||
wtflag = balance->wtflag;
|
||||
sortflag = balance->sortflag;
|
||||
|
||||
@ -42,7 +42,7 @@ class FixBalance : public Fix {
|
||||
private:
|
||||
int nevery, lbstyle, nitermax;
|
||||
double thresh, stopthresh;
|
||||
char bstr[4];
|
||||
std::string bstr;
|
||||
int wtflag; // 1 for weighted balancing
|
||||
int sortflag; // 1 for sorting comm messages
|
||||
|
||||
|
||||
Reference in New Issue
Block a user