refactor balance and fix balance to std::string to avoid buffer overflows

This commit is contained in:
Axel Kohlmeyer
2023-12-02 09:32:53 -05:00
parent 2970d73d22
commit 318556497f
4 changed files with 40 additions and 43 deletions

View File

@ -1,4 +1,3 @@
// clang-format off
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
https://www.lammps.org/, Sandia National Laboratories https://www.lammps.org/, Sandia National Laboratories
@ -21,11 +20,10 @@
#include "balance.h" #include "balance.h"
#include "update.h"
#include "atom.h" #include "atom.h"
#include "neighbor.h"
#include "comm.h" #include "comm.h"
#include "domain.h" #include "domain.h"
#include "error.h"
#include "fix_store_atom.h" #include "fix_store_atom.h"
#include "force.h" #include "force.h"
#include "imbalance.h" #include "imbalance.h"
@ -37,9 +35,10 @@
#include "irregular.h" #include "irregular.h"
#include "memory.h" #include "memory.h"
#include "modify.h" #include "modify.h"
#include "neighbor.h"
#include "pair.h" #include "pair.h"
#include "rcb.h" #include "rcb.h"
#include "error.h" #include "update.h"
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
@ -48,17 +47,16 @@ using namespace LAMMPS_NS;
double EPSNEIGH = 1.0e-3; double EPSNEIGH = 1.0e-3;
enum{XYZ,SHIFT,BISECTION}; enum { XYZ, SHIFT, BISECTION };
enum{NONE,UNIFORM,USER}; enum { NONE, UNIFORM, USER };
enum{X,Y,Z}; enum { X, Y, Z };
// clang-format off
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
Balance::Balance(LAMMPS *lmp) : Command(lmp) Balance::Balance(LAMMPS *lmp) : Command(lmp)
{ {
MPI_Comm_rank(world,&me);
MPI_Comm_size(world,&nprocs);
user_xsplit = user_ysplit = user_zsplit = nullptr; user_xsplit = user_ysplit = user_zsplit = nullptr;
shift_allocate = 0; shift_allocate = 0;
proccost = allproccost = nullptr; proccost = allproccost = nullptr;
@ -118,7 +116,7 @@ void Balance::command(int narg, char **arg)
if (domain->box_exist == 0) if (domain->box_exist == 0)
error->all(FLERR,"Balance command before simulation box is defined"); error->all(FLERR,"Balance command before simulation box is defined");
if (me == 0) utils::logmesg(lmp,"Balancing ...\n"); if (comm->me == 0) utils::logmesg(lmp,"Balancing ...\n");
// parse required arguments // parse required arguments
@ -196,10 +194,10 @@ void Balance::command(int narg, char **arg)
} else if (strcmp(arg[iarg],"shift") == 0) { } else if (strcmp(arg[iarg],"shift") == 0) {
if (style != -1) error->all(FLERR,"Illegal balance command"); if (style != -1) error->all(FLERR,"Illegal balance command");
if (iarg+4 > narg) error->all(FLERR,"Illegal balance command"); if (iarg+4 > narg) utils::missing_cmd_args(FLERR, "balance shift", error);
style = SHIFT; style = SHIFT;
if (strlen(arg[iarg+1]) > BSTR_SIZE) error->all(FLERR,"Illegal balance command"); bstr = arg[iarg+1];
strncpy(bstr,arg[iarg+1],BSTR_SIZE+1); if (bstr.size() > BSTR_SIZE) error->all(FLERR,"Illegal balance shift command");
nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp); nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
if (nitermax <= 0) error->all(FLERR,"Illegal balance command"); if (nitermax <= 0) error->all(FLERR,"Illegal balance command");
stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp); stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp);
@ -235,7 +233,7 @@ void Balance::command(int narg, char **arg)
} }
if (style == SHIFT) { if (style == SHIFT) {
const int blen=strlen(bstr); const int blen = bstr.size();
for (int i = 0; i < blen; i++) { for (int i = 0; i < blen; i++) {
if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z') if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z')
error->all(FLERR,"Balance shift string is invalid"); error->all(FLERR,"Balance shift string is invalid");
@ -336,7 +334,7 @@ void Balance::command(int narg, char **arg)
if (style == SHIFT) { if (style == SHIFT) {
comm->layout = Comm::LAYOUT_NONUNIFORM; comm->layout = Comm::LAYOUT_NONUNIFORM;
shift_setup_static(bstr); shift_setup_static(bstr.c_str());
niter = shift(); niter = shift();
} }
@ -393,7 +391,7 @@ void Balance::command(int narg, char **arg)
// stats output // stats output
if (me == 0) { if (comm->me == 0) {
std::string mesg = fmt::format(" rebalancing time: {:.3f} seconds\n", std::string mesg = fmt::format(" rebalancing time: {:.3f} seconds\n",
platform::walltime()-start_time); platform::walltime()-start_time);
mesg += fmt::format(" iteration count = {}\n",niter); mesg += fmt::format(" iteration count = {}\n",niter);
@ -571,7 +569,7 @@ double Balance::imbalance_factor(double &maxcost)
MPI_Allreduce(&mycost,&totalcost,1,MPI_DOUBLE,MPI_SUM,world); MPI_Allreduce(&mycost,&totalcost,1,MPI_DOUBLE,MPI_SUM,world);
double imbalance = 1.0; double imbalance = 1.0;
if (maxcost > 0.0) imbalance = maxcost / (totalcost/nprocs); if (maxcost > 0.0) imbalance = maxcost / (totalcost / comm->nprocs);
return imbalance; return imbalance;
} }
@ -719,12 +717,12 @@ int *Balance::bisection()
set rho = 0 for static balancing set rho = 0 for static balancing
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void Balance::shift_setup_static(char *str) void Balance::shift_setup_static(const char *str)
{ {
shift_allocate = 1; shift_allocate = 1;
memory->create(proccost,nprocs,"balance:proccost"); memory->create(proccost,comm->nprocs,"balance:proccost");
memory->create(allproccost,nprocs,"balance:allproccost"); memory->create(allproccost,comm->nprocs,"balance:allproccost");
ndim = strlen(str); ndim = strlen(str);
bdim = new int[ndim]; bdim = new int[ndim];
@ -771,7 +769,7 @@ void Balance::shift_setup_static(char *str)
set rho = 1 to do dynamic balancing after call to shift_setup_static() set rho = 1 to do dynamic balancing after call to shift_setup_static()
------------------------------------------------------------------------- */ ------------------------------------------------------------------------- */
void Balance::shift_setup(char *str, int nitermax_in, double thresh_in) void Balance::shift_setup(const char *str, int nitermax_in, double thresh_in)
{ {
shift_setup_static(str); shift_setup_static(str);
nitermax = nitermax_in; nitermax = nitermax_in;
@ -871,7 +869,7 @@ int Balance::shift()
// iterate until balanced // iterate until balanced
#ifdef BALANCE_DEBUG #ifdef BALANCE_DEBUG
if (me == 0) debug_shift_output(idim,0,np,split); if (comm->me == 0) debug_shift_output(idim,0,np,split);
#endif #endif
int doneflag; int doneflag;
@ -882,7 +880,7 @@ int Balance::shift()
niter++; niter++;
#ifdef BALANCE_DEBUG #ifdef BALANCE_DEBUG
if (me == 0) debug_shift_output(idim,m+1,np,split); if (comm->me == 0) debug_shift_output(idim,m+1,np,split);
if (outflag) dumpout(update->ntimestep); if (outflag) dumpout(update->ntimestep);
#endif #endif
@ -1137,7 +1135,7 @@ double Balance::imbalance_splits()
int ny = comm->procgrid[1]; int ny = comm->procgrid[1];
int nz = comm->procgrid[2]; int nz = comm->procgrid[2];
for (int i = 0; i < nprocs; i++) proccost[i] = 0.0; for (int i = 0; i < comm->nprocs; i++) proccost[i] = 0.0;
double **x = atom->x; double **x = atom->x;
int nlocal = atom->nlocal; int nlocal = atom->nlocal;
@ -1162,17 +1160,17 @@ double Balance::imbalance_splits()
// one proc's particles may map to many partitions, so must Allreduce // one proc's particles may map to many partitions, so must Allreduce
MPI_Allreduce(proccost,allproccost,nprocs,MPI_DOUBLE,MPI_SUM,world); MPI_Allreduce(proccost,allproccost,comm->nprocs,MPI_DOUBLE,MPI_SUM,world);
double maxcost = 0.0; double maxcost = 0.0;
double totalcost = 0.0; double totalcost = 0.0;
for (int i = 0; i < nprocs; i++) { for (int i = 0; i < comm->nprocs; i++) {
maxcost = MAX(maxcost,allproccost[i]); maxcost = MAX(maxcost,allproccost[i]);
totalcost += allproccost[i]; totalcost += allproccost[i];
} }
double imbalance = 1.0; double imbalance = 1.0;
if (maxcost > 0.0) imbalance = maxcost / (totalcost/nprocs); if (maxcost > 0.0) imbalance = maxcost / (totalcost/comm->nprocs);
return imbalance; return imbalance;
} }
@ -1188,6 +1186,7 @@ void Balance::dumpout(bigint tstep)
{ {
int dimension = domain->dimension; int dimension = domain->dimension;
int triclinic = domain->triclinic; int triclinic = domain->triclinic;
int nprocs = comm->nprocs;
// Allgather each proc's sub-box // Allgather each proc's sub-box
// could use Gather, but that requires MPI to alloc memory // could use Gather, but that requires MPI to alloc memory
@ -1209,7 +1208,7 @@ void Balance::dumpout(bigint tstep)
memory->create(boxall,nprocs,6,"balance:dumpout"); memory->create(boxall,nprocs,6,"balance:dumpout");
MPI_Allgather(box,6,MPI_DOUBLE,&boxall[0][0],6,MPI_DOUBLE,world); MPI_Allgather(box,6,MPI_DOUBLE,&boxall[0][0],6,MPI_DOUBLE,world);
if (me) { if (comm->me) {
memory->destroy(boxall); memory->destroy(boxall);
return; return;
} }

View File

@ -41,7 +41,7 @@ class Balance : public Command {
void init_imbalance(int); void init_imbalance(int);
void set_weights(); void set_weights();
double imbalance_factor(double &); double imbalance_factor(double &);
void shift_setup(char *, int, double); void shift_setup(const char *, int, double);
int shift(); int shift();
int *bisection(); int *bisection();
void dumpout(bigint); void dumpout(bigint);
@ -49,8 +49,6 @@ class Balance : public Command {
static constexpr int BSTR_SIZE = 3; static constexpr int BSTR_SIZE = 3;
private: private:
int me, nprocs;
double thresh; // threshold to perform LB double thresh; // threshold to perform LB
int style; // style of LB int style; // style of LB
int xflag, yflag, zflag; // xyz LB flags int xflag, yflag, zflag; // xyz LB flags
@ -59,7 +57,7 @@ class Balance : public Command {
int nitermax; // params for shift LB int nitermax; // params for shift LB
double stopthresh; double stopthresh;
char bstr[BSTR_SIZE + 1]; std::string bstr;
int shift_allocate; // 1 if SHIFT vectors have been allocated int shift_allocate; // 1 if SHIFT vectors have been allocated
int ndim; // length of balance string bstr int ndim; // length of balance string bstr
@ -84,7 +82,7 @@ class Balance : public Command {
int firststep; int firststep;
double imbalance_splits(); double imbalance_splits();
void shift_setup_static(char *); void shift_setup_static(const char *);
void tally(int, int, double *); void tally(int, int, double *);
int adjust(int, double *); int adjust(int, double *);
#ifdef BALANCE_DEBUG #ifdef BALANCE_DEBUG

View File

@ -1,4 +1,3 @@
// clang-format off
/* ---------------------------------------------------------------------- /* ----------------------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
https://www.lammps.org/, Sandia National Laboratories https://www.lammps.org/, Sandia National Laboratories
@ -34,7 +33,9 @@
using namespace LAMMPS_NS; using namespace LAMMPS_NS;
using namespace FixConst; using namespace FixConst;
enum{SHIFT,BISECTION}; enum { SHIFT, BISECTION };
// clang-format off
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
@ -66,10 +67,9 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
int iarg = 5; int iarg = 5;
if (lbstyle == SHIFT) { if (lbstyle == SHIFT) {
if (iarg+4 > narg) error->all(FLERR,"Illegal fix balance command"); if (iarg+4 > narg) utils::missing_cmd_args(FLERR, "fix balance shift", error);
if (strlen(arg[iarg+1]) > Balance::BSTR_SIZE) bstr = arg[iarg+1];
error->all(FLERR,"Illegal fix balance command"); if (bstr.size() > Balance::BSTR_SIZE) error->all(FLERR,"Illegal fix balance shift command");
strncpy(bstr,arg[iarg+1], Balance::BSTR_SIZE+1);
nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp); nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
if (nitermax <= 0) error->all(FLERR,"Illegal fix balance command"); if (nitermax <= 0) error->all(FLERR,"Illegal fix balance command");
stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp); stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp);
@ -83,7 +83,7 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
// error checks // error checks
if (lbstyle == SHIFT) { if (lbstyle == SHIFT) {
int blen = strlen(bstr); int blen = bstr.size();
for (int i = 0; i < blen; i++) { for (int i = 0; i < blen; i++) {
if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z') if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z')
error->all(FLERR,"Fix balance shift string is invalid"); error->all(FLERR,"Fix balance shift string is invalid");
@ -103,7 +103,7 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
// process remaining optional args via Balance // process remaining optional args via Balance
balance = new Balance(lmp); balance = new Balance(lmp);
if (lbstyle == SHIFT) balance->shift_setup(bstr,nitermax,thresh); if (lbstyle == SHIFT) balance->shift_setup(bstr.c_str(),nitermax,thresh);
balance->options(iarg,narg,arg,0); balance->options(iarg,narg,arg,0);
wtflag = balance->wtflag; wtflag = balance->wtflag;
sortflag = balance->sortflag; sortflag = balance->sortflag;

View File

@ -42,7 +42,7 @@ class FixBalance : public Fix {
private: private:
int nevery, lbstyle, nitermax; int nevery, lbstyle, nitermax;
double thresh, stopthresh; double thresh, stopthresh;
char bstr[4]; std::string bstr;
int wtflag; // 1 for weighted balancing int wtflag; // 1 for weighted balancing
int sortflag; // 1 for sorting comm messages int sortflag; // 1 for sorting comm messages