From 3f674e5062aa8533f91da871f36f7bdcd90845db Mon Sep 17 00:00:00 2001 From: Axel Kohlmeyer Date: Wed, 3 Aug 2016 18:34:24 -0400 Subject: [PATCH] re-factored balance command now works with group and time weights --- src/balance.cpp | 236 ++++++++++++---------------------------- src/balance.h | 11 -- src/fix_balance.cpp | 5 +- src/imbalance.h | 25 +++-- src/imbalance_group.cpp | 36 ++++-- src/imbalance_group.h | 10 +- src/imbalance_neigh.cpp | 32 +++++- src/imbalance_neigh.h | 18 +-- src/imbalance_time.cpp | 54 ++++++++- src/imbalance_time.h | 18 +-- src/imbalance_var.cpp | 60 +++++++++- src/imbalance_var.h | 18 +-- 12 files changed, 288 insertions(+), 235 deletions(-) diff --git a/src/balance.cpp b/src/balance.cpp index cc765a0b21..c92ab6c1f6 100644 --- a/src/balance.cpp +++ b/src/balance.cpp @@ -63,12 +63,6 @@ Balance::Balance(LAMMPS *lmp) : Pointers(lmp) nimbalance = 0; imbalance = NULL; weight = NULL; - - ngroup = 0; - group_id = NULL; - group_weight = NULL; - - clock_imbalance = NULL; } /* ---------------------------------------------------------------------- */ @@ -100,11 +94,6 @@ Balance::~Balance() delete [] imbalance; delete [] weight; -#if 1 - delete [] group_id; - delete [] group_weight; - delete [] clock_imbalance; -#endif if (fp) fclose(fp); } @@ -238,37 +227,26 @@ void Balance::command(int narg, char **arg) Imbalance *imb; int nopt = 0; if (strcmp(arg[iarg+1],"group") == 0) { - imb = new ImbalanceGroup; - nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1); + imb = new ImbalanceGroup(lmp); + nopt = imb->options(narg-iarg,arg+iarg+2); imbalance[nimbalance] = imb; } else if (strcmp(arg[iarg+1],"time") == 0) { - imb = new ImbalanceTime; - nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1); + imb = new ImbalanceTime(lmp); + nopt = imb->options(narg-iarg,arg+iarg+2); imbalance[nimbalance] = imb; } else if (strcmp(arg[iarg+1],"neigh") == 0) { - imb = new ImbalanceNeigh; - nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1); + imb = new ImbalanceNeigh(lmp); + nopt = imb->options(narg-iarg,arg+iarg+2); imbalance[nimbalance] = imb; } else if (strcmp(arg[iarg+1],"var") == 0) { - imb = new ImbalanceVar; - nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1); + imb = new ImbalanceVar(lmp); + nopt = imb->options(narg-iarg,arg+iarg+2); imbalance[nimbalance] = imb; } else { error->all(FLERR,"Unknown balance weight method"); } + ++nimbalance; iarg += 2+nopt; -#if 1 - } else if (strcmp(arg[iarg],"clock") == 0) { - if (iarg+2 > narg) error->all(FLERR,"Illegal balance command"); - double factor = force->numeric(FLERR,arg[iarg+1]); - if (factor < 0.0 || factor > 1.0) - error->all(FLERR,"Illegal balance command"); - imbalance_clock(factor,0.0); - iarg += 2; - } else if (strcmp(arg[iarg],"group") == 0) { - group_setup(narg-iarg-1,arg+iarg+1); - iarg += 2*ngroup + 2; -#endif } else error->all(FLERR,"Illegal balance command"); } @@ -321,7 +299,7 @@ void Balance::command(int narg, char **arg) comm->exchange(); if (domain->triclinic) domain->lamda2x(atom->nlocal); - // compute and apply imbalance weights + // compute and apply imbalance weights for local atoms if (nimbalance > 0) { int i; const int nlocal = atom->nlocal; @@ -329,7 +307,7 @@ void Balance::command(int narg, char **arg) for (i = 0; i < nlocal; ++i) weight[i] = 1.0; for (i = 0; i < nimbalance; ++i) - imbalance[i]->compute(lmp,weight); + imbalance[i]->compute(weight); } else weight = NULL; // imbinit = initial imbalance @@ -435,35 +413,39 @@ void Balance::command(int narg, char **arg) error->all(FLERR,str); } + // recompute and apply imbalance weights for local atoms + + if (nimbalance > 0) { + int i; + const int nlocal = atom->nlocal; + delete[] weight; + weight = new double[nlocal]; + for (i = 0; i < nlocal; ++i) + weight[i] = 1.0; + for (i = 0; i < nimbalance; ++i) + imbalance[i]->compute(weight); + } else weight = NULL; + // imbfinal = final imbalance based on final (weighted) nlocal int maxfinal; double imbfinal = imbalance_nlocal(maxfinal); if (me == 0) { + double stop_time = MPI_Wtime(); if (screen) { - fprintf(screen," rebalancing time: %g seconds\n", - MPI_Wtime()-start_time); + fprintf(screen," rebalancing time: %g seconds\n",stop_time-start_time); fprintf(screen," iteration count = %d\n",niter); - if (ngroup > 0) { - fprintf(screen," group weights:"); - for (int i=0; i < ngroup; ++i) - fprintf(screen," %s=%g", group->names[group_id[i]],group_weight[i]); - fprintf(screen,"\n"); - } + for (int i = 0; i < nimbalance; ++i) imbalance[i]->info(screen); fprintf(screen," initial/final max load/proc = %d %d\n", maxinit,maxfinal); fprintf(screen," initial/final imbalance factor = %g %g\n", imbinit,imbfinal); } if (logfile) { + fprintf(logfile," rebalancing time: %g seconds\n",stop_time-start_time); fprintf(logfile," iteration count = %d\n",niter); - if (ngroup > 0) { - fprintf(logfile," group weights:"); - for (int i=0; i < ngroup; ++i) - fprintf(logfile," %s=%g", group->names[group_id[i]],group_weight[i]); - fprintf(logfile,"\n"); - } + for (int i = 0; i < nimbalance; ++i) imbalance[i]->info(logfile); fprintf(logfile," initial/final max load/proc = %d %d\n", maxinit,maxfinal); fprintf(logfile," initial/final imbalance factor = %g %g\n", @@ -505,69 +487,6 @@ void Balance::command(int narg, char **arg) } } - /* ---------------------------------------------------------------------- - compute the computational load associated with an atom - i = atom index - return cost = product of group weights for this atom. -------------------------------------------------------------------------- */ - -double Balance::getcost(int i) -{ - double cost = 1.0; - for (int j = 0; j < ngroup; ++j) { - if (atom->mask[i] & group->bitmask[group_id[j]]) - cost *= group_weight[j]; - } - return cost; -} - -/* ---------------------------------------------------------------------- - calculate imbalance based on timers for Pair+Bond+Kspace+Neighbor time. -------------------------------------------------------------------------- */ - -double Balance::imbalance_clock(double factor, double last_cost) -{ - - // Compute the cost function of based on relevant timers - if (timer->has_normal()) { - if (!clock_imbalance) clock_imbalance = new double[nprocs+1]; - - double cost = -last_cost; - cost += timer->get_wall(Timer::PAIR); - cost += timer->get_wall(Timer::NEIGH); - cost += timer->get_wall(Timer::BOND); - cost += timer->get_wall(Timer::KSPACE); - - double *clock_cost = new double[nprocs+1]; - for (int i = 0; i <= nprocs; ++i) clock_imbalance[i] = clock_cost[i] = 0.0; - clock_cost[me] = cost; - clock_cost[nprocs] = cost; - MPI_Allreduce(clock_cost,clock_imbalance,nprocs+1,MPI_DOUBLE,MPI_SUM,world); - - const double avg_cost = clock_imbalance[nprocs]/nprocs; - if (avg_cost > 0.0) { - for (int i = 0; i < nprocs; ++i) - clock_imbalance[i] = (1.0-factor) + factor*clock_imbalance[i]/avg_cost; - } else { - for (int i = 0; i < nprocs; ++i) - clock_imbalance[i] = 1.0; - } - -#if BALANCE_DEBUG - if (me == 0) { - fprintf(stderr,"Clock imbalance using factor %g\n",factor); - for (int i = 0; i < nprocs; ++i) - fprintf(stderr," % 2d: %4.2f",i,clock_imbalance[i]); - fputs("\n",stderr); - } -#endif - - delete [] clock_cost; - return cost + last_cost; - } - return last_cost; -} - /* ---------------------------------------------------------------------- calculate imbalance based on (weighted) local atom counts return max = max atom per proc @@ -579,10 +498,12 @@ double Balance::imbalance_nlocal(int &maxcost) // Compute the cost function of local atoms double cost = 0.0; - for (int i=0; i < atom->nlocal; ++i) { - cost += getcost(i); + if (weight == NULL) { + cost = atom->nlocal; + } else { + for (int i=0; i < atom->nlocal; ++i) + cost += weight[i]; } - if (clock_imbalance) cost *= clock_imbalance[me]; int intcost = (int)cost; int sumcost = maxcost = 0; @@ -622,21 +543,23 @@ double Balance::imbalance_splits(int &max) int nlocal = atom->nlocal; int ix,iy,iz; - for (int i = 0; i < nlocal; i++) { - ix = binary(x[i][0],nx,xsplit); - iy = binary(x[i][1],ny,ysplit); - iz = binary(x[i][2],nz,zsplit); - - proccost[iz*nx*ny + iy*nx + ix] += getcost(i); - } - - for (int i = 0; i < nprocs; i++) { - if (clock_imbalance) - proccount[i] = static_cast(proccost[i]*clock_imbalance[i]); - else - proccount[i] = static_cast(proccost[i]); + if (weight) { + for (int i = 0; i < nlocal; i++) { + ix = binary(x[i][0],nx,xsplit); + iy = binary(x[i][1],ny,ysplit); + iz = binary(x[i][2],nz,zsplit); + proccost[iz*nx*ny + iy*nx + ix] += weight[i]; + } + } else { + for (int i = 0; i < nlocal; i++) { + ix = binary(x[i][0],nx,xsplit); + iy = binary(x[i][1],ny,ysplit); + iz = binary(x[i][2],nz,zsplit); + proccost[iz*nx*ny + iy*nx + ix] += 1.0; + } } + for (int i = 0; i < nprocs; i++) proccount[i] = (int)(proccost[i]); MPI_Allreduce(proccount,allproccount,nprocs,MPI_INT,MPI_SUM,world); bigint sum = 0; max = 0; @@ -698,20 +621,10 @@ int *Balance::bisection(int sortflag) // invoke RCB // then invert() to create list of proc assignements for my atoms - // Use specified weightings for each atom rather than atom count + // Use compute weights for each atom, if available -#if 1 - double factor = 1.0; - if (clock_imbalance) factor = clock_imbalance[me]; - - double *weights = new double[nlocal]; - for (int i = 0; i < nlocal; i++) - weights[i] = getcost(i)*factor; -#endif - - rcb->compute(dim,atom->nlocal,atom->x,weights,shrinklo,shrinkhi); + rcb->compute(dim,atom->nlocal,atom->x,weight,shrinklo,shrinkhi); rcb->invert(sortflag); - delete[] weights; // reset RCB lo/hi bounding box to full simulation box as needed @@ -820,28 +733,6 @@ void Balance::shift_setup(char *str, int nitermax_in, double thresh_in) rho = 1; } -/* ---------------------------------------------------------------------- - setup group based load balance operations - called from balance->command() and fix balance -------------------------------------------------------------------------- */ -int Balance::group_setup(int narg, char **arg) -{ - if (narg < 3) error->all(FLERR,"Illegal balance command"); - - ngroup = force->inumeric(FLERR,arg[0]); - if (ngroup < 1) error->all(FLERR,"Illegal balance command"); - if (2*ngroup+1 > narg) error->all(FLERR,"Illegal balance command"); - - group_id = new int[ngroup]; - group_weight = new double[ngroup]; - for (int i = 0; i < ngroup; ++i) { - group_id[i] = group->find(arg[2*i+1]); - if (group_id[i] < 0) error->all(FLERR,"Unknown group in balance command"); - group_weight[i] = force->numeric(FLERR,arg[2*i+2]); - } - return ngroup; -} - /* ---------------------------------------------------------------------- load balance by changing xyz split proc boundaries in Comm called one time from input script command or many times from fix balance @@ -886,10 +777,13 @@ int Balance::shift() tally(bdim[idim],np,split); double cost = 0.0; - for (i=0; i < atom->nlocal; i++) - cost += getcost(i); + if (weight == NULL) { + cost = atom->nlocal; + } else { + for (int i=0; i < atom->nlocal; ++i) + cost += weight[i]; + } - if (clock_imbalance) cost *= clock_imbalance[me]; int intcost = (int)cost; int totalcost; MPI_Allreduce(&intcost,&totalcost,1,MPI_INT,MPI_SUM,world); @@ -1033,12 +927,16 @@ void Balance::tally(int dim, int n, double *split) int nlocal = atom->nlocal; int index; - double factor = 1.0; - if (clock_imbalance) factor = clock_imbalance[me]; - - for (int i = 0; i < nlocal; i++) { - index = binary(x[i][dim],n,split); - onecost[index] += getcost(i)*factor; + if (weight) { + for (int i = 0; i < nlocal; i++) { + index = binary(x[i][dim],n,split); + onecost[index] += weight[i]; + } + } else { + for (int i = 0; i < nlocal; i++) { + index = binary(x[i][dim],n,split); + onecost[index] += 1.0; + } } for (int i = 0; i < n; i++) onecount[i] = static_cast(onecost[i]); diff --git a/src/balance.h b/src/balance.h index eed77a6d85..0efea62e21 100644 --- a/src/balance.h +++ b/src/balance.h @@ -32,12 +32,10 @@ class Balance : protected Pointers { Balance(class LAMMPS *); ~Balance(); void command(int, char **); - int group_setup(int, char **); void shift_setup(char *, int, double); int shift(); int *bisection(int sortflag = 0); double imbalance_nlocal(int &); - double imbalance_clock(double, double); void dumpout(bigint, FILE *); private: @@ -71,14 +69,6 @@ class Balance : protected Pointers { class Imbalance **imbalance; // list of imbalance compute classes double *weight; // per (local) atom weight factor or NULL -#if 1 - int ngroup; // number of groups weights - int *group_id; // group ids for weights - double *group_weight; // weights of groups - - double *clock_imbalance; // computed wall clock imbalance, NULL if not available -#endif - int outflag; // for output of balance results to file FILE *fp; int firststep; @@ -88,7 +78,6 @@ class Balance : protected Pointers { void tally(int, int, double *); int adjust(int, double *); int binary(double, int, double *); - double getcost(int); #ifdef BALANCE_DEBUG void debug_shift_output(int, int, int, double *); #endif diff --git a/src/fix_balance.cpp b/src/fix_balance.cpp index 8ad716a8bc..37328cb617 100644 --- a/src/fix_balance.cpp +++ b/src/fix_balance.cpp @@ -97,9 +97,11 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) : if (clock_factor < 0.0 || clock_factor > 1.0) error->all(FLERR,"Illegal fix balance command"); iarg += 2; +#if 0 } else if (strcmp(arg[iarg],"group") == 0) { int ngroup = balance->group_setup(narg-iarg-1,arg+iarg+1); iarg += 2 + 2*ngroup; +#endif } else error->all(FLERR,"Illegal fix balance command"); } @@ -229,10 +231,11 @@ void FixBalance::pre_exchange() if (domain->triclinic) domain->lamda2x(atom->nlocal); // return if imbalance < threshhold - +#if 0 if (clock_factor > 0.0) last_clock = balance->imbalance_clock(clock_factor,last_clock); imbnow = balance->imbalance_nlocal(maxperproc); +#endif if (imbnow <= thresh) { if (nevery) next_reneighbor = (update->ntimestep/nevery)*nevery + nevery; return; diff --git a/src/imbalance.h b/src/imbalance.h index ec3fce0e0b..99eaee74cd 100644 --- a/src/imbalance.h +++ b/src/imbalance.h @@ -14,25 +14,36 @@ #ifndef LMP_IMBALANCE_H #define LMP_IMBALANCE_H +#include + namespace LAMMPS_NS { class LAMMPS; class Imbalance { public: - Imbalance() {}; + Imbalance(LAMMPS *lmp) : _lmp(lmp) {}; virtual ~Imbalance() {}; - // disallow copy constructor and assignment operator + // disallow default and copy constructor, assignment operator private: + Imbalance() {}; Imbalance(const Imbalance &) {}; Imbalance &operator=(const Imbalance &) {return *this;}; - // required member functions + // internal use only data members + protected: + LAMMPS *_lmp; + + // public API public: - // parse options. return number of arguments consumed. - virtual int options(LAMMPS *lmp, int narg, char **arg) = 0; - // compute and apply weigh factors to local atom array - virtual void compute(LAMMPS *lmp, double *weights) = 0; + // parse options. return number of arguments consumed. (required) + virtual int options(int narg, char **arg) = 0; + // reinitialize internal data (needed for fix balance) (optional) + virtual void init() {}; + // compute and apply weight factors to local atom array (required) + virtual void compute(double *weights) = 0; + // print information about the state of this imbalance compute (required) + virtual void info(FILE *fp) = 0; }; } diff --git a/src/imbalance_group.cpp b/src/imbalance_group.cpp index cc95c282a6..f0ddc13d2e 100644 --- a/src/imbalance_group.cpp +++ b/src/imbalance_group.cpp @@ -21,11 +21,11 @@ using namespace LAMMPS_NS; -int ImbalanceGroup::options(LAMMPS *lmp, int narg, char **arg) +int ImbalanceGroup::options(int narg, char **arg) { - Error *error = lmp->error; - Force *force = lmp->force; - Group *group = lmp->group; + Error *error = _lmp->error; + Force *force = _lmp->force; + Group *group = _lmp->group; if (narg < 3) error->all(FLERR,"Illegal balance weight command"); @@ -41,14 +41,16 @@ int ImbalanceGroup::options(LAMMPS *lmp, int narg, char **arg) error->all(FLERR,"Unknown group in balance weight command"); _factor[i] = force->numeric(FLERR,arg[2*i+2]); } - return _num; + return 2*_num+1; } - -void ImbalanceGroup::compute(LAMMPS *lmp, double *weight) + +/* -------------------------------------------------------------------- */ + +void ImbalanceGroup::compute(double *weight) { - const int * const mask = lmp->atom->mask; - const int * const bitmask = lmp->group->bitmask; - const int nlocal = lmp->atom->nlocal; + const int * const mask = _lmp->atom->mask; + const int * const bitmask = _lmp->group->bitmask; + const int nlocal = _lmp->atom->nlocal; if (_num == 0) return; @@ -62,3 +64,17 @@ void ImbalanceGroup::compute(LAMMPS *lmp, double *weight) weight[i] = iweight; } } + +/* -------------------------------------------------------------------- */ + +void ImbalanceGroup::info(FILE *fp) +{ + if (_num > 0) { + const char * const * const names = _lmp->group->names; + + fprintf(fp," group weights:"); + for (int i = 0; i < _num; ++i) + fprintf(fp," %s=%g",names[_id[i]],_factor[i]); + fputs("\n",fp); + } +} diff --git a/src/imbalance_group.h b/src/imbalance_group.h index 52869f1c61..fdd68d1fab 100644 --- a/src/imbalance_group.h +++ b/src/imbalance_group.h @@ -21,7 +21,7 @@ namespace LAMMPS_NS { class ImbalanceGroup : public Imbalance { public: - ImbalanceGroup() : Imbalance(), _num(0), _id(0), _factor(0) {}; + ImbalanceGroup(LAMMPS *lmp) : Imbalance(lmp),_num(0),_id(0),_factor(0) {}; virtual ~ImbalanceGroup() { delete[] _id; delete[] _factor; }; // internal data members @@ -33,9 +33,11 @@ class ImbalanceGroup : public Imbalance { // required member functions public: // parse options. return number of arguments consumed. - virtual int options(LAMMPS *lmp, int narg, char **arg); - // compute per-atom imbalance and apply to weight array - virtual void compute(LAMMPS *lmp, double *weight); + virtual int options(int narg, char **arg); + // compute and apply weight factors to local atom array + virtual void compute(double *weight); + // print information about the state of this imbalance compute + virtual void info(FILE *fp); }; } diff --git a/src/imbalance_neigh.cpp b/src/imbalance_neigh.cpp index 18ef0c17e7..d2a31f80e0 100644 --- a/src/imbalance_neigh.cpp +++ b/src/imbalance_neigh.cpp @@ -14,14 +14,40 @@ #include "pointers.h" #include "imbalance_neigh.h" +#include "atom.h" +#include "error.h" +#include "comm.h" +#include "force.h" using namespace LAMMPS_NS; -int ImbalanceNeigh::options(LAMMPS *lmp, int narg, char **arg) +int ImbalanceNeigh::options(int narg, char **arg) { - return 0; + Error *error = _lmp->error; + Force *force = _lmp->force; + + if (narg < 1) error->all(FLERR,"Illegal balance weight command"); + _factor = force->numeric(FLERR,arg[0]); + if (_factor < 0.0 || _factor > 1.0) + error->all(FLERR,"Illegal balance weight command"); + return 1; } + +/* -------------------------------------------------------------------- */ -void ImbalanceNeigh::compute(LAMMPS *lmp, double *weight) +void ImbalanceNeigh::compute(double *weight) { + const int nlocal = _lmp->atom->nlocal; + MPI_Comm world = _lmp->world; + + if (_factor > 0.0) { + } +} + +/* -------------------------------------------------------------------- */ + +void ImbalanceNeigh::info(FILE *fp) +{ + if (_factor > 0.0) + fprintf(fp," neigh weight factor: %g\n",_factor); } diff --git a/src/imbalance_neigh.h b/src/imbalance_neigh.h index 7d276b0904..b4e1e576a9 100644 --- a/src/imbalance_neigh.h +++ b/src/imbalance_neigh.h @@ -20,20 +20,20 @@ namespace LAMMPS_NS { class ImbalanceNeigh : public Imbalance { public: - ImbalanceNeigh() : Imbalance() {}; + ImbalanceNeigh(LAMMPS *lmp) : Imbalance(lmp), _factor(0.0) {}; virtual ~ImbalanceNeigh() {}; - // disallow copy constructor and assignment operator + // internal data members private: - ImbalanceNeigh(const ImbalanceNeigh &) {}; - ImbalanceNeigh &operator=(const ImbalanceNeigh &) {return *this;}; + double _factor; // weight factor for neighbor imbalance - // required member functions public: - // parse options. return number of arguments consumed. - virtual int options(LAMMPS *lmp, int narg, char **arg); - // compute per-atom imbalance and apply to weight array - virtual void compute(LAMMPS *lmp, double *weight); + // parse options. return number of arguments consumed + virtual int options(int narg, char **arg); + // compute and apply weight factors to local atom array + virtual void compute(double *weights); + // print information about the state of this imbalance compute + virtual void info(FILE *fp); }; } diff --git a/src/imbalance_time.cpp b/src/imbalance_time.cpp index 1150ab7bfd..82050186f0 100644 --- a/src/imbalance_time.cpp +++ b/src/imbalance_time.cpp @@ -14,14 +14,62 @@ #include "pointers.h" #include "imbalance_time.h" +#include "atom.h" +#include "error.h" +#include "comm.h" +#include "force.h" +#include "timer.h" using namespace LAMMPS_NS; -int ImbalanceTime::options(LAMMPS *lmp, int narg, char **arg) +int ImbalanceTime::options(int narg, char **arg) { - return 0; + Error *error = _lmp->error; + Force *force = _lmp->force; + + if (narg < 1) error->all(FLERR,"Illegal balance weight command"); + _factor = force->numeric(FLERR,arg[0]); + if (_factor < 0.0 || _factor > 1.0) + error->all(FLERR,"Illegal balance weight command"); + return 1; } + +/* -------------------------------------------------------------------- */ -void ImbalanceTime::compute(LAMMPS *lmp, double *weight) +void ImbalanceTime::compute(double *weight) { + const int nlocal = _lmp->atom->nlocal; + const int nprocs = _lmp->comm->nprocs; + MPI_Comm world = _lmp->world; + Timer *timer = _lmp->timer; + + if (_factor > 0.0) { + // compute the cost function of based on relevant timers + if (timer->has_normal()) { + + double cost = -_last; + cost += timer->get_wall(Timer::PAIR); + cost += timer->get_wall(Timer::NEIGH); + cost += timer->get_wall(Timer::BOND); + cost += timer->get_wall(Timer::KSPACE); + + double allcost; + MPI_Allreduce(&cost,&allcost,1,MPI_DOUBLE,MPI_SUM,world); + + if (allcost > 0.0) { + const double scale = (1.0-_factor) + _factor*cost*nprocs/allcost; + for (int i = 0; i < nlocal; ++i) weight[i] *= scale; + } + // record time up to this point + _last += cost; + } + } +} + +/* -------------------------------------------------------------------- */ + +void ImbalanceTime::info(FILE *fp) +{ + if (_factor > 0.0) + fprintf(fp," time weight factor: %g\n",_factor); } diff --git a/src/imbalance_time.h b/src/imbalance_time.h index c9a52c93dc..78c5068a70 100644 --- a/src/imbalance_time.h +++ b/src/imbalance_time.h @@ -20,20 +20,24 @@ namespace LAMMPS_NS { class ImbalanceTime : public Imbalance { public: - ImbalanceTime() : Imbalance() {}; + ImbalanceTime(LAMMPS *lmp) : Imbalance(lmp),_factor(0.0),_last(0.0) {}; virtual ~ImbalanceTime() {}; - // disallow copy constructor and assignment operator + // internal data members private: - ImbalanceTime(const ImbalanceTime &) {}; - ImbalanceTime &operator=(const ImbalanceTime &) {return *this;}; + double _factor; // weight factor for time imbalance + double _last; // combined wall time from last call // required member functions public: // parse options. return number of arguments consumed. - virtual int options(LAMMPS *lmp, int narg, char **arg); - // compute per-atom imbalance and apply to weight array - virtual void compute(LAMMPS *lmp, double *weight); + virtual int options(int narg, char **arg); + // reinitialize internal data (needed for fix balance) + virtual void init() { _last = 0.0; }; + // compute and apply weight factors to local atom array + virtual void compute(double *weight); + // print information about the state of this imbalance compute + virtual void info(FILE *fp); }; } diff --git a/src/imbalance_var.cpp b/src/imbalance_var.cpp index c98dc3e095..5324635259 100644 --- a/src/imbalance_var.cpp +++ b/src/imbalance_var.cpp @@ -11,17 +11,69 @@ See the README file in the top-level LAMMPS directory. ------------------------------------------------------------------------- */ - +#include #include "pointers.h" #include "imbalance_var.h" +#include "atom.h" +#include "error.h" +#include "force.h" +#include "group.h" +#include "input.h" +#include "variable.h" using namespace LAMMPS_NS; -int ImbalanceVar::options(LAMMPS *lmp, int narg, char **arg) +int ImbalanceVar::options(int narg, char **arg) { - return 0; + Error *error = _lmp->error; + Force *force = _lmp->force; + + if (narg < 1) error->all(FLERR,"Illegal balance weight command"); + int len = strlen(arg[0])+1; + _name = new char[len]; + memcpy(_name,arg[0],len); + this->init(); + + return 1; } + +/* -------------------------------------------------------------------- */ -void ImbalanceVar::compute(LAMMPS *lmp, double *weight) +void ImbalanceVar::init() { + Error *error = _lmp->error; + Variable *variable = _lmp->input->variable; + + if (_name) { + _id = variable->find(_name); + if (_id < 0) { + error->all(FLERR,"Variable name for balance weight does not exist"); + } else { + if (variable->atomstyle(_id) == 0) + error->all(FLERR,"Variable for balance weight has invalid style"); + } + } +} + +/* -------------------------------------------------------------------- */ + +void ImbalanceVar::compute(double *weight) +{ + if (_id >= 0) { + const int all = _lmp->group->find("all"); + const int nlocal = _lmp->atom->nlocal; + + double *val = new double[nlocal]; + _lmp->input->variable->compute_atom(_id,all,val,1,0); + for (int i = 0; i < nlocal; ++i) weight[i] *= val[i]; + delete[] val; + } +} + +/* -------------------------------------------------------------------- */ + +void ImbalanceVar::info(FILE *fp) +{ + if (_id >= 0) + fprintf(fp," weight variable: %s\n",_name); } diff --git a/src/imbalance_var.h b/src/imbalance_var.h index 3543725aa1..43e2dfe849 100644 --- a/src/imbalance_var.h +++ b/src/imbalance_var.h @@ -20,20 +20,24 @@ namespace LAMMPS_NS { class ImbalanceVar : public Imbalance { public: - ImbalanceVar() : Imbalance() {}; - virtual ~ImbalanceVar() {}; + ImbalanceVar(LAMMPS *lmp) : Imbalance(lmp), _name(0), _id(-1) {}; + virtual ~ImbalanceVar() { delete[] _name; }; - // disallow copy constructor and assignment operator + // internal data members private: - ImbalanceVar(const ImbalanceVar &) {}; - ImbalanceVar &operator=(const ImbalanceVar &) {return *this;}; + char *_name; // variable name + int _id; // variable ID // required member functions public: // parse options. return number of arguments consumed. - virtual int options(LAMMPS *lmp, int narg, char **arg); + virtual int options(int narg, char **arg); + // re-initialize internal data, e.g. variable ID + virtual void init(); // compute per-atom imbalance and apply to weight array - virtual void compute(LAMMPS *lmp, double *weight); + virtual void compute(double *weight); + // print information about the state of this imbalance compute (required) + virtual void info(FILE *fp); }; }