have compute_reduce require either peratom or local inputs

This commit is contained in:
Steve Plimpton
2023-08-17 16:12:14 -06:00
parent 0d739439c7
commit 299eda8ca3
5 changed files with 149 additions and 106 deletions

View File

@ -31,12 +31,16 @@
using namespace LAMMPS_NS;
enum{UNDECIDED,PERATOM,LOCAL}; // same as in ComputeReduceRegion
#define BIG 1.0e20
//----------------------------------------------------------------
void abs_max(void *in, void *inout, int * /*len*/, MPI_Datatype * /*type*/)
{
// r is the already reduced value, n is the new value
double n = std::fabs(*(double *) in), r = *(double *) inout;
double m;
@ -47,9 +51,11 @@ void abs_max(void *in, void *inout, int * /*len*/, MPI_Datatype * /*type*/)
}
*(double *) inout = m;
}
void abs_min(void *in, void *inout, int * /*len*/, MPI_Datatype * /*type*/)
{
// r is the already reduced value, n is the new value
double n = std::fabs(*(double *) in), r = *(double *) inout;
double m;
@ -68,6 +74,7 @@ ComputeReduce::ComputeReduce(LAMMPS *lmp, int narg, char **arg) :
owner(nullptr), idregion(nullptr), region(nullptr), varatom(nullptr)
{
int iarg = 0;
if (strcmp(style, "reduce") == 0) {
if (narg < 5) utils::missing_cmd_args(FLERR, "compute reduce", error);
iarg = 3;
@ -128,42 +135,52 @@ ComputeReduce::ComputeReduce(LAMMPS *lmp, int narg, char **arg) :
// parse values
input_mode = UNDECIDED;
values.clear();
nvalues = 0;
for (int iarg = 0; iarg < nargnew; ++iarg) {
value_t val;
val.id = "";
val.flavor = 0;
val.val.c = nullptr;
if (strcmp(arg[iarg], "x") == 0) {
input_mode = PERATOM;
val.which = ArgInfo::X;
val.argindex = 0;
} else if (strcmp(arg[iarg], "y") == 0) {
input_mode = PERATOM;
val.which = ArgInfo::X;
val.argindex = 1;
} else if (strcmp(arg[iarg], "z") == 0) {
input_mode = PERATOM;
val.which = ArgInfo::X;
val.argindex = 2;
} else if (strcmp(arg[iarg], "vx") == 0) {
input_mode = PERATOM;
val.which = ArgInfo::V;
val.argindex = 0;
} else if (strcmp(arg[iarg], "vy") == 0) {
input_mode = PERATOM;
val.which = ArgInfo::V;
val.argindex = 1;
} else if (strcmp(arg[iarg], "vz") == 0) {
input_mode = PERATOM;
val.which = ArgInfo::V;
val.argindex = 2;
} else if (strcmp(arg[iarg], "fx") == 0) {
input_mode = PERATOM;
val.which = ArgInfo::F;
val.argindex = 0;
} else if (strcmp(arg[iarg], "fy") == 0) {
input_mode = PERATOM;
val.which = ArgInfo::F;
val.argindex = 1;
} else if (strcmp(arg[iarg], "fz") == 0) {
input_mode = PERATOM;
val.which = ArgInfo::F;
val.argindex = 2;
@ -207,6 +224,14 @@ ComputeReduce::ComputeReduce(LAMMPS *lmp, int narg, char **arg) :
error->all(FLERR, "Compute {} replace column already used for another replacement");
replace[col1] = col2;
iarg += 2;
} else if (strcmp(arg[iarg], "inputs") == 0) {
if (iarg + 2 > narg) utils::missing_cmd_args(FLERR, mycmd + " inputs", error);
if (strcmp(arg[iarg+1], "peratom") == 0) input_mode = PERATOM;
else if (strcmp(arg[iarg+1], "local") == 0) {
if (input_mode == PERATOM)
error->all(FLERR,"Compute {} inputs must be all peratom or all local");
input_mode = LOCAL;
}
} else
error->all(FLERR, "Unknown compute {} keyword: {}", style, arg[iarg]);
}
@ -231,66 +256,64 @@ ComputeReduce::ComputeReduce(LAMMPS *lmp, int narg, char **arg) :
// setup and error check
for (auto &val : values) {
if (val.which == ArgInfo::X || val.which == ArgInfo::V || val.which == ArgInfo::F)
val.flavor = PERATOM;
else if (val.which == ArgInfo::COMPUTE) {
if (val.which == ArgInfo::COMPUTE) {
val.val.c = modify->get_compute_by_id(val.id);
if (!val.val.c)
error->all(FLERR, "Compute ID {} for compute {} does not exist", val.id, style);
if (val.val.c->peratom_flag) {
val.flavor = PERATOM;
if (input_mode == PERATOM) {
if (!val.val.c->peratom_flag)
error->all(FLERR, "Compute {} compute {} does not calculate per-atom values", style, val.id);
if (val.argindex == 0 && val.val.c->size_peratom_cols != 0)
error->all(FLERR, "Compute {} compute {} does not calculate a per-atom vector", style,
val.id);
error->all(FLERR, "Compute {} compute {} does not calculate a per-atom vector", style, val.id);
if (val.argindex && val.val.c->size_peratom_cols == 0)
error->all(FLERR, "Compute {} compute {} does not calculate a per-atom array", style,
val.id);
error->all(FLERR, "Compute {} compute {} does not calculate a per-atom array", style, val.id);
if (val.argindex && val.argindex > val.val.c->size_peratom_cols)
error->all(FLERR, "Compute {} compute {} array is accessed out-of-range", style, val.id);
} else if (val.val.c->local_flag) {
val.flavor = LOCAL;
} else if (input_mode == LOCAL) {
if (!val.val.c->peratom_flag)
error->all(FLERR, "Compute {} compute {} does not calculate local values", style, val.id);
if (val.argindex == 0 && val.val.c->size_local_cols != 0)
error->all(FLERR, "Compute {} compute {} does not calculate a local vector", style,
val.id);
error->all(FLERR, "Compute {} compute {} does not calculate a local vector", style, val.id);
if (val.argindex && val.val.c->size_local_cols == 0)
error->all(FLERR, "Compute {} compute {} does not calculate a local array", style,
val.id);
error->all(FLERR, "Compute {} compute {} does not calculate a local array", style, val.id);
if (val.argindex && val.argindex > val.val.c->size_local_cols)
error->all(FLERR, "Compute {} compute {} array is accessed out-of-range", style, val.id);
} else
error->all(FLERR, "Compute {} compute {} calculates global values", style, val.id);
}
} else if (val.which == ArgInfo::FIX) {
val.val.f = modify->get_fix_by_id(val.id);
if (!val.val.f) error->all(FLERR, "Fix ID {} for compute {} does not exist", val.id, style);
if (val.val.f->peratom_flag) {
val.flavor = PERATOM;
if (input_mode == PERATOM) {
if (!val.val.f->peratom_flag)
error->all(FLERR, "Compute {} fix {} does not calculate per-atom values", style, val.id);
if (val.argindex == 0 && (val.val.f->size_peratom_cols != 0))
error->all(FLERR, "Compute {} fix {} does not calculate a per-atom vector", style,
val.id);
error->all(FLERR, "Compute {} fix {} does not calculate a per-atom vector", style, val.id);
if (val.argindex && (val.val.f->size_peratom_cols == 0))
error->all(FLERR, "Compute {} fix {} does not calculate a per-atom array", style, val.id);
if (val.argindex && (val.argindex > val.val.f->size_peratom_cols))
error->all(FLERR, "Compute {} fix {} array is accessed out-of-range", style, val.id);
} else if (val.val.f->local_flag) {
val.flavor = LOCAL;
} else if (input_mode == LOCAL) {
if (!val.val.f->local_flag)
error->all(FLERR, "Compute {} fix {} does not calculate local values", style, val.id);
if (val.argindex == 0 && (val.val.f->size_local_cols != 0))
error->all(FLERR, "Compute {} fix {} does not calculate a local vector", style, val.id);
if (val.argindex && (val.val.f->size_local_cols == 0))
error->all(FLERR, "Compute {} fix {} does not calculate a local array", style, val.id);
if (val.argindex && (val.argindex > val.val.f->size_local_cols))
error->all(FLERR, "Compute {} fix {} array is accessed out-of-range", style, val.id);
} else
error->all(FLERR, "Compute {} fix {} calculates global values", style, val.id);
}
} else if (val.which == ArgInfo::VARIABLE) {
if (input_mode == LOCAL) error->all(FLERR,"Compute {} inputs must be all local");
val.val.v = input->variable->find(val.id.c_str());
if (val.val.v < 0)
error->all(FLERR, "Variable name {} for compute {} does not exist", val.id, style);
if (input->variable->atomstyle(val.val.v) == 0)
error->all(FLERR, "Compute {} variable {} is not atom-style variable", style, val.id);
val.flavor = PERATOM;
}
}
@ -512,7 +535,7 @@ double ComputeReduce::compute_one(int m, int flag)
} else if (val.which == ArgInfo::COMPUTE) {
if (val.flavor == PERATOM) {
if (input_mode == PERATOM) {
if (!(val.val.c->invoked_flag & Compute::INVOKED_PERATOM)) {
val.val.c->compute_peratom();
val.val.c->invoked_flag |= Compute::INVOKED_PERATOM;
@ -537,7 +560,7 @@ double ComputeReduce::compute_one(int m, int flag)
one = carray_atom[flag][aidxm1];
}
} else if (val.flavor == LOCAL) {
} else if (input_mode == LOCAL) {
if (!(val.val.c->invoked_flag & Compute::INVOKED_LOCAL)) {
val.val.c->compute_local();
val.val.c->invoked_flag |= Compute::INVOKED_LOCAL;
@ -567,7 +590,7 @@ double ComputeReduce::compute_one(int m, int flag)
if (update->ntimestep % val.val.f->peratom_freq)
error->all(FLERR, "Fix {} used in compute {} not computed at compatible time", val.id, style);
if (val.flavor == PERATOM) {
if (input_mode == PERATOM) {
if (aidx == 0) {
double *fix_vector = val.val.f->vector_atom;
if (flag < 0) {
@ -585,7 +608,7 @@ double ComputeReduce::compute_one(int m, int flag)
one = fix_array[flag][aidxm1];
}
} else if (val.flavor == LOCAL) {
} else if (input_mode == LOCAL) {
if (aidx == 0) {
double *fix_vector = val.val.f->vector_local;
int n = val.val.f->size_local_rows;
@ -632,18 +655,18 @@ bigint ComputeReduce::count(int m)
if ((val.which == ArgInfo::X) || (val.which == ArgInfo::V) || (val.which == ArgInfo::F))
return group->count(igroup);
else if (val.which == ArgInfo::COMPUTE) {
if (val.flavor == PERATOM) {
if (input_mode == PERATOM) {
return group->count(igroup);
} else if (val.flavor == LOCAL) {
} else if (input_mode == LOCAL) {
bigint ncount = val.val.c->size_local_rows;
bigint ncountall;
MPI_Allreduce(&ncount, &ncountall, 1, MPI_LMP_BIGINT, MPI_SUM, world);
return ncountall;
}
} else if (val.which == ArgInfo::FIX) {
if (val.flavor == PERATOM) {
if (input_mode == PERATOM) {
return group->count(igroup);
} else if (val.flavor == LOCAL) {
} else if (input_mode == LOCAL) {
bigint ncount = val.val.f->size_local_rows;
bigint ncountall;
MPI_Allreduce(&ncount, &ncountall, 1, MPI_LMP_BIGINT, MPI_SUM, world);