Files
lammps/src/EXTRA-COMPUTE/compute_slcsa_atom.cpp
2025-03-20 22:45:01 -04:00

421 lines
16 KiB
C++

/* ----------------------------------------------------------------------
LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
https://www.lammps.org/, Sandia National Laboratories
LAMMPS development team: developers@lammps.org
Copyright (2003) Sandia Corporation. Under the terms of Contract
DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
certain rights in this software. This software is distributed under
the GNU General Public License.
See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */
/* ----------------------------------------------------------------------
Contributing author: Paul Lafourcade (CEA-DAM-DIF, Arpajon, France)
------------------------------------------------------------------------- */
#include "compute_slcsa_atom.h"
#include "arg_info.h"
#include "atom.h"
#include "citeme.h"
#include "comm.h"
#include "error.h"
#include "memory.h"
#include "modify.h"
#include "neigh_list.h"
#include "potential_file_reader.h"
#include "update.h"
#include <cmath>
#include <cstring>
using namespace LAMMPS_NS;
static const char cite_compute_slcsa_atom_c[] =
"compute slcsa/atom command: doi:10.1016/j.commatsci.2023.112534\n\n"
"@Article{Lafourcade2023,\n"
" author = {P. Lafourcade and J.-B. Maillet and C. Denoual and E. Duval and A. Allera and A. "
"M. Goryaeva and M.-C. Marinica},\n"
" title = {Robust crystal structure identification at extreme conditions using a "
"density-independent spectral descriptor and supervised learning},\n"
" journal = {Computational Materials Science},\n"
" year = 2023,\n"
" volume = 230,\n"
" pages = 112534\n"
"}\n\n";
/* ---------------------------------------------------------------------- */
ComputeSLCSAAtom::ComputeSLCSAAtom(LAMMPS *lmp, int narg, char **arg) :
Compute(lmp, narg, arg), list(nullptr), lda_scalings(nullptr),
database_mean_descriptor(nullptr), lr_bias(nullptr), lr_decision(nullptr), icov_list(nullptr),
mean_projected_descriptors(nullptr), maha_thresholds(nullptr), full_descriptor(nullptr),
projected_descriptor(nullptr), scores(nullptr), probas(nullptr), prodright(nullptr),
dmaha(nullptr), classification(nullptr)
{
// command : compute c1 all slcsa/atom jmax nclasses parameters_file.dat
// example : compute c1 all slcsa/atom 8 4 slcsa_parameters.dat
// example : compute c1 all slcsa/atom 8 4 database_mean_descriptor.dat lda_scalings.dat lr_decision.dat lr_bias.dat mahalanobis_data.dat c_b1[*]
// Steps :
// 1. bs=bs-xbar
// 2. dred=coefs_lda*bs
// 3. scores=decision_lr*dred + lr_bias
// 4. probas=exp(scores)/sum(exp(scores))
// 5. cs=argmax(probas)
// Read the parameters file in one bloc
// File structure :
// # database mean descriptor
// vector with bso4dim rows x 1 col
// # LDA dimension reduction matrix
// matrix bso4dim rows x nclasses-1 cols
// # LR decision matrix
// matrix with nclasses rows x nclasses-1 cols
// # LR bias vector
// vector with 1 row x nclasses cols
if (lmp->citeme) lmp->citeme->add(cite_compute_slcsa_atom_c);
if (narg != 11) utils::missing_cmd_args(FLERR, "compute slcsa/atom", error);
int twojmax = utils::inumeric(FLERR, arg[3], false, lmp);
if (twojmax < 0)
error->all(FLERR, 3, "Illegal compute slcsa/atom command: twojmax must be >= 0");
ncomps = compute_ncomps(twojmax);
nclasses = utils::inumeric(FLERR, arg[4], false, lmp);
if (nclasses < 2)
error->all(FLERR, 4, "Illegal compute slcsa/atom command: nclasses must be greater than 1");
database_mean_descriptor_file = arg[5];
lda_scalings_file = arg[6];
lr_decision_file = arg[7];
lr_bias_file = arg[8];
maha_file = arg[9];
if (comm->me == 0) {
auto mesg = fmt::format(
"Files used:\n {:24}: {}\n {:24}: {}\n {:24}: {}\n {:24}: {}\n {:24}: {}\n",
"database mean descriptor", database_mean_descriptor_file, "lda scalings",
lda_scalings_file, "lr decision", lr_decision_file, "lr bias", lr_bias_file, "maha stats",
maha_file);
utils::logmesg(lmp, mesg);
}
int expand = 0;
char **earg;
int nvalues = utils::expand_args(FLERR, narg - 10, &arg[10], 1, earg, lmp);
if (earg != &arg[10]) expand = 1;
arg = earg;
ArgInfo argi(arg[0]);
value_t val;
val.id = "";
val.val.c = nullptr;
val.which = argi.get_type();
val.argindex = argi.get_index1();
val.id = argi.get_name();
if ((val.which == ArgInfo::FIX) || (val.which == ArgInfo::VARIABLE) ||
(val.which == ArgInfo::UNKNOWN) || (val.which == ArgInfo::NONE) || (argi.get_dim() > 1))
error->all(FLERR, 10, "Invalid compute slcsa/atom argument: {}", arg[0]);
// if wildcard expansion occurred, free earg memory from expand_args()
if (expand) {
for (int i = 0; i < nvalues; i++) delete[] earg[i];
memory->sfree(earg);
}
val.val.c = modify->get_compute_by_id(val.id);
if (!val.val.c) error->all(FLERR, 10, "Compute ID {} for fix slcsa/atom does not exist", val.id);
if (val.val.c->peratom_flag == 0)
error->all(FLERR, 10, "Compute slcsa/atom compute {} does not calculate per-atom values",
val.id);
if (val.argindex == 0 && val.val.c->size_peratom_cols != 0)
error->all(FLERR, 10, "Compute slcsa/atom compute {} does not calculate a per-atom vector",
val.id);
if (val.argindex && val.val.c->size_peratom_cols == 0)
error->all(FLERR, 10, "Compute slcsa/atom compute {} does not calculate a per-atom array",
val.id);
if (val.argindex && val.argindex > val.val.c->size_peratom_cols)
error->all(FLERR, 10, "Compute slcsa/atom compute {} array is accessed out-of-range{}", val.id,
utils::errorurl(20));
descriptorval = val;
memory->create(database_mean_descriptor, ncomps, "slcsa/atom:database_mean_descriptor");
memory->create(lda_scalings, ncomps, nclasses - 1, "slcsa/atom:lda_scalings");
memory->create(lr_decision, nclasses, nclasses - 1, "slcsa/atom:lr_decision");
memory->create(lr_bias, nclasses, "slcsa/atom:lr_bias");
memory->create(maha_thresholds, nclasses, "slcsa/atom:maha_thresholds");
memory->create(icov_list, nclasses, nclasses - 1, nclasses - 1, "slcsa/atom:icov_list");
memory->create(mean_projected_descriptors, nclasses, nclasses - 1,
"slcsa/atom:mean_projected_descriptors");
if (comm->me == 0) {
if (strcmp(database_mean_descriptor_file, "NULL") == 0) {
error->one(FLERR, Error::NOLASTLINE,
"Cannot open database mean descriptor file {}: ", database_mean_descriptor_file,
utils::getsyserror());
} else {
PotentialFileReader reader(lmp, database_mean_descriptor_file,
"database mean descriptor file");
int nread = 0;
while (nread < ncomps) {
auto values = reader.next_values(0);
database_mean_descriptor[nread] = values.next_double();
nread++;
}
}
if (strcmp(lda_scalings_file, "NULL") == 0) {
error->one(FLERR, Error::NOLASTLINE, "Cannot open database linear discriminant analysis "
"scalings file {}: ", lda_scalings_file, utils::getsyserror());
} else {
PotentialFileReader reader(lmp, lda_scalings_file, "lda scalings file");
int nread = 0;
while (nread < ncomps) {
auto values = reader.next_values(nclasses - 1);
lda_scalings[nread][0] = values.next_double();
lda_scalings[nread][1] = values.next_double();
lda_scalings[nread][2] = values.next_double();
nread++;
}
}
if (strcmp(lr_decision_file, "NULL") == 0) {
error->one(FLERR, Error::NOLASTLINE, "Cannot open logistic regression decision file {}: ",
lr_decision_file, utils::getsyserror());
} else {
PotentialFileReader reader(lmp, lr_decision_file, "lr decision file");
int nread = 0;
while (nread < nclasses) {
auto values = reader.next_values(nclasses - 1);
lr_decision[nread][0] = values.next_double();
lr_decision[nread][1] = values.next_double();
lr_decision[nread][2] = values.next_double();
nread++;
}
}
if (strcmp(lr_bias_file, "NULL") == 0) {
error->one(FLERR, Error::NOLASTLINE, "Cannot open logistic regression bias file {}: ",
lr_bias_file, utils::getsyserror());
} else {
PotentialFileReader reader(lmp, lr_bias_file, "lr bias file");
auto values = reader.next_values(nclasses);
lr_bias[0] = values.next_double();
lr_bias[1] = values.next_double();
lr_bias[2] = values.next_double();
lr_bias[3] = values.next_double();
}
if (strcmp(maha_file, "NULL") == 0) {
error->one(FLERR, Error::NOLASTLINE, "Cannot open mahalanobis stats file {}: ", maha_file,
utils::getsyserror());
} else {
PotentialFileReader reader(lmp, maha_file, "mahalanobis stats file");
int nvalues = nclasses * ((nclasses - 1) * (nclasses - 1) + nclasses);
auto values = reader.next_values(nvalues);
for (int i = 0; i < nclasses; i++) {
maha_thresholds[i] = values.next_double();
for (int j = 0; j < nclasses - 1; j++)
mean_projected_descriptors[i][j] = values.next_double();
for (int k = 0; k < nclasses - 1; k++)
for (int l = 0; l < nclasses - 1; l++) icov_list[i][k][l] = values.next_double();
}
for (int i = 0; i < nclasses; i++) {
auto mesg = fmt::format("For class {} maha threshold = {:.6}\n", i, maha_thresholds[i]);
mesg += " mean B:\n";
for (int j = 0; j < nclasses - 1; j++)
mesg += fmt::format(" {:11.6}\n", mean_projected_descriptors[i][j]);
mesg += " icov:\n";
for (int j = 0; j < nclasses - 1; j++) {
mesg += fmt::format(" {:11.6} {:11.6} {:11.6}\n", icov_list[i][j][0],
icov_list[i][j][1], icov_list[i][j][2]);
}
utils::logmesg(lmp, mesg);
}
}
}
MPI_Bcast(&database_mean_descriptor[0], ncomps, MPI_DOUBLE, 0, world);
MPI_Bcast(&lda_scalings[0][0], ncomps * (nclasses - 1), MPI_DOUBLE, 0, world);
MPI_Bcast(&lr_decision[0][0], nclasses * (nclasses - 1), MPI_DOUBLE, 0, world);
MPI_Bcast(&lr_bias[0], nclasses, MPI_DOUBLE, 0, world);
MPI_Bcast(&maha_thresholds[0], nclasses, MPI_DOUBLE, 0, world);
MPI_Bcast(&mean_projected_descriptors[0][0], nclasses * (nclasses - 1), MPI_DOUBLE, 0, world);
MPI_Bcast(&icov_list[0][0][0], nclasses * (nclasses - 1) * (nclasses - 1), MPI_DOUBLE, 0, world);
peratom_flag = 1;
size_peratom_cols = nclasses + 1;
ncols = nclasses + 1;
nmax = 0;
}
/* ---------------------------------------------------------------------- */
ComputeSLCSAAtom::~ComputeSLCSAAtom()
{
memory->destroy(classification);
memory->destroy(database_mean_descriptor);
memory->destroy(lda_scalings);
memory->destroy(lr_decision);
memory->destroy(lr_bias);
memory->destroy(maha_thresholds);
memory->destroy(mean_projected_descriptors);
memory->destroy(icov_list);
memory->destroy(full_descriptor);
memory->destroy(projected_descriptor);
memory->destroy(scores);
memory->destroy(probas);
memory->destroy(prodright);
memory->destroy(dmaha);
}
/* ---------------------------------------------------------------------- */
void ComputeSLCSAAtom::init()
{
if (modify->get_compute_by_style(style).size() > 1)
if (comm->me == 0) error->warning(FLERR, "More than one compute {}", style);
}
/* ---------------------------------------------------------------------- */
void ComputeSLCSAAtom::init_list(int /*id*/, NeighList *ptr)
{
list = ptr;
}
/* ---------------------------------------------------------------------- */
void ComputeSLCSAAtom::compute_peratom()
{
invoked_peratom = update->ntimestep;
// grow per-atom if necessary
if (atom->nmax > nmax) {
memory->destroy(classification);
nmax = atom->nmax;
memory->create(classification, nmax, ncols, "slcsa/atom:classification");
array_atom = classification;
}
int *mask = atom->mask;
int nlocal = atom->nlocal;
if (descriptorval.which == ArgInfo::COMPUTE) {
if (!(descriptorval.val.c->invoked_flag & Compute::INVOKED_PERATOM)) {
descriptorval.val.c->compute_peratom();
descriptorval.val.c->invoked_flag |= Compute::INVOKED_PERATOM;
}
double **compute_array = descriptorval.val.c->array_atom;
memory->create(full_descriptor, ncomps, "slcsa/atom:local descriptor");
memory->create(projected_descriptor, nclasses - 1, "slcsa/atom:reduced descriptor");
memory->create(scores, nclasses, "slcsa/atom:scores");
memory->create(probas, nclasses, "slcsa/atom:probas");
memory->create(prodright, nclasses - 1, "slcsa/atom:prodright");
memory->create(dmaha, nclasses, "slcsa/atom:prodright");
for (int i = 0; i < nlocal; i++) {
if (mask[i] & groupbit) {
for (int j = 0; j < ncomps; j++) full_descriptor[j] = compute_array[i][j];
// Here comes the LDA + LR process
// 1st step : Retrieve mean database descriptor
for (int j = 0; j < ncomps; j++) full_descriptor[j] -= database_mean_descriptor[j];
// 2nd step : Matrix multiplication to go from ncompsx1 -> (nclasses-1)*1
for (int j = 0; j < nclasses - 1; j++) {
projected_descriptor[j] = 0.0;
for (int k = 0; k < ncomps; k++) {
projected_descriptor[j] += full_descriptor[k] * lda_scalings[k][j];
}
}
// 3rd step : Matrix multiplication
for (int j = 0; j < nclasses; j++) {
scores[j] = lr_bias[j];
for (int k = 0; k < nclasses - 1; k++) {
scores[j] += lr_decision[j][k] * projected_descriptor[k];
}
}
// 4th step : Matrix multiplication
double sumexpscores = 0.0;
for (int j = 0; j < nclasses; j++) sumexpscores += exp(scores[j]);
for (int j = 0; j < nclasses; j++) probas[j] = exp(scores[j]) / sumexpscores;
classification[i][nclasses] = argmax(probas, nclasses);
// 5th step : Mahalanobis distance
for (int j = 0; j < nclasses; j++) {
prodright[0] = 0.0;
prodright[1] = 0.0;
prodright[2] = 0.0;
for (int k = 0; k < nclasses - 1; k++) {
for (int l = 0; l < nclasses - 1; l++) {
prodright[k] += (icov_list[j][k][l] *
(projected_descriptor[k] - mean_projected_descriptors[j][k]));
}
}
double prodleft = 0.0;
for (int k = 0; k < nclasses - 1; k++) {
prodleft +=
(prodright[k] * (projected_descriptor[k] - mean_projected_descriptors[j][k]));
}
classification[i][j] = sqrt(prodleft);
}
// 6th step : Sanity check
int locclass = classification[i][nclasses];
if (classification[i][locclass] > maha_thresholds[locclass]) {
classification[i][nclasses] = -1.0;
}
} else {
for (int j = 0; j < ncols; j++) classification[i][j] = -1.0;
}
}
memory->destroy(full_descriptor);
memory->destroy(projected_descriptor);
memory->destroy(scores);
memory->destroy(probas);
memory->destroy(prodright);
memory->destroy(dmaha);
}
}
int ComputeSLCSAAtom::compute_ncomps(int twojmax)
{
int ncount;
ncount = 0;
for (int j1 = 0; j1 <= twojmax; j1++)
for (int j2 = 0; j2 <= j1; j2++)
for (int j = j1 - j2; j <= MIN(twojmax, j1 + j2); j += 2)
if (j >= j1) ncount++;
return ncount;
}
int ComputeSLCSAAtom::argmax(double arr[], int size)
{
int maxIndex = 0; // Initialize the index of the maximum value to the first element.
double maxValue = arr[0]; // Initialize the maximum value to the first element.
for (int i = 1; i < size; ++i) {
if (arr[i] > maxValue) {
// If a greater value is found, update the maxIndex and maxValue.
maxIndex = i;
maxValue = arr[i];
}
}
return maxIndex;
}