refactor balance and fix balance to std::string to avoid buffer overflows

This commit is contained in:
Axel Kohlmeyer
2023-12-02 09:32:53 -05:00
parent 2970d73d22
commit 318556497f
4 changed files with 40 additions and 43 deletions

View File

@ -1,4 +1,3 @@
// clang-format off
/* ----------------------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
https://www.lammps.org/, Sandia National Laboratories
@ -21,11 +20,10 @@
#include "balance.h"
#include "update.h"
#include "atom.h"
#include "neighbor.h"
#include "comm.h"
#include "domain.h"
#include "error.h"
#include "fix_store_atom.h"
#include "force.h"
#include "imbalance.h"
@ -37,9 +35,10 @@
#include "irregular.h"
#include "memory.h"
#include "modify.h"
#include "neighbor.h"
#include "pair.h"
#include "rcb.h"
#include "error.h"
#include "update.h"
#include <cmath>
#include <cstring>
@ -48,17 +47,16 @@ using namespace LAMMPS_NS;
double EPSNEIGH = 1.0e-3;
enum{XYZ,SHIFT,BISECTION};
enum{NONE,UNIFORM,USER};
enum{X,Y,Z};
enum { XYZ, SHIFT, BISECTION };
enum { NONE, UNIFORM, USER };
enum { X, Y, Z };
// clang-format off
/* ---------------------------------------------------------------------- */
Balance::Balance(LAMMPS *lmp) : Command(lmp)
{
MPI_Comm_rank(world,&me);
MPI_Comm_size(world,&nprocs);
user_xsplit = user_ysplit = user_zsplit = nullptr;
shift_allocate = 0;
proccost = allproccost = nullptr;
@ -118,7 +116,7 @@ void Balance::command(int narg, char **arg)
if (domain->box_exist == 0)
error->all(FLERR,"Balance command before simulation box is defined");
if (me == 0) utils::logmesg(lmp,"Balancing ...\n");
if (comm->me == 0) utils::logmesg(lmp,"Balancing ...\n");
// parse required arguments
@ -196,10 +194,10 @@ void Balance::command(int narg, char **arg)
} else if (strcmp(arg[iarg],"shift") == 0) {
if (style != -1) error->all(FLERR,"Illegal balance command");
if (iarg+4 > narg) error->all(FLERR,"Illegal balance command");
if (iarg+4 > narg) utils::missing_cmd_args(FLERR, "balance shift", error);
style = SHIFT;
if (strlen(arg[iarg+1]) > BSTR_SIZE) error->all(FLERR,"Illegal balance command");
strncpy(bstr,arg[iarg+1],BSTR_SIZE+1);
bstr = arg[iarg+1];
if (bstr.size() > BSTR_SIZE) error->all(FLERR,"Illegal balance shift command");
nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
if (nitermax <= 0) error->all(FLERR,"Illegal balance command");
stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp);
@ -235,7 +233,7 @@ void Balance::command(int narg, char **arg)
}
if (style == SHIFT) {
const int blen=strlen(bstr);
const int blen = bstr.size();
for (int i = 0; i < blen; i++) {
if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z')
error->all(FLERR,"Balance shift string is invalid");
@ -336,7 +334,7 @@ void Balance::command(int narg, char **arg)
if (style == SHIFT) {
comm->layout = Comm::LAYOUT_NONUNIFORM;
shift_setup_static(bstr);
shift_setup_static(bstr.c_str());
niter = shift();
}
@ -393,7 +391,7 @@ void Balance::command(int narg, char **arg)
// stats output
if (me == 0) {
if (comm->me == 0) {
std::string mesg = fmt::format(" rebalancing time: {:.3f} seconds\n",
platform::walltime()-start_time);
mesg += fmt::format(" iteration count = {}\n",niter);
@ -571,7 +569,7 @@ double Balance::imbalance_factor(double &maxcost)
MPI_Allreduce(&mycost,&totalcost,1,MPI_DOUBLE,MPI_SUM,world);
double imbalance = 1.0;
if (maxcost > 0.0) imbalance = maxcost / (totalcost/nprocs);
if (maxcost > 0.0) imbalance = maxcost / (totalcost / comm->nprocs);
return imbalance;
}
@ -719,12 +717,12 @@ int *Balance::bisection()
set rho = 0 for static balancing
------------------------------------------------------------------------- */
void Balance::shift_setup_static(char *str)
void Balance::shift_setup_static(const char *str)
{
shift_allocate = 1;
memory->create(proccost,nprocs,"balance:proccost");
memory->create(allproccost,nprocs,"balance:allproccost");
memory->create(proccost,comm->nprocs,"balance:proccost");
memory->create(allproccost,comm->nprocs,"balance:allproccost");
ndim = strlen(str);
bdim = new int[ndim];
@ -771,7 +769,7 @@ void Balance::shift_setup_static(char *str)
set rho = 1 to do dynamic balancing after call to shift_setup_static()
------------------------------------------------------------------------- */
void Balance::shift_setup(char *str, int nitermax_in, double thresh_in)
void Balance::shift_setup(const char *str, int nitermax_in, double thresh_in)
{
shift_setup_static(str);
nitermax = nitermax_in;
@ -871,7 +869,7 @@ int Balance::shift()
// iterate until balanced
#ifdef BALANCE_DEBUG
if (me == 0) debug_shift_output(idim,0,np,split);
if (comm->me == 0) debug_shift_output(idim,0,np,split);
#endif
int doneflag;
@ -882,7 +880,7 @@ int Balance::shift()
niter++;
#ifdef BALANCE_DEBUG
if (me == 0) debug_shift_output(idim,m+1,np,split);
if (comm->me == 0) debug_shift_output(idim,m+1,np,split);
if (outflag) dumpout(update->ntimestep);
#endif
@ -1137,7 +1135,7 @@ double Balance::imbalance_splits()
int ny = comm->procgrid[1];
int nz = comm->procgrid[2];
for (int i = 0; i < nprocs; i++) proccost[i] = 0.0;
for (int i = 0; i < comm->nprocs; i++) proccost[i] = 0.0;
double **x = atom->x;
int nlocal = atom->nlocal;
@ -1162,17 +1160,17 @@ double Balance::imbalance_splits()
// one proc's particles may map to many partitions, so must Allreduce
MPI_Allreduce(proccost,allproccost,nprocs,MPI_DOUBLE,MPI_SUM,world);
MPI_Allreduce(proccost,allproccost,comm->nprocs,MPI_DOUBLE,MPI_SUM,world);
double maxcost = 0.0;
double totalcost = 0.0;
for (int i = 0; i < nprocs; i++) {
for (int i = 0; i < comm->nprocs; i++) {
maxcost = MAX(maxcost,allproccost[i]);
totalcost += allproccost[i];
}
double imbalance = 1.0;
if (maxcost > 0.0) imbalance = maxcost / (totalcost/nprocs);
if (maxcost > 0.0) imbalance = maxcost / (totalcost/comm->nprocs);
return imbalance;
}
@ -1188,6 +1186,7 @@ void Balance::dumpout(bigint tstep)
{
int dimension = domain->dimension;
int triclinic = domain->triclinic;
int nprocs = comm->nprocs;
// Allgather each proc's sub-box
// could use Gather, but that requires MPI to alloc memory
@ -1209,7 +1208,7 @@ void Balance::dumpout(bigint tstep)
memory->create(boxall,nprocs,6,"balance:dumpout");
MPI_Allgather(box,6,MPI_DOUBLE,&boxall[0][0],6,MPI_DOUBLE,world);
if (me) {
if (comm->me) {
memory->destroy(boxall);
return;
}