refactor balance and fix balance to std::string to avoid buffer overflows
This commit is contained in:
@ -1,4 +1,3 @@
|
|||||||
// clang-format off
|
|
||||||
/* ----------------------------------------------------------------------
|
/* ----------------------------------------------------------------------
|
||||||
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
|
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
|
||||||
https://www.lammps.org/, Sandia National Laboratories
|
https://www.lammps.org/, Sandia National Laboratories
|
||||||
@ -21,11 +20,10 @@
|
|||||||
|
|
||||||
#include "balance.h"
|
#include "balance.h"
|
||||||
|
|
||||||
#include "update.h"
|
|
||||||
#include "atom.h"
|
#include "atom.h"
|
||||||
#include "neighbor.h"
|
|
||||||
#include "comm.h"
|
#include "comm.h"
|
||||||
#include "domain.h"
|
#include "domain.h"
|
||||||
|
#include "error.h"
|
||||||
#include "fix_store_atom.h"
|
#include "fix_store_atom.h"
|
||||||
#include "force.h"
|
#include "force.h"
|
||||||
#include "imbalance.h"
|
#include "imbalance.h"
|
||||||
@ -37,9 +35,10 @@
|
|||||||
#include "irregular.h"
|
#include "irregular.h"
|
||||||
#include "memory.h"
|
#include "memory.h"
|
||||||
#include "modify.h"
|
#include "modify.h"
|
||||||
|
#include "neighbor.h"
|
||||||
#include "pair.h"
|
#include "pair.h"
|
||||||
#include "rcb.h"
|
#include "rcb.h"
|
||||||
#include "error.h"
|
#include "update.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
@ -48,17 +47,16 @@ using namespace LAMMPS_NS;
|
|||||||
|
|
||||||
double EPSNEIGH = 1.0e-3;
|
double EPSNEIGH = 1.0e-3;
|
||||||
|
|
||||||
enum{XYZ,SHIFT,BISECTION};
|
enum { XYZ, SHIFT, BISECTION };
|
||||||
enum{NONE,UNIFORM,USER};
|
enum { NONE, UNIFORM, USER };
|
||||||
enum{X,Y,Z};
|
enum { X, Y, Z };
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
/* ---------------------------------------------------------------------- */
|
/* ---------------------------------------------------------------------- */
|
||||||
|
|
||||||
Balance::Balance(LAMMPS *lmp) : Command(lmp)
|
Balance::Balance(LAMMPS *lmp) : Command(lmp)
|
||||||
{
|
{
|
||||||
MPI_Comm_rank(world,&me);
|
|
||||||
MPI_Comm_size(world,&nprocs);
|
|
||||||
|
|
||||||
user_xsplit = user_ysplit = user_zsplit = nullptr;
|
user_xsplit = user_ysplit = user_zsplit = nullptr;
|
||||||
shift_allocate = 0;
|
shift_allocate = 0;
|
||||||
proccost = allproccost = nullptr;
|
proccost = allproccost = nullptr;
|
||||||
@ -118,7 +116,7 @@ void Balance::command(int narg, char **arg)
|
|||||||
if (domain->box_exist == 0)
|
if (domain->box_exist == 0)
|
||||||
error->all(FLERR,"Balance command before simulation box is defined");
|
error->all(FLERR,"Balance command before simulation box is defined");
|
||||||
|
|
||||||
if (me == 0) utils::logmesg(lmp,"Balancing ...\n");
|
if (comm->me == 0) utils::logmesg(lmp,"Balancing ...\n");
|
||||||
|
|
||||||
// parse required arguments
|
// parse required arguments
|
||||||
|
|
||||||
@ -196,10 +194,10 @@ void Balance::command(int narg, char **arg)
|
|||||||
|
|
||||||
} else if (strcmp(arg[iarg],"shift") == 0) {
|
} else if (strcmp(arg[iarg],"shift") == 0) {
|
||||||
if (style != -1) error->all(FLERR,"Illegal balance command");
|
if (style != -1) error->all(FLERR,"Illegal balance command");
|
||||||
if (iarg+4 > narg) error->all(FLERR,"Illegal balance command");
|
if (iarg+4 > narg) utils::missing_cmd_args(FLERR, "balance shift", error);
|
||||||
style = SHIFT;
|
style = SHIFT;
|
||||||
if (strlen(arg[iarg+1]) > BSTR_SIZE) error->all(FLERR,"Illegal balance command");
|
bstr = arg[iarg+1];
|
||||||
strncpy(bstr,arg[iarg+1],BSTR_SIZE+1);
|
if (bstr.size() > BSTR_SIZE) error->all(FLERR,"Illegal balance shift command");
|
||||||
nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
|
nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
|
||||||
if (nitermax <= 0) error->all(FLERR,"Illegal balance command");
|
if (nitermax <= 0) error->all(FLERR,"Illegal balance command");
|
||||||
stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp);
|
stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp);
|
||||||
@ -235,7 +233,7 @@ void Balance::command(int narg, char **arg)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (style == SHIFT) {
|
if (style == SHIFT) {
|
||||||
const int blen=strlen(bstr);
|
const int blen = bstr.size();
|
||||||
for (int i = 0; i < blen; i++) {
|
for (int i = 0; i < blen; i++) {
|
||||||
if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z')
|
if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z')
|
||||||
error->all(FLERR,"Balance shift string is invalid");
|
error->all(FLERR,"Balance shift string is invalid");
|
||||||
@ -336,7 +334,7 @@ void Balance::command(int narg, char **arg)
|
|||||||
|
|
||||||
if (style == SHIFT) {
|
if (style == SHIFT) {
|
||||||
comm->layout = Comm::LAYOUT_NONUNIFORM;
|
comm->layout = Comm::LAYOUT_NONUNIFORM;
|
||||||
shift_setup_static(bstr);
|
shift_setup_static(bstr.c_str());
|
||||||
niter = shift();
|
niter = shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -393,7 +391,7 @@ void Balance::command(int narg, char **arg)
|
|||||||
|
|
||||||
// stats output
|
// stats output
|
||||||
|
|
||||||
if (me == 0) {
|
if (comm->me == 0) {
|
||||||
std::string mesg = fmt::format(" rebalancing time: {:.3f} seconds\n",
|
std::string mesg = fmt::format(" rebalancing time: {:.3f} seconds\n",
|
||||||
platform::walltime()-start_time);
|
platform::walltime()-start_time);
|
||||||
mesg += fmt::format(" iteration count = {}\n",niter);
|
mesg += fmt::format(" iteration count = {}\n",niter);
|
||||||
@ -571,7 +569,7 @@ double Balance::imbalance_factor(double &maxcost)
|
|||||||
MPI_Allreduce(&mycost,&totalcost,1,MPI_DOUBLE,MPI_SUM,world);
|
MPI_Allreduce(&mycost,&totalcost,1,MPI_DOUBLE,MPI_SUM,world);
|
||||||
|
|
||||||
double imbalance = 1.0;
|
double imbalance = 1.0;
|
||||||
if (maxcost > 0.0) imbalance = maxcost / (totalcost/nprocs);
|
if (maxcost > 0.0) imbalance = maxcost / (totalcost / comm->nprocs);
|
||||||
return imbalance;
|
return imbalance;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -719,12 +717,12 @@ int *Balance::bisection()
|
|||||||
set rho = 0 for static balancing
|
set rho = 0 for static balancing
|
||||||
------------------------------------------------------------------------- */
|
------------------------------------------------------------------------- */
|
||||||
|
|
||||||
void Balance::shift_setup_static(char *str)
|
void Balance::shift_setup_static(const char *str)
|
||||||
{
|
{
|
||||||
shift_allocate = 1;
|
shift_allocate = 1;
|
||||||
|
|
||||||
memory->create(proccost,nprocs,"balance:proccost");
|
memory->create(proccost,comm->nprocs,"balance:proccost");
|
||||||
memory->create(allproccost,nprocs,"balance:allproccost");
|
memory->create(allproccost,comm->nprocs,"balance:allproccost");
|
||||||
|
|
||||||
ndim = strlen(str);
|
ndim = strlen(str);
|
||||||
bdim = new int[ndim];
|
bdim = new int[ndim];
|
||||||
@ -771,7 +769,7 @@ void Balance::shift_setup_static(char *str)
|
|||||||
set rho = 1 to do dynamic balancing after call to shift_setup_static()
|
set rho = 1 to do dynamic balancing after call to shift_setup_static()
|
||||||
------------------------------------------------------------------------- */
|
------------------------------------------------------------------------- */
|
||||||
|
|
||||||
void Balance::shift_setup(char *str, int nitermax_in, double thresh_in)
|
void Balance::shift_setup(const char *str, int nitermax_in, double thresh_in)
|
||||||
{
|
{
|
||||||
shift_setup_static(str);
|
shift_setup_static(str);
|
||||||
nitermax = nitermax_in;
|
nitermax = nitermax_in;
|
||||||
@ -871,7 +869,7 @@ int Balance::shift()
|
|||||||
// iterate until balanced
|
// iterate until balanced
|
||||||
|
|
||||||
#ifdef BALANCE_DEBUG
|
#ifdef BALANCE_DEBUG
|
||||||
if (me == 0) debug_shift_output(idim,0,np,split);
|
if (comm->me == 0) debug_shift_output(idim,0,np,split);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int doneflag;
|
int doneflag;
|
||||||
@ -882,7 +880,7 @@ int Balance::shift()
|
|||||||
niter++;
|
niter++;
|
||||||
|
|
||||||
#ifdef BALANCE_DEBUG
|
#ifdef BALANCE_DEBUG
|
||||||
if (me == 0) debug_shift_output(idim,m+1,np,split);
|
if (comm->me == 0) debug_shift_output(idim,m+1,np,split);
|
||||||
if (outflag) dumpout(update->ntimestep);
|
if (outflag) dumpout(update->ntimestep);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -1137,7 +1135,7 @@ double Balance::imbalance_splits()
|
|||||||
int ny = comm->procgrid[1];
|
int ny = comm->procgrid[1];
|
||||||
int nz = comm->procgrid[2];
|
int nz = comm->procgrid[2];
|
||||||
|
|
||||||
for (int i = 0; i < nprocs; i++) proccost[i] = 0.0;
|
for (int i = 0; i < comm->nprocs; i++) proccost[i] = 0.0;
|
||||||
|
|
||||||
double **x = atom->x;
|
double **x = atom->x;
|
||||||
int nlocal = atom->nlocal;
|
int nlocal = atom->nlocal;
|
||||||
@ -1162,17 +1160,17 @@ double Balance::imbalance_splits()
|
|||||||
|
|
||||||
// one proc's particles may map to many partitions, so must Allreduce
|
// one proc's particles may map to many partitions, so must Allreduce
|
||||||
|
|
||||||
MPI_Allreduce(proccost,allproccost,nprocs,MPI_DOUBLE,MPI_SUM,world);
|
MPI_Allreduce(proccost,allproccost,comm->nprocs,MPI_DOUBLE,MPI_SUM,world);
|
||||||
|
|
||||||
double maxcost = 0.0;
|
double maxcost = 0.0;
|
||||||
double totalcost = 0.0;
|
double totalcost = 0.0;
|
||||||
for (int i = 0; i < nprocs; i++) {
|
for (int i = 0; i < comm->nprocs; i++) {
|
||||||
maxcost = MAX(maxcost,allproccost[i]);
|
maxcost = MAX(maxcost,allproccost[i]);
|
||||||
totalcost += allproccost[i];
|
totalcost += allproccost[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
double imbalance = 1.0;
|
double imbalance = 1.0;
|
||||||
if (maxcost > 0.0) imbalance = maxcost / (totalcost/nprocs);
|
if (maxcost > 0.0) imbalance = maxcost / (totalcost/comm->nprocs);
|
||||||
return imbalance;
|
return imbalance;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1188,6 +1186,7 @@ void Balance::dumpout(bigint tstep)
|
|||||||
{
|
{
|
||||||
int dimension = domain->dimension;
|
int dimension = domain->dimension;
|
||||||
int triclinic = domain->triclinic;
|
int triclinic = domain->triclinic;
|
||||||
|
int nprocs = comm->nprocs;
|
||||||
|
|
||||||
// Allgather each proc's sub-box
|
// Allgather each proc's sub-box
|
||||||
// could use Gather, but that requires MPI to alloc memory
|
// could use Gather, but that requires MPI to alloc memory
|
||||||
@ -1209,7 +1208,7 @@ void Balance::dumpout(bigint tstep)
|
|||||||
memory->create(boxall,nprocs,6,"balance:dumpout");
|
memory->create(boxall,nprocs,6,"balance:dumpout");
|
||||||
MPI_Allgather(box,6,MPI_DOUBLE,&boxall[0][0],6,MPI_DOUBLE,world);
|
MPI_Allgather(box,6,MPI_DOUBLE,&boxall[0][0],6,MPI_DOUBLE,world);
|
||||||
|
|
||||||
if (me) {
|
if (comm->me) {
|
||||||
memory->destroy(boxall);
|
memory->destroy(boxall);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -41,7 +41,7 @@ class Balance : public Command {
|
|||||||
void init_imbalance(int);
|
void init_imbalance(int);
|
||||||
void set_weights();
|
void set_weights();
|
||||||
double imbalance_factor(double &);
|
double imbalance_factor(double &);
|
||||||
void shift_setup(char *, int, double);
|
void shift_setup(const char *, int, double);
|
||||||
int shift();
|
int shift();
|
||||||
int *bisection();
|
int *bisection();
|
||||||
void dumpout(bigint);
|
void dumpout(bigint);
|
||||||
@ -49,8 +49,6 @@ class Balance : public Command {
|
|||||||
static constexpr int BSTR_SIZE = 3;
|
static constexpr int BSTR_SIZE = 3;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int me, nprocs;
|
|
||||||
|
|
||||||
double thresh; // threshold to perform LB
|
double thresh; // threshold to perform LB
|
||||||
int style; // style of LB
|
int style; // style of LB
|
||||||
int xflag, yflag, zflag; // xyz LB flags
|
int xflag, yflag, zflag; // xyz LB flags
|
||||||
@ -59,7 +57,7 @@ class Balance : public Command {
|
|||||||
|
|
||||||
int nitermax; // params for shift LB
|
int nitermax; // params for shift LB
|
||||||
double stopthresh;
|
double stopthresh;
|
||||||
char bstr[BSTR_SIZE + 1];
|
std::string bstr;
|
||||||
|
|
||||||
int shift_allocate; // 1 if SHIFT vectors have been allocated
|
int shift_allocate; // 1 if SHIFT vectors have been allocated
|
||||||
int ndim; // length of balance string bstr
|
int ndim; // length of balance string bstr
|
||||||
@ -84,7 +82,7 @@ class Balance : public Command {
|
|||||||
int firststep;
|
int firststep;
|
||||||
|
|
||||||
double imbalance_splits();
|
double imbalance_splits();
|
||||||
void shift_setup_static(char *);
|
void shift_setup_static(const char *);
|
||||||
void tally(int, int, double *);
|
void tally(int, int, double *);
|
||||||
int adjust(int, double *);
|
int adjust(int, double *);
|
||||||
#ifdef BALANCE_DEBUG
|
#ifdef BALANCE_DEBUG
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
// clang-format off
|
|
||||||
/* ----------------------------------------------------------------------
|
/* ----------------------------------------------------------------------
|
||||||
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
|
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
|
||||||
https://www.lammps.org/, Sandia National Laboratories
|
https://www.lammps.org/, Sandia National Laboratories
|
||||||
@ -34,7 +33,9 @@
|
|||||||
using namespace LAMMPS_NS;
|
using namespace LAMMPS_NS;
|
||||||
using namespace FixConst;
|
using namespace FixConst;
|
||||||
|
|
||||||
enum{SHIFT,BISECTION};
|
enum { SHIFT, BISECTION };
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
/* ---------------------------------------------------------------------- */
|
/* ---------------------------------------------------------------------- */
|
||||||
|
|
||||||
@ -66,10 +67,9 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
|
|||||||
|
|
||||||
int iarg = 5;
|
int iarg = 5;
|
||||||
if (lbstyle == SHIFT) {
|
if (lbstyle == SHIFT) {
|
||||||
if (iarg+4 > narg) error->all(FLERR,"Illegal fix balance command");
|
if (iarg+4 > narg) utils::missing_cmd_args(FLERR, "fix balance shift", error);
|
||||||
if (strlen(arg[iarg+1]) > Balance::BSTR_SIZE)
|
bstr = arg[iarg+1];
|
||||||
error->all(FLERR,"Illegal fix balance command");
|
if (bstr.size() > Balance::BSTR_SIZE) error->all(FLERR,"Illegal fix balance shift command");
|
||||||
strncpy(bstr,arg[iarg+1], Balance::BSTR_SIZE+1);
|
|
||||||
nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
|
nitermax = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
|
||||||
if (nitermax <= 0) error->all(FLERR,"Illegal fix balance command");
|
if (nitermax <= 0) error->all(FLERR,"Illegal fix balance command");
|
||||||
stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp);
|
stopthresh = utils::numeric(FLERR,arg[iarg+3],false,lmp);
|
||||||
@ -83,7 +83,7 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
|
|||||||
// error checks
|
// error checks
|
||||||
|
|
||||||
if (lbstyle == SHIFT) {
|
if (lbstyle == SHIFT) {
|
||||||
int blen = strlen(bstr);
|
int blen = bstr.size();
|
||||||
for (int i = 0; i < blen; i++) {
|
for (int i = 0; i < blen; i++) {
|
||||||
if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z')
|
if (bstr[i] != 'x' && bstr[i] != 'y' && bstr[i] != 'z')
|
||||||
error->all(FLERR,"Fix balance shift string is invalid");
|
error->all(FLERR,"Fix balance shift string is invalid");
|
||||||
@ -103,7 +103,7 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
|
|||||||
// process remaining optional args via Balance
|
// process remaining optional args via Balance
|
||||||
|
|
||||||
balance = new Balance(lmp);
|
balance = new Balance(lmp);
|
||||||
if (lbstyle == SHIFT) balance->shift_setup(bstr,nitermax,thresh);
|
if (lbstyle == SHIFT) balance->shift_setup(bstr.c_str(),nitermax,thresh);
|
||||||
balance->options(iarg,narg,arg,0);
|
balance->options(iarg,narg,arg,0);
|
||||||
wtflag = balance->wtflag;
|
wtflag = balance->wtflag;
|
||||||
sortflag = balance->sortflag;
|
sortflag = balance->sortflag;
|
||||||
|
|||||||
@ -42,7 +42,7 @@ class FixBalance : public Fix {
|
|||||||
private:
|
private:
|
||||||
int nevery, lbstyle, nitermax;
|
int nevery, lbstyle, nitermax;
|
||||||
double thresh, stopthresh;
|
double thresh, stopthresh;
|
||||||
char bstr[4];
|
std::string bstr;
|
||||||
int wtflag; // 1 for weighted balancing
|
int wtflag; // 1 for weighted balancing
|
||||||
int sortflag; // 1 for sorting comm messages
|
int sortflag; // 1 for sorting comm messages
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user