re-factored balance command now works with group and time weights

This commit is contained in:
Axel Kohlmeyer
2016-08-03 18:34:24 -04:00
parent 8ff0085cba
commit 3f674e5062
12 changed files with 288 additions and 235 deletions

View File

@ -63,12 +63,6 @@ Balance::Balance(LAMMPS *lmp) : Pointers(lmp)
nimbalance = 0;
imbalance = NULL;
weight = NULL;
ngroup = 0;
group_id = NULL;
group_weight = NULL;
clock_imbalance = NULL;
}
/* ---------------------------------------------------------------------- */
@ -100,11 +94,6 @@ Balance::~Balance()
delete [] imbalance;
delete [] weight;
#if 1
delete [] group_id;
delete [] group_weight;
delete [] clock_imbalance;
#endif
if (fp) fclose(fp);
}
@ -238,37 +227,26 @@ void Balance::command(int narg, char **arg)
Imbalance *imb;
int nopt = 0;
if (strcmp(arg[iarg+1],"group") == 0) {
imb = new ImbalanceGroup;
nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1);
imb = new ImbalanceGroup(lmp);
nopt = imb->options(narg-iarg,arg+iarg+2);
imbalance[nimbalance] = imb;
} else if (strcmp(arg[iarg+1],"time") == 0) {
imb = new ImbalanceTime;
nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1);
imb = new ImbalanceTime(lmp);
nopt = imb->options(narg-iarg,arg+iarg+2);
imbalance[nimbalance] = imb;
} else if (strcmp(arg[iarg+1],"neigh") == 0) {
imb = new ImbalanceNeigh;
nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1);
imb = new ImbalanceNeigh(lmp);
nopt = imb->options(narg-iarg,arg+iarg+2);
imbalance[nimbalance] = imb;
} else if (strcmp(arg[iarg+1],"var") == 0) {
imb = new ImbalanceVar;
nopt = imb->options(lmp,narg-iarg-1,arg+iarg+1);
imb = new ImbalanceVar(lmp);
nopt = imb->options(narg-iarg,arg+iarg+2);
imbalance[nimbalance] = imb;
} else {
error->all(FLERR,"Unknown balance weight method");
}
++nimbalance;
iarg += 2+nopt;
#if 1
} else if (strcmp(arg[iarg],"clock") == 0) {
if (iarg+2 > narg) error->all(FLERR,"Illegal balance command");
double factor = force->numeric(FLERR,arg[iarg+1]);
if (factor < 0.0 || factor > 1.0)
error->all(FLERR,"Illegal balance command");
imbalance_clock(factor,0.0);
iarg += 2;
} else if (strcmp(arg[iarg],"group") == 0) {
group_setup(narg-iarg-1,arg+iarg+1);
iarg += 2*ngroup + 2;
#endif
} else error->all(FLERR,"Illegal balance command");
}
@ -321,7 +299,7 @@ void Balance::command(int narg, char **arg)
comm->exchange();
if (domain->triclinic) domain->lamda2x(atom->nlocal);
// compute and apply imbalance weights
// compute and apply imbalance weights for local atoms
if (nimbalance > 0) {
int i;
const int nlocal = atom->nlocal;
@ -329,7 +307,7 @@ void Balance::command(int narg, char **arg)
for (i = 0; i < nlocal; ++i)
weight[i] = 1.0;
for (i = 0; i < nimbalance; ++i)
imbalance[i]->compute(lmp,weight);
imbalance[i]->compute(weight);
} else weight = NULL;
// imbinit = initial imbalance
@ -435,35 +413,39 @@ void Balance::command(int narg, char **arg)
error->all(FLERR,str);
}
// recompute and apply imbalance weights for local atoms
if (nimbalance > 0) {
int i;
const int nlocal = atom->nlocal;
delete[] weight;
weight = new double[nlocal];
for (i = 0; i < nlocal; ++i)
weight[i] = 1.0;
for (i = 0; i < nimbalance; ++i)
imbalance[i]->compute(weight);
} else weight = NULL;
// imbfinal = final imbalance based on final (weighted) nlocal
int maxfinal;
double imbfinal = imbalance_nlocal(maxfinal);
if (me == 0) {
double stop_time = MPI_Wtime();
if (screen) {
fprintf(screen," rebalancing time: %g seconds\n",
MPI_Wtime()-start_time);
fprintf(screen," rebalancing time: %g seconds\n",stop_time-start_time);
fprintf(screen," iteration count = %d\n",niter);
if (ngroup > 0) {
fprintf(screen," group weights:");
for (int i=0; i < ngroup; ++i)
fprintf(screen," %s=%g", group->names[group_id[i]],group_weight[i]);
fprintf(screen,"\n");
}
for (int i = 0; i < nimbalance; ++i) imbalance[i]->info(screen);
fprintf(screen," initial/final max load/proc = %d %d\n",
maxinit,maxfinal);
fprintf(screen," initial/final imbalance factor = %g %g\n",
imbinit,imbfinal);
}
if (logfile) {
fprintf(logfile," rebalancing time: %g seconds\n",stop_time-start_time);
fprintf(logfile," iteration count = %d\n",niter);
if (ngroup > 0) {
fprintf(logfile," group weights:");
for (int i=0; i < ngroup; ++i)
fprintf(logfile," %s=%g", group->names[group_id[i]],group_weight[i]);
fprintf(logfile,"\n");
}
for (int i = 0; i < nimbalance; ++i) imbalance[i]->info(logfile);
fprintf(logfile," initial/final max load/proc = %d %d\n",
maxinit,maxfinal);
fprintf(logfile," initial/final imbalance factor = %g %g\n",
@ -505,69 +487,6 @@ void Balance::command(int narg, char **arg)
}
}
/* ----------------------------------------------------------------------
compute the computational load associated with an atom
i = atom index
return cost = product of group weights for this atom.
------------------------------------------------------------------------- */
double Balance::getcost(int i)
{
double cost = 1.0;
for (int j = 0; j < ngroup; ++j) {
if (atom->mask[i] & group->bitmask[group_id[j]])
cost *= group_weight[j];
}
return cost;
}
/* ----------------------------------------------------------------------
calculate imbalance based on timers for Pair+Bond+Kspace+Neighbor time.
------------------------------------------------------------------------- */
double Balance::imbalance_clock(double factor, double last_cost)
{
// Compute the cost function of based on relevant timers
if (timer->has_normal()) {
if (!clock_imbalance) clock_imbalance = new double[nprocs+1];
double cost = -last_cost;
cost += timer->get_wall(Timer::PAIR);
cost += timer->get_wall(Timer::NEIGH);
cost += timer->get_wall(Timer::BOND);
cost += timer->get_wall(Timer::KSPACE);
double *clock_cost = new double[nprocs+1];
for (int i = 0; i <= nprocs; ++i) clock_imbalance[i] = clock_cost[i] = 0.0;
clock_cost[me] = cost;
clock_cost[nprocs] = cost;
MPI_Allreduce(clock_cost,clock_imbalance,nprocs+1,MPI_DOUBLE,MPI_SUM,world);
const double avg_cost = clock_imbalance[nprocs]/nprocs;
if (avg_cost > 0.0) {
for (int i = 0; i < nprocs; ++i)
clock_imbalance[i] = (1.0-factor) + factor*clock_imbalance[i]/avg_cost;
} else {
for (int i = 0; i < nprocs; ++i)
clock_imbalance[i] = 1.0;
}
#if BALANCE_DEBUG
if (me == 0) {
fprintf(stderr,"Clock imbalance using factor %g\n",factor);
for (int i = 0; i < nprocs; ++i)
fprintf(stderr," % 2d: %4.2f",i,clock_imbalance[i]);
fputs("\n",stderr);
}
#endif
delete [] clock_cost;
return cost + last_cost;
}
return last_cost;
}
/* ----------------------------------------------------------------------
calculate imbalance based on (weighted) local atom counts
return max = max atom per proc
@ -579,10 +498,12 @@ double Balance::imbalance_nlocal(int &maxcost)
// Compute the cost function of local atoms
double cost = 0.0;
for (int i=0; i < atom->nlocal; ++i) {
cost += getcost(i);
if (weight == NULL) {
cost = atom->nlocal;
} else {
for (int i=0; i < atom->nlocal; ++i)
cost += weight[i];
}
if (clock_imbalance) cost *= clock_imbalance[me];
int intcost = (int)cost;
int sumcost = maxcost = 0;
@ -622,21 +543,23 @@ double Balance::imbalance_splits(int &max)
int nlocal = atom->nlocal;
int ix,iy,iz;
for (int i = 0; i < nlocal; i++) {
ix = binary(x[i][0],nx,xsplit);
iy = binary(x[i][1],ny,ysplit);
iz = binary(x[i][2],nz,zsplit);
proccost[iz*nx*ny + iy*nx + ix] += getcost(i);
}
for (int i = 0; i < nprocs; i++) {
if (clock_imbalance)
proccount[i] = static_cast<int>(proccost[i]*clock_imbalance[i]);
else
proccount[i] = static_cast<int>(proccost[i]);
if (weight) {
for (int i = 0; i < nlocal; i++) {
ix = binary(x[i][0],nx,xsplit);
iy = binary(x[i][1],ny,ysplit);
iz = binary(x[i][2],nz,zsplit);
proccost[iz*nx*ny + iy*nx + ix] += weight[i];
}
} else {
for (int i = 0; i < nlocal; i++) {
ix = binary(x[i][0],nx,xsplit);
iy = binary(x[i][1],ny,ysplit);
iz = binary(x[i][2],nz,zsplit);
proccost[iz*nx*ny + iy*nx + ix] += 1.0;
}
}
for (int i = 0; i < nprocs; i++) proccount[i] = (int)(proccost[i]);
MPI_Allreduce(proccount,allproccount,nprocs,MPI_INT,MPI_SUM,world);
bigint sum = 0;
max = 0;
@ -698,20 +621,10 @@ int *Balance::bisection(int sortflag)
// invoke RCB
// then invert() to create list of proc assignements for my atoms
// Use specified weightings for each atom rather than atom count
// Use compute weights for each atom, if available
#if 1
double factor = 1.0;
if (clock_imbalance) factor = clock_imbalance[me];
double *weights = new double[nlocal];
for (int i = 0; i < nlocal; i++)
weights[i] = getcost(i)*factor;
#endif
rcb->compute(dim,atom->nlocal,atom->x,weights,shrinklo,shrinkhi);
rcb->compute(dim,atom->nlocal,atom->x,weight,shrinklo,shrinkhi);
rcb->invert(sortflag);
delete[] weights;
// reset RCB lo/hi bounding box to full simulation box as needed
@ -820,28 +733,6 @@ void Balance::shift_setup(char *str, int nitermax_in, double thresh_in)
rho = 1;
}
/* ----------------------------------------------------------------------
setup group based load balance operations
called from balance->command() and fix balance
------------------------------------------------------------------------- */
int Balance::group_setup(int narg, char **arg)
{
if (narg < 3) error->all(FLERR,"Illegal balance command");
ngroup = force->inumeric(FLERR,arg[0]);
if (ngroup < 1) error->all(FLERR,"Illegal balance command");
if (2*ngroup+1 > narg) error->all(FLERR,"Illegal balance command");
group_id = new int[ngroup];
group_weight = new double[ngroup];
for (int i = 0; i < ngroup; ++i) {
group_id[i] = group->find(arg[2*i+1]);
if (group_id[i] < 0) error->all(FLERR,"Unknown group in balance command");
group_weight[i] = force->numeric(FLERR,arg[2*i+2]);
}
return ngroup;
}
/* ----------------------------------------------------------------------
load balance by changing xyz split proc boundaries in Comm
called one time from input script command or many times from fix balance
@ -886,10 +777,13 @@ int Balance::shift()
tally(bdim[idim],np,split);
double cost = 0.0;
for (i=0; i < atom->nlocal; i++)
cost += getcost(i);
if (weight == NULL) {
cost = atom->nlocal;
} else {
for (int i=0; i < atom->nlocal; ++i)
cost += weight[i];
}
if (clock_imbalance) cost *= clock_imbalance[me];
int intcost = (int)cost;
int totalcost;
MPI_Allreduce(&intcost,&totalcost,1,MPI_INT,MPI_SUM,world);
@ -1033,12 +927,16 @@ void Balance::tally(int dim, int n, double *split)
int nlocal = atom->nlocal;
int index;
double factor = 1.0;
if (clock_imbalance) factor = clock_imbalance[me];
for (int i = 0; i < nlocal; i++) {
index = binary(x[i][dim],n,split);
onecost[index] += getcost(i)*factor;
if (weight) {
for (int i = 0; i < nlocal; i++) {
index = binary(x[i][dim],n,split);
onecost[index] += weight[i];
}
} else {
for (int i = 0; i < nlocal; i++) {
index = binary(x[i][dim],n,split);
onecost[index] += 1.0;
}
}
for (int i = 0; i < n; i++) onecount[i] = static_cast<bigint>(onecost[i]);

View File

@ -32,12 +32,10 @@ class Balance : protected Pointers {
Balance(class LAMMPS *);
~Balance();
void command(int, char **);
int group_setup(int, char **);
void shift_setup(char *, int, double);
int shift();
int *bisection(int sortflag = 0);
double imbalance_nlocal(int &);
double imbalance_clock(double, double);
void dumpout(bigint, FILE *);
private:
@ -71,14 +69,6 @@ class Balance : protected Pointers {
class Imbalance **imbalance; // list of imbalance compute classes
double *weight; // per (local) atom weight factor or NULL
#if 1
int ngroup; // number of groups weights
int *group_id; // group ids for weights
double *group_weight; // weights of groups
double *clock_imbalance; // computed wall clock imbalance, NULL if not available
#endif
int outflag; // for output of balance results to file
FILE *fp;
int firststep;
@ -88,7 +78,6 @@ class Balance : protected Pointers {
void tally(int, int, double *);
int adjust(int, double *);
int binary(double, int, double *);
double getcost(int);
#ifdef BALANCE_DEBUG
void debug_shift_output(int, int, int, double *);
#endif

View File

@ -97,9 +97,11 @@ FixBalance::FixBalance(LAMMPS *lmp, int narg, char **arg) :
if (clock_factor < 0.0 || clock_factor > 1.0)
error->all(FLERR,"Illegal fix balance command");
iarg += 2;
#if 0
} else if (strcmp(arg[iarg],"group") == 0) {
int ngroup = balance->group_setup(narg-iarg-1,arg+iarg+1);
iarg += 2 + 2*ngroup;
#endif
} else error->all(FLERR,"Illegal fix balance command");
}
@ -229,10 +231,11 @@ void FixBalance::pre_exchange()
if (domain->triclinic) domain->lamda2x(atom->nlocal);
// return if imbalance < threshhold
#if 0
if (clock_factor > 0.0)
last_clock = balance->imbalance_clock(clock_factor,last_clock);
imbnow = balance->imbalance_nlocal(maxperproc);
#endif
if (imbnow <= thresh) {
if (nevery) next_reneighbor = (update->ntimestep/nevery)*nevery + nevery;
return;

View File

@ -14,25 +14,36 @@
#ifndef LMP_IMBALANCE_H
#define LMP_IMBALANCE_H
#include <stdio.h>
namespace LAMMPS_NS {
class LAMMPS;
class Imbalance {
public:
Imbalance() {};
Imbalance(LAMMPS *lmp) : _lmp(lmp) {};
virtual ~Imbalance() {};
// disallow copy constructor and assignment operator
// disallow default and copy constructor, assignment operator
private:
Imbalance() {};
Imbalance(const Imbalance &) {};
Imbalance &operator=(const Imbalance &) {return *this;};
// required member functions
// internal use only data members
protected:
LAMMPS *_lmp;
// public API
public:
// parse options. return number of arguments consumed.
virtual int options(LAMMPS *lmp, int narg, char **arg) = 0;
// compute and apply weigh factors to local atom array
virtual void compute(LAMMPS *lmp, double *weights) = 0;
// parse options. return number of arguments consumed. (required)
virtual int options(int narg, char **arg) = 0;
// reinitialize internal data (needed for fix balance) (optional)
virtual void init() {};
// compute and apply weight factors to local atom array (required)
virtual void compute(double *weights) = 0;
// print information about the state of this imbalance compute (required)
virtual void info(FILE *fp) = 0;
};
}

View File

@ -21,11 +21,11 @@
using namespace LAMMPS_NS;
int ImbalanceGroup::options(LAMMPS *lmp, int narg, char **arg)
int ImbalanceGroup::options(int narg, char **arg)
{
Error *error = lmp->error;
Force *force = lmp->force;
Group *group = lmp->group;
Error *error = _lmp->error;
Force *force = _lmp->force;
Group *group = _lmp->group;
if (narg < 3) error->all(FLERR,"Illegal balance weight command");
@ -41,14 +41,16 @@ int ImbalanceGroup::options(LAMMPS *lmp, int narg, char **arg)
error->all(FLERR,"Unknown group in balance weight command");
_factor[i] = force->numeric(FLERR,arg[2*i+2]);
}
return _num;
return 2*_num+1;
}
void ImbalanceGroup::compute(LAMMPS *lmp, double *weight)
/* -------------------------------------------------------------------- */
void ImbalanceGroup::compute(double *weight)
{
const int * const mask = lmp->atom->mask;
const int * const bitmask = lmp->group->bitmask;
const int nlocal = lmp->atom->nlocal;
const int * const mask = _lmp->atom->mask;
const int * const bitmask = _lmp->group->bitmask;
const int nlocal = _lmp->atom->nlocal;
if (_num == 0) return;
@ -62,3 +64,17 @@ void ImbalanceGroup::compute(LAMMPS *lmp, double *weight)
weight[i] = iweight;
}
}
/* -------------------------------------------------------------------- */
void ImbalanceGroup::info(FILE *fp)
{
if (_num > 0) {
const char * const * const names = _lmp->group->names;
fprintf(fp," group weights:");
for (int i = 0; i < _num; ++i)
fprintf(fp," %s=%g",names[_id[i]],_factor[i]);
fputs("\n",fp);
}
}

View File

@ -21,7 +21,7 @@ namespace LAMMPS_NS {
class ImbalanceGroup : public Imbalance {
public:
ImbalanceGroup() : Imbalance(), _num(0), _id(0), _factor(0) {};
ImbalanceGroup(LAMMPS *lmp) : Imbalance(lmp),_num(0),_id(0),_factor(0) {};
virtual ~ImbalanceGroup() { delete[] _id; delete[] _factor; };
// internal data members
@ -33,9 +33,11 @@ class ImbalanceGroup : public Imbalance {
// required member functions
public:
// parse options. return number of arguments consumed.
virtual int options(LAMMPS *lmp, int narg, char **arg);
// compute per-atom imbalance and apply to weight array
virtual void compute(LAMMPS *lmp, double *weight);
virtual int options(int narg, char **arg);
// compute and apply weight factors to local atom array
virtual void compute(double *weight);
// print information about the state of this imbalance compute
virtual void info(FILE *fp);
};
}

View File

@ -14,14 +14,40 @@
#include "pointers.h"
#include "imbalance_neigh.h"
#include "atom.h"
#include "error.h"
#include "comm.h"
#include "force.h"
using namespace LAMMPS_NS;
int ImbalanceNeigh::options(LAMMPS *lmp, int narg, char **arg)
int ImbalanceNeigh::options(int narg, char **arg)
{
return 0;
Error *error = _lmp->error;
Force *force = _lmp->force;
if (narg < 1) error->all(FLERR,"Illegal balance weight command");
_factor = force->numeric(FLERR,arg[0]);
if (_factor < 0.0 || _factor > 1.0)
error->all(FLERR,"Illegal balance weight command");
return 1;
}
/* -------------------------------------------------------------------- */
void ImbalanceNeigh::compute(LAMMPS *lmp, double *weight)
void ImbalanceNeigh::compute(double *weight)
{
const int nlocal = _lmp->atom->nlocal;
MPI_Comm world = _lmp->world;
if (_factor > 0.0) {
}
}
/* -------------------------------------------------------------------- */
void ImbalanceNeigh::info(FILE *fp)
{
if (_factor > 0.0)
fprintf(fp," neigh weight factor: %g\n",_factor);
}

View File

@ -20,20 +20,20 @@ namespace LAMMPS_NS {
class ImbalanceNeigh : public Imbalance {
public:
ImbalanceNeigh() : Imbalance() {};
ImbalanceNeigh(LAMMPS *lmp) : Imbalance(lmp), _factor(0.0) {};
virtual ~ImbalanceNeigh() {};
// disallow copy constructor and assignment operator
// internal data members
private:
ImbalanceNeigh(const ImbalanceNeigh &) {};
ImbalanceNeigh &operator=(const ImbalanceNeigh &) {return *this;};
double _factor; // weight factor for neighbor imbalance
// required member functions
public:
// parse options. return number of arguments consumed.
virtual int options(LAMMPS *lmp, int narg, char **arg);
// compute per-atom imbalance and apply to weight array
virtual void compute(LAMMPS *lmp, double *weight);
// parse options. return number of arguments consumed
virtual int options(int narg, char **arg);
// compute and apply weight factors to local atom array
virtual void compute(double *weights);
// print information about the state of this imbalance compute
virtual void info(FILE *fp);
};
}

View File

@ -14,14 +14,62 @@
#include "pointers.h"
#include "imbalance_time.h"
#include "atom.h"
#include "error.h"
#include "comm.h"
#include "force.h"
#include "timer.h"
using namespace LAMMPS_NS;
int ImbalanceTime::options(LAMMPS *lmp, int narg, char **arg)
int ImbalanceTime::options(int narg, char **arg)
{
return 0;
Error *error = _lmp->error;
Force *force = _lmp->force;
if (narg < 1) error->all(FLERR,"Illegal balance weight command");
_factor = force->numeric(FLERR,arg[0]);
if (_factor < 0.0 || _factor > 1.0)
error->all(FLERR,"Illegal balance weight command");
return 1;
}
/* -------------------------------------------------------------------- */
void ImbalanceTime::compute(LAMMPS *lmp, double *weight)
void ImbalanceTime::compute(double *weight)
{
const int nlocal = _lmp->atom->nlocal;
const int nprocs = _lmp->comm->nprocs;
MPI_Comm world = _lmp->world;
Timer *timer = _lmp->timer;
if (_factor > 0.0) {
// compute the cost function of based on relevant timers
if (timer->has_normal()) {
double cost = -_last;
cost += timer->get_wall(Timer::PAIR);
cost += timer->get_wall(Timer::NEIGH);
cost += timer->get_wall(Timer::BOND);
cost += timer->get_wall(Timer::KSPACE);
double allcost;
MPI_Allreduce(&cost,&allcost,1,MPI_DOUBLE,MPI_SUM,world);
if (allcost > 0.0) {
const double scale = (1.0-_factor) + _factor*cost*nprocs/allcost;
for (int i = 0; i < nlocal; ++i) weight[i] *= scale;
}
// record time up to this point
_last += cost;
}
}
}
/* -------------------------------------------------------------------- */
void ImbalanceTime::info(FILE *fp)
{
if (_factor > 0.0)
fprintf(fp," time weight factor: %g\n",_factor);
}

View File

@ -20,20 +20,24 @@ namespace LAMMPS_NS {
class ImbalanceTime : public Imbalance {
public:
ImbalanceTime() : Imbalance() {};
ImbalanceTime(LAMMPS *lmp) : Imbalance(lmp),_factor(0.0),_last(0.0) {};
virtual ~ImbalanceTime() {};
// disallow copy constructor and assignment operator
// internal data members
private:
ImbalanceTime(const ImbalanceTime &) {};
ImbalanceTime &operator=(const ImbalanceTime &) {return *this;};
double _factor; // weight factor for time imbalance
double _last; // combined wall time from last call
// required member functions
public:
// parse options. return number of arguments consumed.
virtual int options(LAMMPS *lmp, int narg, char **arg);
// compute per-atom imbalance and apply to weight array
virtual void compute(LAMMPS *lmp, double *weight);
virtual int options(int narg, char **arg);
// reinitialize internal data (needed for fix balance)
virtual void init() { _last = 0.0; };
// compute and apply weight factors to local atom array
virtual void compute(double *weight);
// print information about the state of this imbalance compute
virtual void info(FILE *fp);
};
}

View File

@ -11,17 +11,69 @@
See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */
#include <string.h>
#include "pointers.h"
#include "imbalance_var.h"
#include "atom.h"
#include "error.h"
#include "force.h"
#include "group.h"
#include "input.h"
#include "variable.h"
using namespace LAMMPS_NS;
int ImbalanceVar::options(LAMMPS *lmp, int narg, char **arg)
int ImbalanceVar::options(int narg, char **arg)
{
return 0;
Error *error = _lmp->error;
Force *force = _lmp->force;
if (narg < 1) error->all(FLERR,"Illegal balance weight command");
int len = strlen(arg[0])+1;
_name = new char[len];
memcpy(_name,arg[0],len);
this->init();
return 1;
}
/* -------------------------------------------------------------------- */
void ImbalanceVar::compute(LAMMPS *lmp, double *weight)
void ImbalanceVar::init()
{
Error *error = _lmp->error;
Variable *variable = _lmp->input->variable;
if (_name) {
_id = variable->find(_name);
if (_id < 0) {
error->all(FLERR,"Variable name for balance weight does not exist");
} else {
if (variable->atomstyle(_id) == 0)
error->all(FLERR,"Variable for balance weight has invalid style");
}
}
}
/* -------------------------------------------------------------------- */
void ImbalanceVar::compute(double *weight)
{
if (_id >= 0) {
const int all = _lmp->group->find("all");
const int nlocal = _lmp->atom->nlocal;
double *val = new double[nlocal];
_lmp->input->variable->compute_atom(_id,all,val,1,0);
for (int i = 0; i < nlocal; ++i) weight[i] *= val[i];
delete[] val;
}
}
/* -------------------------------------------------------------------- */
void ImbalanceVar::info(FILE *fp)
{
if (_id >= 0)
fprintf(fp," weight variable: %s\n",_name);
}

View File

@ -20,20 +20,24 @@ namespace LAMMPS_NS {
class ImbalanceVar : public Imbalance {
public:
ImbalanceVar() : Imbalance() {};
virtual ~ImbalanceVar() {};
ImbalanceVar(LAMMPS *lmp) : Imbalance(lmp), _name(0), _id(-1) {};
virtual ~ImbalanceVar() { delete[] _name; };
// disallow copy constructor and assignment operator
// internal data members
private:
ImbalanceVar(const ImbalanceVar &) {};
ImbalanceVar &operator=(const ImbalanceVar &) {return *this;};
char *_name; // variable name
int _id; // variable ID
// required member functions
public:
// parse options. return number of arguments consumed.
virtual int options(LAMMPS *lmp, int narg, char **arg);
virtual int options(int narg, char **arg);
// re-initialize internal data, e.g. variable ID
virtual void init();
// compute per-atom imbalance and apply to weight array
virtual void compute(LAMMPS *lmp, double *weight);
virtual void compute(double *weight);
// print information about the state of this imbalance compute (required)
virtual void info(FILE *fp);
};
}