Merge pull request #4080 from rbberger/compute_reaxff_atom_overflow_fix

Fix buffer overflow in compute reaxff/atom
This commit is contained in:
Axel Kohlmeyer
2024-02-21 10:51:37 -05:00
committed by GitHub
3 changed files with 24 additions and 15 deletions

View File

@ -67,10 +67,10 @@ void ComputeReaxFFAtomKokkos<DeviceType>::init()
template<class DeviceType>
void ComputeReaxFFAtomKokkos<DeviceType>::compute_bonds()
{
if (atom->nlocal > nlocal) {
if (atom->nmax > nmax) {
memory->destroy(array_atom);
nlocal = atom->nlocal;
memory->create(array_atom, nlocal, 3, "reaxff/atom:array_atom");
nmax = atom->nmax;
memory->create(array_atom, nmax, 3, "reaxff/atom:array_atom");
}
// retrieve bond information from kokkos pair style. the data potentially
@ -85,6 +85,7 @@ void ComputeReaxFFAtomKokkos<DeviceType>::compute_bonds()
else
host_pair()->FindBond(maxnumbonds, groupbit);
const int nlocal = atom->nlocal;
nbuf = ((store_bonds ? maxnumbonds*2 : 0) + 3)*nlocal;
if (!buf || ((int)k_buf.extent(0) < nbuf)) {
@ -135,6 +136,7 @@ void ComputeReaxFFAtomKokkos<DeviceType>::compute_local()
int b = 0;
int j = 0;
auto tag = atom->tag;
const int nlocal = atom->nlocal;
for (int i = 0; i < nlocal; ++i) {
const int numbonds = static_cast<int>(buf[j+2]);
@ -161,6 +163,7 @@ void ComputeReaxFFAtomKokkos<DeviceType>::compute_peratom()
compute_bonds();
// extract peratom bond information from buffer
const int nlocal = atom->nlocal;
int j = 0;
for (int i = 0; i < nlocal; ++i) {
@ -180,7 +183,7 @@ void ComputeReaxFFAtomKokkos<DeviceType>::compute_peratom()
template<class DeviceType>
double ComputeReaxFFAtomKokkos<DeviceType>::memory_usage()
{
double bytes = (double)(nlocal*3) * sizeof(double);
double bytes = (double)(nmax*3) * sizeof(double);
if (store_bonds)
bytes += (double)(nbonds*3) * sizeof(double);
bytes += (double)(nbuf > 0 ? nbuf * sizeof(double) : 0);

View File

@ -43,7 +43,7 @@ ComputeReaxFFAtom::ComputeReaxFFAtom(LAMMPS *lmp, int narg, char **arg) :
// initialize output
nlocal = -1;
nmax = -1;
nbonds = 0;
prev_nbonds = -1;
@ -162,20 +162,22 @@ void ComputeReaxFFAtom::compute_bonds()
{
invoked_bonds = update->ntimestep;
if (atom->nlocal > nlocal) {
if (atom->nmax > nmax) {
memory->destroy(abo);
memory->destroy(neighid);
memory->destroy(bondcount);
memory->destroy(array_atom);
nlocal = atom->nlocal;
nmax = atom->nmax;
if (store_bonds) {
memory->create(abo, nlocal, MAXREAXBOND, "reaxff/atom:abo");
memory->create(neighid, nlocal, MAXREAXBOND, "reaxff/atom:neighid");
memory->create(abo, nmax, MAXREAXBOND, "reaxff/atom:abo");
memory->create(neighid, nmax, MAXREAXBOND, "reaxff/atom:neighid");
}
memory->create(bondcount, nlocal, "reaxff/atom:bondcount");
memory->create(array_atom, nlocal, 3, "reaxff/atom:array_atom");
memory->create(bondcount, nmax, "reaxff/atom:bondcount");
memory->create(array_atom, nmax, 3, "reaxff/atom:array_atom");
}
const int nlocal = atom->nlocal;
for (int i = 0; i < nlocal; i++) {
bondcount[i] = 0;
for (int j = 0; store_bonds && j < MAXREAXBOND; j++) {
@ -208,6 +210,8 @@ void ComputeReaxFFAtom::compute_local()
int b = 0;
const int nlocal = atom->nlocal;
for (int i = 0; i < nlocal; ++i) {
const int numbonds = bondcount[i];
@ -230,6 +234,8 @@ void ComputeReaxFFAtom::compute_peratom()
compute_bonds();
}
const int nlocal = atom->nlocal;
for (int i = 0; i < nlocal; ++i) {
auto ptr = array_atom[i];
ptr[0] = reaxff->api->workspace->total_bond_order[i];
@ -244,10 +250,10 @@ void ComputeReaxFFAtom::compute_peratom()
double ComputeReaxFFAtom::memory_usage()
{
double bytes = (double)(nlocal*3) * sizeof(double);
bytes += (double)(nlocal) * sizeof(int);
double bytes = (double)(nmax*3) * sizeof(double);
bytes += (double)(nmax) * sizeof(int);
if (store_bonds) {
bytes += (double)(2*nlocal*MAXREAXBOND) * sizeof(double);
bytes += (double)(2*nmax*MAXREAXBOND) * sizeof(double);
bytes += (double)(nbonds*3) * sizeof(double);
}
return bytes;

View File

@ -40,7 +40,7 @@ class ComputeReaxFFAtom : public Compute {
protected:
bigint invoked_bonds; // last timestep on which compute_bonds() was invoked
int nlocal;
int nmax;
int nbonds;
int prev_nbonds;
int nsub;