diff --git a/src/compute_reduce.cpp b/src/compute_reduce.cpp index c8426673ab..41a1831dc8 100644 --- a/src/compute_reduce.cpp +++ b/src/compute_reduce.cpp @@ -15,6 +15,7 @@ #include "arg_info.h" #include "atom.h" +#include "comm.h" #include "domain.h" #include "error.h" #include "fix.h" @@ -34,16 +35,15 @@ using namespace LAMMPS_NS; /* ---------------------------------------------------------------------- */ ComputeReduce::ComputeReduce(LAMMPS *lmp, int narg, char **arg) : - Compute(lmp, narg, arg), nvalues(0), which(nullptr), argindex(nullptr), flavor(nullptr), - value2index(nullptr), ids(nullptr), onevec(nullptr), replace(nullptr), indices(nullptr), + Compute(lmp, narg, arg), nvalues(0), onevec(nullptr), replace(nullptr), indices(nullptr), owner(nullptr), idregion(nullptr), region(nullptr), varatom(nullptr) { int iarg = 0; if (strcmp(style, "reduce") == 0) { - if (narg < 5) error->all(FLERR, "Illegal compute reduce command"); + if (narg < 5) utils::missing_cmd_args(FLERR, "compute reduce", error); iarg = 3; } else if (strcmp(style, "reduce/region") == 0) { - if (narg < 6) error->all(FLERR, "Illegal compute reduce/region command"); + if (narg < 6) utils::missing_cmd_args(FLERR, "compute reduce/region", error); if (!domain->get_region_by_id(arg[3])) error->all(FLERR, "Region {} for compute reduce/region does not exist", arg[3]); idregion = utils::strdup(arg[3]); @@ -67,11 +67,9 @@ ComputeReduce::ComputeReduce(LAMMPS *lmp, int narg, char **arg) : else if (strcmp(arg[iarg], "aveabs") == 0) mode = AVEABS; else - error->all(FLERR, "Illegal compute {} operation {}", style, arg[iarg]); + error->all(FLERR, "Unknown compute {} mode: {}", style, arg[iarg]); iarg++; - MPI_Comm_rank(world, &me); - // expand args if any have wildcard character "*" int expand = 0; @@ -81,95 +79,92 @@ ComputeReduce::ComputeReduce(LAMMPS *lmp, int narg, char **arg) : if (earg != &arg[iarg]) expand = 1; arg = earg; - // parse values until one isn't recognized + // parse values - which = new int[nargnew]; - argindex = new int[nargnew]; - flavor = new int[nargnew]; - ids = new char *[nargnew]; - value2index = new int[nargnew]; - for (int i = 0; i < nargnew; ++i) { - which[i] = argindex[i] = flavor[i] = value2index[i] = ArgInfo::UNKNOWN; - ids[i] = nullptr; - } + values.clear(); nvalues = 0; + for (int iarg = 0; iarg < nargnew; ++iarg) { + value_t val; - iarg = 0; - while (iarg < nargnew) { - ids[nvalues] = nullptr; + val.id = ""; + val.flavor = 0; + val.val.c = nullptr; if (strcmp(arg[iarg], "x") == 0) { - which[nvalues] = ArgInfo::X; - argindex[nvalues++] = 0; + val.which = ArgInfo::X; + val.argindex = 0; } else if (strcmp(arg[iarg], "y") == 0) { - which[nvalues] = ArgInfo::X; - argindex[nvalues++] = 1; + val.which = ArgInfo::X; + val.argindex = 1; } else if (strcmp(arg[iarg], "z") == 0) { - which[nvalues] = ArgInfo::X; - argindex[nvalues++] = 2; + val.which = ArgInfo::X; + val.argindex = 2; } else if (strcmp(arg[iarg], "vx") == 0) { - which[nvalues] = ArgInfo::V; - argindex[nvalues++] = 0; + val.which = ArgInfo::V; + val.argindex = 0; } else if (strcmp(arg[iarg], "vy") == 0) { - which[nvalues] = ArgInfo::V; - argindex[nvalues++] = 1; + val.which = ArgInfo::V; + val.argindex = 1; } else if (strcmp(arg[iarg], "vz") == 0) { - which[nvalues] = ArgInfo::V; - argindex[nvalues++] = 2; + val.which = ArgInfo::V; + val.argindex = 2; } else if (strcmp(arg[iarg], "fx") == 0) { - which[nvalues] = ArgInfo::F; - argindex[nvalues++] = 0; + val.which = ArgInfo::F; + val.argindex = 0; } else if (strcmp(arg[iarg], "fy") == 0) { - which[nvalues] = ArgInfo::F; - argindex[nvalues++] = 1; + val.which = ArgInfo::F; + val.argindex = 1; } else if (strcmp(arg[iarg], "fz") == 0) { - which[nvalues] = ArgInfo::F; - argindex[nvalues++] = 2; + val.which = ArgInfo::F; + val.argindex = 2; } else { ArgInfo argi(arg[iarg]); - which[nvalues] = argi.get_type(); - argindex[nvalues] = argi.get_index1(); - ids[nvalues] = argi.copy_name(); + val.which = argi.get_type(); + val.argindex = argi.get_index1(); + val.id = argi.get_name(); - if ((which[nvalues] == ArgInfo::UNKNOWN) || (argi.get_dim() > 1)) - error->all(FLERR, "Illegal compute reduce command"); + if ((val.which == ArgInfo::UNKNOWN) || (argi.get_dim() > 1)) + error->all(FLERR, "Illegal compute {} argument: {}", style, arg[iarg]); - if (which[nvalues] == ArgInfo::NONE) break; - nvalues++; + if (val.which == ArgInfo::NONE) break; } - - iarg++; + values.push_back(val); } // optional args + nvalues = values.size(); replace = new int[nvalues]; - for (int i = 0; i < nvalues; i++) replace[i] = -1; + for (int i = 0; i < nvalues; ++i) replace[i] = -1; + std::string mycmd = "compute "; + mycmd += style; - while (iarg < nargnew) { + for (int iarg = nvalues; iarg < nargnew; iarg++) { if (strcmp(arg[iarg], "replace") == 0) { - if (iarg + 3 > narg) error->all(FLERR, "Illegal compute reduce command"); + if (iarg + 3 > narg) utils::missing_cmd_args(FLERR, mycmd + " replace", error); if (mode != MINN && mode != MAXX) - error->all(FLERR, "Compute reduce replace requires min or max mode"); + error->all(FLERR, "Compute {} replace requires min or max mode", style); int col1 = utils::inumeric(FLERR, arg[iarg + 1], false, lmp) - 1; int col2 = utils::inumeric(FLERR, arg[iarg + 2], false, lmp) - 1; - if (col1 < 0 || col1 >= nvalues || col2 < 0 || col2 >= nvalues) - error->all(FLERR, "Illegal compute reduce command"); - if (col1 == col2) error->all(FLERR, "Illegal compute reduce command"); - if (replace[col1] >= 0 || replace[col2] >= 0) - error->all(FLERR, "Invalid replace values in compute reduce"); + if ((col1 < 0) || (col1 >= nvalues)) + error->all(FLERR, "Invalid compute {} replace first column index {}", style, col1); + if ((col2 < 0) || (col2 >= nvalues)) + error->all(FLERR, "Invalid compute {} replace second column index {}", style, col2); + if (col1 == col2) error->all(FLERR, "Compute {} replace columns must be different"); + if ((replace[col1] >= 0) || (replace[col2] >= 0)) + error->all(FLERR, "Compute {} replace column already used for another replacement"); replace[col1] = col2; - iarg += 3; + iarg += 2; } else - error->all(FLERR, "Illegal compute reduce command"); + error->all(FLERR, "Unknown compute {} keyword: {}", style, arg[iarg]); } - // delete replace if not set + // delete replace list if not set int flag = 0; for (int i = 0; i < nvalues; i++) @@ -188,68 +183,67 @@ ComputeReduce::ComputeReduce(LAMMPS *lmp, int narg, char **arg) : // setup and error check - for (int i = 0; i < nvalues; i++) { - if (which[i] == ArgInfo::X || which[i] == ArgInfo::V || which[i] == ArgInfo::F) - flavor[i] = PERATOM; + for (auto &val : values) { + if (val.which == ArgInfo::X || val.which == ArgInfo::V || val.which == ArgInfo::F) + val.flavor = PERATOM; - else if (which[i] == ArgInfo::COMPUTE) { - int icompute = modify->find_compute(ids[i]); - if (icompute < 0) error->all(FLERR, "Compute ID for compute reduce does not exist"); - if (modify->compute[icompute]->peratom_flag) { - flavor[i] = PERATOM; - if (argindex[i] == 0 && modify->compute[icompute]->size_peratom_cols != 0) - error->all(FLERR, - "Compute reduce compute does not " - "calculate a per-atom vector"); - if (argindex[i] && modify->compute[icompute]->size_peratom_cols == 0) - error->all(FLERR, - "Compute reduce compute does not " - "calculate a per-atom array"); - if (argindex[i] && argindex[i] > modify->compute[icompute]->size_peratom_cols) - error->all(FLERR, "Compute reduce compute array is accessed out-of-range"); - } else if (modify->compute[icompute]->local_flag) { - flavor[i] = LOCAL; - if (argindex[i] == 0 && modify->compute[icompute]->size_local_cols != 0) - error->all(FLERR, - "Compute reduce compute does not " - "calculate a local vector"); - if (argindex[i] && modify->compute[icompute]->size_local_cols == 0) - error->all(FLERR, - "Compute reduce compute does not " - "calculate a local array"); - if (argindex[i] && argindex[i] > modify->compute[icompute]->size_local_cols) - error->all(FLERR, "Compute reduce compute array is accessed out-of-range"); + else if (val.which == ArgInfo::COMPUTE) { + val.val.c = modify->get_compute_by_id(val.id); + if (!val.val.c) + error->all(FLERR, "Compute ID {} for compute {} does not exist", val.id, style); + if (val.val.c->peratom_flag) { + val.flavor = PERATOM; + if (val.argindex == 0 && val.val.c->size_peratom_cols != 0) + error->all(FLERR, "Compute {} compute {} does not calculate a per-atom vector", style, + val.id); + if (val.argindex && val.val.c->size_peratom_cols == 0) + error->all(FLERR, "Compute {} compute {} does not calculate a per-atom array", style, + val.id); + if (val.argindex && val.argindex > val.val.c->size_peratom_cols) + error->all(FLERR, "Compute {} compute {} array is accessed out-of-range", style, val.id); + } else if (val.val.c->local_flag) { + val.flavor = LOCAL; + if (val.argindex == 0 && val.val.c->size_local_cols != 0) + error->all(FLERR, "Compute {} compute {} does not calculate a local vector", style, + val.id); + if (val.argindex && val.val.c->size_local_cols == 0) + error->all(FLERR, "Compute {} compute {} does not calculate a local array", style, + val.id); + if (val.argindex && val.argindex > val.val.c->size_local_cols) + error->all(FLERR, "Compute {} compute {} array is accessed out-of-range", style, val.id); } else - error->all(FLERR, "Compute reduce compute calculates global values"); + error->all(FLERR, "Compute {} compute {} calculates global values", style, val.id); - } else if (which[i] == ArgInfo::FIX) { - auto ifix = modify->get_fix_by_id(ids[i]); - if (!ifix) error->all(FLERR, "Fix ID {} for compute reduce does not exist", ids[i]); - if (ifix->peratom_flag) { - flavor[i] = PERATOM; - if (argindex[i] == 0 && (ifix->size_peratom_cols != 0)) - error->all(FLERR, "Compute reduce fix {} does not calculate a per-atom vector", ids[i]); - if (argindex[i] && (ifix->size_peratom_cols == 0)) - error->all(FLERR, "Compute reduce fix {} does not calculate a per-atom array", ids[i]); - if (argindex[i] && (argindex[i] > ifix->size_peratom_cols)) - error->all(FLERR, "Compute reduce fix {} array is accessed out-of-range", ids[i]); - } else if (ifix->local_flag) { - flavor[i] = LOCAL; - if (argindex[i] == 0 && (ifix->size_local_cols != 0)) - error->all(FLERR, "Compute reduce fix {} does not calculate a local vector", ids[i]); - if (argindex[i] && (ifix->size_local_cols == 0)) - error->all(FLERR, "Compute reduce fix {} does not calculate a local array", ids[i]); - if (argindex[i] && (argindex[i] > ifix->size_local_cols)) - error->all(FLERR, "Compute reduce fix {} array is accessed out-of-range", ids[i]); + } else if (val.which == ArgInfo::FIX) { + val.val.f = modify->get_fix_by_id(val.id); + if (!val.val.f) error->all(FLERR, "Fix ID {} for compute {} does not exist", val.id, style); + if (val.val.f->peratom_flag) { + val.flavor = PERATOM; + if (val.argindex == 0 && (val.val.f->size_peratom_cols != 0)) + error->all(FLERR, "Compute {} fix {} does not calculate a per-atom vector", style, + val.id); + if (val.argindex && (val.val.f->size_peratom_cols == 0)) + error->all(FLERR, "Compute {} fix {} does not calculate a per-atom array", style, val.id); + if (val.argindex && (val.argindex > val.val.f->size_peratom_cols)) + error->all(FLERR, "Compute {} fix {} array is accessed out-of-range", style, val.id); + } else if (val.val.f->local_flag) { + val.flavor = LOCAL; + if (val.argindex == 0 && (val.val.f->size_local_cols != 0)) + error->all(FLERR, "Compute {} fix {} does not calculate a local vector", style, val.id); + if (val.argindex && (val.val.f->size_local_cols == 0)) + error->all(FLERR, "Compute {} fix {} does not calculate a local array", style, val.id); + if (val.argindex && (val.argindex > val.val.f->size_local_cols)) + error->all(FLERR, "Compute {} fix {} array is accessed out-of-range", style, val.id); } else - error->all(FLERR, "Compute reduce fix {} calculates global values", ids[i]); + error->all(FLERR, "Compute {} fix {} calculates global values", style, val.id); - } else if (which[i] == ArgInfo::VARIABLE) { - int ivariable = input->variable->find(ids[i]); - if (ivariable < 0) error->all(FLERR, "Variable name for compute reduce does not exist"); - if (input->variable->atomstyle(ivariable) == 0) - error->all(FLERR, "Compute reduce variable is not atom-style variable"); - flavor[i] = PERATOM; + } else if (val.which == ArgInfo::VARIABLE) { + val.val.v = input->variable->find(val.id.c_str()); + if (val.val.v < 0) + error->all(FLERR, "Variable name {} for compute {} does not exist", val.id, style); + if (input->variable->atomstyle(val.val.v) == 0) + error->all(FLERR, "Compute {} variable {} is not atom-style variable", style, val.id); + val.flavor = PERATOM; } } @@ -284,12 +278,6 @@ ComputeReduce::ComputeReduce(LAMMPS *lmp, int narg, char **arg) : ComputeReduce::~ComputeReduce() { - delete[] which; - delete[] argindex; - delete[] flavor; - for (int m = 0; m < nvalues; m++) delete[] ids[m]; - delete[] ids; - delete[] value2index; delete[] replace; delete[] idregion; @@ -307,24 +295,21 @@ void ComputeReduce::init() { // set indices of all computes,fixes,variables - for (int m = 0; m < nvalues; m++) { - if (which[m] == ArgInfo::COMPUTE) { - int icompute = modify->find_compute(ids[m]); - if (icompute < 0) error->all(FLERR, "Compute ID for compute reduce does not exist"); - value2index[m] = icompute; + for (auto &val : values) { + if (val.which == ArgInfo::COMPUTE) { + val.val.c = modify->get_compute_by_id(val.id); + if (!val.val.c) + error->all(FLERR, "Compute ID {} for compute {} does not exist", val.id, style); - } else if (which[m] == ArgInfo::FIX) { - int ifix = modify->find_fix(ids[m]); - if (ifix < 0) error->all(FLERR, "Fix ID for compute reduce does not exist"); - value2index[m] = ifix; + } else if (val.which == ArgInfo::FIX) { + val.val.f = modify->get_fix_by_id(val.id); + if (!val.val.f) error->all(FLERR, "Fix ID {} for compute {} does not exist", val.id, style); - } else if (which[m] == ArgInfo::VARIABLE) { - int ivariable = input->variable->find(ids[m]); - if (ivariable < 0) error->all(FLERR, "Variable name for compute reduce does not exist"); - value2index[m] = ivariable; - - } else - value2index[m] = ArgInfo::UNKNOWN; + } else if (val.which == ArgInfo::VARIABLE) { + val.val.v = input->variable->find(val.id.c_str()); + if (val.val.v < 0) + error->all(FLERR, "Variable name {} for compute {} does not exist", val.id, style); + } } // set index and check validity of region @@ -383,14 +368,14 @@ void ComputeReduce::compute_vector() for (int m = 0; m < nvalues; m++) if (replace[m] < 0) { pairme.value = onevec[m]; - pairme.proc = me; + pairme.proc = comm->me; MPI_Allreduce(&pairme, &pairall, 1, MPI_DOUBLE_INT, MPI_MINLOC, world); vector[m] = pairall.value; owner[m] = pairall.proc; } for (int m = 0; m < nvalues; m++) if (replace[m] >= 0) { - if (me == owner[replace[m]]) vector[m] = compute_one(m, indices[replace[m]]); + if (comm->me == owner[replace[m]]) vector[m] = compute_one(m, indices[replace[m]]); MPI_Bcast(&vector[m], 1, MPI_DOUBLE, owner[replace[m]], world); } } @@ -404,14 +389,14 @@ void ComputeReduce::compute_vector() for (int m = 0; m < nvalues; m++) if (replace[m] < 0) { pairme.value = onevec[m]; - pairme.proc = me; + pairme.proc = comm->me; MPI_Allreduce(&pairme, &pairall, 1, MPI_DOUBLE_INT, MPI_MAXLOC, world); vector[m] = pairall.value; owner[m] = pairall.proc; } for (int m = 0; m < nvalues; m++) if (replace[m] >= 0) { - if (me == owner[replace[m]]) vector[m] = compute_one(m, indices[replace[m]]); + if (comm->me == owner[replace[m]]) vector[m] = compute_one(m, indices[replace[m]]); MPI_Bcast(&vector[m], 1, MPI_DOUBLE, owner[replace[m]], world); } } @@ -436,24 +421,21 @@ void ComputeReduce::compute_vector() double ComputeReduce::compute_one(int m, int flag) { - int i; - // invoke the appropriate attribute,compute,fix,variable // for flag = -1, compute scalar quantity by scanning over atom properties // only include atoms in group for atom properties and per-atom quantities index = -1; - int vidx = value2index[m]; + auto &val = values[m]; // initialization in case it has not yet been run, e.g. when // the compute was invoked right after it has been created - if (vidx == ArgInfo::UNKNOWN) { - init(); - vidx = value2index[m]; + if ((val.which == ArgInfo::COMPUTE) || (val.which == ArgInfo::FIX)) { + if (val.val.c == nullptr) init(); } - int aidx = argindex[m]; + int aidx = val.argindex; int *mask = atom->mask; int nlocal = atom->nlocal; @@ -461,77 +443,76 @@ double ComputeReduce::compute_one(int m, int flag) if (mode == MINN) one = BIG; if (mode == MAXX) one = -BIG; - if (which[m] == ArgInfo::X) { + if (val.which == ArgInfo::X) { double **x = atom->x; if (flag < 0) { - for (i = 0; i < nlocal; i++) + for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit) combine(one, x[i][aidx], i); } else one = x[flag][aidx]; - } else if (which[m] == ArgInfo::V) { + } else if (val.which == ArgInfo::V) { double **v = atom->v; if (flag < 0) { - for (i = 0; i < nlocal; i++) + for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit) combine(one, v[i][aidx], i); } else one = v[flag][aidx]; - } else if (which[m] == ArgInfo::F) { + } else if (val.which == ArgInfo::F) { double **f = atom->f; if (flag < 0) { - for (i = 0; i < nlocal; i++) + for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit) combine(one, f[i][aidx], i); } else one = f[flag][aidx]; // invoke compute if not previously invoked - } else if (which[m] == ArgInfo::COMPUTE) { - Compute *compute = modify->compute[vidx]; + } else if (val.which == ArgInfo::COMPUTE) { - if (flavor[m] == PERATOM) { - if (!(compute->invoked_flag & Compute::INVOKED_PERATOM)) { - compute->compute_peratom(); - compute->invoked_flag |= Compute::INVOKED_PERATOM; + if (val.flavor == PERATOM) { + if (!(val.val.c->invoked_flag & Compute::INVOKED_PERATOM)) { + val.val.c->compute_peratom(); + val.val.c->invoked_flag |= Compute::INVOKED_PERATOM; } if (aidx == 0) { - double *comp_vec = compute->vector_atom; + double *comp_vec = val.val.c->vector_atom; int n = nlocal; if (flag < 0) { - for (i = 0; i < n; i++) + for (int i = 0; i < n; i++) if (mask[i] & groupbit) combine(one, comp_vec[i], i); } else one = comp_vec[flag]; } else { - double **carray_atom = compute->array_atom; + double **carray_atom = val.val.c->array_atom; int n = nlocal; int aidxm1 = aidx - 1; if (flag < 0) { - for (i = 0; i < n; i++) + for (int i = 0; i < n; i++) if (mask[i] & groupbit) combine(one, carray_atom[i][aidxm1], i); } else one = carray_atom[flag][aidxm1]; } - } else if (flavor[m] == LOCAL) { - if (!(compute->invoked_flag & Compute::INVOKED_LOCAL)) { - compute->compute_local(); - compute->invoked_flag |= Compute::INVOKED_LOCAL; + } else if (val.flavor == LOCAL) { + if (!(val.val.c->invoked_flag & Compute::INVOKED_LOCAL)) { + val.val.c->compute_local(); + val.val.c->invoked_flag |= Compute::INVOKED_LOCAL; } if (aidx == 0) { - double *comp_vec = compute->vector_local; - int n = compute->size_local_rows; + double *comp_vec = val.val.c->vector_local; + int n = val.val.c->size_local_rows; if (flag < 0) - for (i = 0; i < n; i++) combine(one, comp_vec[i], i); + for (int i = 0; i < n; i++) combine(one, comp_vec[i], i); else one = comp_vec[flag]; } else { - double **carray_local = compute->array_local; - int n = compute->size_local_rows; + double **carray_local = val.val.c->array_local; + int n = val.val.c->size_local_rows; int aidxm1 = aidx - 1; if (flag < 0) - for (i = 0; i < n; i++) combine(one, carray_local[i][aidxm1], i); + for (int i = 0; i < n; i++) combine(one, carray_local[i][aidxm1], i); else one = carray_local[flag][aidxm1]; } @@ -539,46 +520,43 @@ double ComputeReduce::compute_one(int m, int flag) // access fix fields, check if fix frequency is a match - } else if (which[m] == ArgInfo::FIX) { - if (update->ntimestep % modify->fix[vidx]->peratom_freq) - error->all(FLERR, - "Fix used in compute reduce not " - "computed at compatible time"); - Fix *fix = modify->fix[vidx]; + } else if (val.which == ArgInfo::FIX) { + if (update->ntimestep % val.val.f->peratom_freq) + error->all(FLERR, "Fix {} used in compute {} not computed at compatible time", val.id, style); - if (flavor[m] == PERATOM) { + if (val.flavor == PERATOM) { if (aidx == 0) { - double *fix_vector = fix->vector_atom; + double *fix_vector = val.val.f->vector_atom; int n = nlocal; if (flag < 0) { - for (i = 0; i < n; i++) + for (int i = 0; i < n; i++) if (mask[i] & groupbit) combine(one, fix_vector[i], i); } else one = fix_vector[flag]; } else { - double **fix_array = fix->array_atom; + double **fix_array = val.val.f->array_atom; int aidxm1 = aidx - 1; if (flag < 0) { - for (i = 0; i < nlocal; i++) + for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit) combine(one, fix_array[i][aidxm1], i); } else one = fix_array[flag][aidxm1]; } - } else if (flavor[m] == LOCAL) { + } else if (val.flavor == LOCAL) { if (aidx == 0) { - double *fix_vector = fix->vector_local; - int n = fix->size_local_rows; + double *fix_vector = val.val.f->vector_local; + int n = val.val.f->size_local_rows; if (flag < 0) - for (i = 0; i < n; i++) combine(one, fix_vector[i], i); + for (int i = 0; i < n; i++) combine(one, fix_vector[i], i); else one = fix_vector[flag]; } else { - double **fix_array = fix->array_local; - int n = fix->size_local_rows; + double **fix_array = val.val.f->array_local; + int n = val.val.f->size_local_rows; int aidxm1 = aidx - 1; if (flag < 0) - for (i = 0; i < n; i++) combine(one, fix_array[i][aidxm1], i); + for (int i = 0; i < n; i++) combine(one, fix_array[i][aidxm1], i); else one = fix_array[flag][aidxm1]; } @@ -586,16 +564,16 @@ double ComputeReduce::compute_one(int m, int flag) // evaluate atom-style variable - } else if (which[m] == ArgInfo::VARIABLE) { + } else if (val.which == ArgInfo::VARIABLE) { if (atom->nmax > maxatom) { maxatom = atom->nmax; memory->destroy(varatom); memory->create(varatom, maxatom, "reduce:varatom"); } - input->variable->compute_atom(vidx, igroup, varatom, 1, 0); + input->variable->compute_atom(val.val.v, igroup, varatom, 1, 0); if (flag < 0) { - for (i = 0; i < nlocal; i++) + for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit) combine(one, varatom[i], i); } else one = varatom[flag]; @@ -608,31 +586,28 @@ double ComputeReduce::compute_one(int m, int flag) bigint ComputeReduce::count(int m) { - int vidx = value2index[m]; - - if (which[m] == ArgInfo::X || which[m] == ArgInfo::V || which[m] == ArgInfo::F) + auto &val = values[m]; + if ((val.which == ArgInfo::X) || (val.which == ArgInfo::V) || (val.which == ArgInfo::F)) return group->count(igroup); - else if (which[m] == ArgInfo::COMPUTE) { - Compute *compute = modify->compute[vidx]; - if (flavor[m] == PERATOM) { + else if (val.which == ArgInfo::COMPUTE) { + if (val.flavor == PERATOM) { return group->count(igroup); - } else if (flavor[m] == LOCAL) { - bigint ncount = compute->size_local_rows; + } else if (val.flavor == LOCAL) { + bigint ncount = val.val.c->size_local_rows; bigint ncountall; MPI_Allreduce(&ncount, &ncountall, 1, MPI_LMP_BIGINT, MPI_SUM, world); return ncountall; } - } else if (which[m] == ArgInfo::FIX) { - Fix *fix = modify->fix[vidx]; - if (flavor[m] == PERATOM) { + } else if (val.which == ArgInfo::FIX) { + if (val.flavor == PERATOM) { return group->count(igroup); - } else if (flavor[m] == LOCAL) { - bigint ncount = fix->size_local_rows; + } else if (val.flavor == LOCAL) { + bigint ncount = val.val.f->size_local_rows; bigint ncountall; MPI_Allreduce(&ncount, &ncountall, 1, MPI_LMP_BIGINT, MPI_SUM, world); return ncountall; } - } else if (which[m] == ArgInfo::VARIABLE) + } else if (val.which == ArgInfo::VARIABLE) return group->count(igroup); bigint dummy = 0; diff --git a/src/compute_reduce.h b/src/compute_reduce.h index dc4ee1ef2c..bfaa1e2a72 100644 --- a/src/compute_reduce.h +++ b/src/compute_reduce.h @@ -37,30 +37,38 @@ class ComputeReduce : public Compute { double memory_usage() override; protected: - int me; int mode, nvalues; - int *which, *argindex, *flavor, *value2index; - char **ids; + struct value_t { + int which; + int argindex; + std::string id; + int flavor; + union { + class Compute *c; + class Fix *f; + int v; + } val; + }; + std::vector values; double *onevec; int *replace, *indices, *owner; + int index; char *idregion; class Region *region; int maxatom; double *varatom; - struct Pair { + struct valpair { double value; int proc; }; - Pair pairme, pairall; + valpair pairme, pairall; virtual double compute_one(int, int); virtual bigint count(int); void combine(double &, double, int); }; - } // namespace LAMMPS_NS - #endif #endif diff --git a/src/compute_reduce_chunk.cpp b/src/compute_reduce_chunk.cpp index 2f8212bc0a..f3bb57d33d 100644 --- a/src/compute_reduce_chunk.cpp +++ b/src/compute_reduce_chunk.cpp @@ -65,7 +65,7 @@ ComputeReduceChunk::ComputeReduceChunk(LAMMPS *lmp, int narg, char **arg) : if (earg != &arg[iarg]) expand = 1; arg = earg; - // parse values until + // parse values values.clear(); for (iarg = 0; iarg < nargnew; iarg++) { diff --git a/src/compute_reduce_region.cpp b/src/compute_reduce_region.cpp index f8a92c7bf3..81d0f1acf5 100644 --- a/src/compute_reduce_region.cpp +++ b/src/compute_reduce_region.cpp @@ -55,62 +55,58 @@ double ComputeReduceRegion::compute_one(int m, int flag) // only include atoms in group index = -1; + auto &val = values[m]; + + // initialization in case it has not yet been run, e.g. when + // the compute was invoked right after it has been created + if ((val.which == ArgInfo::COMPUTE) || (val.which == ArgInfo::FIX)) { + if (val.val.c == nullptr) init(); + } + + int aidx = val.argindex; double **x = atom->x; int *mask = atom->mask; int nlocal = atom->nlocal; - int n = value2index[m]; - - // initialization in case it has not yet been run, - // e.g. when invoked - if (n == ArgInfo::UNKNOWN) { - init(); - n = value2index[m]; - } - - int j = argindex[m]; - double one = 0.0; if (mode == MINN) one = BIG; if (mode == MAXX) one = -BIG; - if (which[m] == ArgInfo::X) { + if (val.which == ArgInfo::X) { if (flag < 0) { for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit && region->match(x[i][0], x[i][1], x[i][2])) - combine(one, x[i][j], i); + combine(one, x[i][aidx], i); } else - one = x[flag][j]; - } else if (which[m] == ArgInfo::V) { + one = x[flag][aidx]; + } else if (val.which == ArgInfo::V) { double **v = atom->v; if (flag < 0) { for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit && region->match(x[i][0], x[i][1], x[i][2])) - combine(one, v[i][j], i); + combine(one, v[i][aidx], i); } else - one = v[flag][j]; - } else if (which[m] == ArgInfo::F) { + one = v[flag][aidx]; + } else if (val.which == ArgInfo::F) { double **f = atom->f; if (flag < 0) { for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit && region->match(x[i][0], x[i][1], x[i][2])) - combine(one, f[i][j], i); + combine(one, f[i][aidx], i); } else - one = f[flag][j]; + one = f[flag][aidx]; // invoke compute if not previously invoked - } else if (which[m] == ArgInfo::COMPUTE) { - Compute *compute = modify->compute[n]; - - if (flavor[m] == PERATOM) { - if (!(compute->invoked_flag & Compute::INVOKED_PERATOM)) { - compute->compute_peratom(); - compute->invoked_flag |= Compute::INVOKED_PERATOM; + } else if (val.which == ArgInfo::COMPUTE) { + if (val.flavor == PERATOM) { + if (!(val.val.c->invoked_flag & Compute::INVOKED_PERATOM)) { + val.val.c->compute_peratom(); + val.val.c->invoked_flag |= Compute::INVOKED_PERATOM; } - if (j == 0) { - double *compute_vector = compute->vector_atom; + if (aidx == 0) { + double *compute_vector = val.val.c->vector_atom; if (flag < 0) { for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit && region->match(x[i][0], x[i][1], x[i][2])) @@ -118,48 +114,48 @@ double ComputeReduceRegion::compute_one(int m, int flag) } else one = compute_vector[flag]; } else { - double **compute_array = compute->array_atom; - int jm1 = j - 1; + double **compute_array = val.val.c->array_atom; + int aidxm1 = aidx - 1; if (flag < 0) { for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit && region->match(x[i][0], x[i][1], x[i][2])) - combine(one, compute_array[i][jm1], i); + combine(one, compute_array[i][aidxm1], i); } else - one = compute_array[flag][jm1]; + one = compute_array[flag][aidxm1]; } - } else if (flavor[m] == LOCAL) { - if (!(compute->invoked_flag & Compute::INVOKED_LOCAL)) { - compute->compute_local(); - compute->invoked_flag |= Compute::INVOKED_LOCAL; + } else if (val.flavor == LOCAL) { + if (!(val.val.c->invoked_flag & Compute::INVOKED_LOCAL)) { + val.val.c->compute_local(); + val.val.c->invoked_flag |= Compute::INVOKED_LOCAL; } - if (j == 0) { - double *compute_vector = compute->vector_local; + if (aidx == 0) { + double *compute_vector = val.val.c->vector_local; if (flag < 0) - for (int i = 0; i < compute->size_local_rows; i++) combine(one, compute_vector[i], i); + for (int i = 0; i < val.val.c->size_local_rows; i++) combine(one, compute_vector[i], i); else one = compute_vector[flag]; } else { - double **compute_array = compute->array_local; - int jm1 = j - 1; + double **compute_array = val.val.c->array_local; + int aidxm1 = aidx - 1; if (flag < 0) - for (int i = 0; i < compute->size_local_rows; i++) combine(one, compute_array[i][jm1], i); + for (int i = 0; i < val.val.c->size_local_rows; i++) + combine(one, compute_array[i][aidxm1], i); else - one = compute_array[flag][jm1]; + one = compute_array[flag][aidxm1]; } } // check if fix frequency is a match - } else if (which[m] == ArgInfo::FIX) { - if (update->ntimestep % modify->fix[n]->peratom_freq) - error->all(FLERR, "Fix used in compute reduce not computed at compatible time"); - Fix *fix = modify->fix[n]; + } else if (val.which == ArgInfo::FIX) { + if (update->ntimestep % val.val.f->peratom_freq) + error->all(FLERR, "Fix {} used in compute {} not computed at compatible time", val.id, style); - if (flavor[m] == PERATOM) { - if (j == 0) { - double *fix_vector = fix->vector_atom; + if (val.flavor == PERATOM) { + if (aidx == 0) { + double *fix_vector = val.val.f->vector_atom; if (flag < 0) { for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit && region->match(x[i][0], x[i][1], x[i][2])) @@ -167,43 +163,44 @@ double ComputeReduceRegion::compute_one(int m, int flag) } else one = fix_vector[flag]; } else { - double **fix_array = fix->array_atom; - int jm1 = j - 1; + double **fix_array = val.val.f->array_atom; + int aidxm1 = aidx - 1; if (flag < 0) { for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit && region->match(x[i][0], x[i][1], x[i][2])) - combine(one, fix_array[i][jm1], i); + combine(one, fix_array[i][aidxm1], i); } else - one = fix_array[flag][jm1]; + one = fix_array[flag][aidxm1]; } - } else if (flavor[m] == LOCAL) { - if (j == 0) { - double *fix_vector = fix->vector_local; + } else if (val.flavor == LOCAL) { + if (aidx == 0) { + double *fix_vector = val.val.f->vector_local; if (flag < 0) - for (int i = 0; i < fix->size_local_rows; i++) combine(one, fix_vector[i], i); + for (int i = 0; i < val.val.f->size_local_rows; i++) combine(one, fix_vector[i], i); else one = fix_vector[flag]; } else { - double **fix_array = fix->array_local; - int jm1 = j - 1; + double **fix_array = val.val.f->array_local; + int aidxm1 = aidx - 1; if (flag < 0) - for (int i = 0; i < fix->size_local_rows; i++) combine(one, fix_array[i][jm1], i); + for (int i = 0; i < val.val.f->size_local_rows; i++) + combine(one, fix_array[i][aidxm1], i); else - one = fix_array[flag][jm1]; + one = fix_array[flag][aidxm1]; } } // evaluate atom-style variable - } else if (which[m] == ArgInfo::VARIABLE) { + } else if (val.which == ArgInfo::VARIABLE) { if (atom->nmax > maxatom) { maxatom = atom->nmax; memory->destroy(varatom); memory->create(varatom, maxatom, "reduce/region:varatom"); } - input->variable->compute_atom(n, igroup, varatom, 1, 0); + input->variable->compute_atom(val.val.v, igroup, varatom, 1, 0); if (flag < 0) { for (int i = 0; i < nlocal; i++) if (mask[i] & groupbit && region->match(x[i][0], x[i][1], x[i][2])) @@ -219,31 +216,29 @@ double ComputeReduceRegion::compute_one(int m, int flag) bigint ComputeReduceRegion::count(int m) { - int n = value2index[m]; + auto &val = values[m]; - if (which[m] == ArgInfo::X || which[m] == ArgInfo::V || which[m] == ArgInfo::F) + if (val.which == ArgInfo::X || val.which == ArgInfo::V || val.which == ArgInfo::F) return group->count(igroup, region); - else if (which[m] == ArgInfo::COMPUTE) { - Compute *compute = modify->compute[n]; - if (flavor[m] == PERATOM) { + else if (val.which == ArgInfo::COMPUTE) { + if (val.flavor == PERATOM) { return group->count(igroup, region); - } else if (flavor[m] == LOCAL) { - bigint ncount = compute->size_local_rows; + } else if (val.flavor == LOCAL) { + bigint ncount = val.val.c->size_local_rows; bigint ncountall; MPI_Allreduce(&ncount, &ncountall, 1, MPI_DOUBLE, MPI_SUM, world); return ncountall; } - } else if (which[m] == ArgInfo::FIX) { - Fix *fix = modify->fix[n]; - if (flavor[m] == PERATOM) { + } else if (val.which == ArgInfo::FIX) { + if (val.flavor == PERATOM) { return group->count(igroup, region); - } else if (flavor[m] == LOCAL) { - bigint ncount = fix->size_local_rows; + } else if (val.flavor == LOCAL) { + bigint ncount = val.val.f->size_local_rows; bigint ncountall; MPI_Allreduce(&ncount, &ncountall, 1, MPI_DOUBLE, MPI_SUM, world); return ncountall; } - } else if (which[m] == ArgInfo::VARIABLE) + } else if (val.which == ArgInfo::VARIABLE) return group->count(igroup, region); bigint dummy = 0;