check energy and virial per atom arrays for correct size

This commit is contained in:
Axel Kohlmeyer
2021-07-22 22:49:16 -04:00
parent bf8bde5b03
commit a078d1ba53
2 changed files with 16 additions and 1 deletions

View File

@ -1842,6 +1842,8 @@ class lammps(object):
""" """
nlocal = self.extract_setting('nlocal') nlocal = self.extract_setting('nlocal')
if len(eatom) < nlocal:
raise Exception('per-atom energy list length must be at least nlocal')
ceatom = (nlocal*c_double)(*eatom) ceatom = (nlocal*c_double)(*eatom)
with ExceptionCheck(self): with ExceptionCheck(self):
return self.lib.lammps_fix_external_set_energy_peratom(self.lmp, fix_id.encode(), ceatom) return self.lib.lammps_fix_external_set_energy_peratom(self.lmp, fix_id.encode(), ceatom)
@ -1862,6 +1864,10 @@ class lammps(object):
# copy virial data to C compatible buffer # copy virial data to C compatible buffer
nlocal = self.extract_setting('nlocal') nlocal = self.extract_setting('nlocal')
if len(vatom) < nlocal:
raise Exception('per-atom virial first dimension must be at least nlocal')
if len(vatom[0]) != 6:
raise Exception('per-atom virial second dimension must be 6')
vbuf = (c_double * 6) vbuf = (c_double * 6)
vptr = POINTER(c_double) vptr = POINTER(c_double)
c_virial = (vptr * nlocal)() c_virial = (vptr * nlocal)()

View File

@ -283,6 +283,9 @@ class numpy_wrapper:
""" """
import numpy as np import numpy as np
nlocal = self.lmp.extract_setting('nlocal') nlocal = self.lmp.extract_setting('nlocal')
if len(eatom) < nlocal:
raise Exception('per-atom energy dimension must be at least nlocal')
c_double_p = POINTER(c_double) c_double_p = POINTER(c_double)
value = eatom.astype(np.double) value = eatom.astype(np.double)
return self.lmp.lib.lammps_fix_external_set_energy_peratom(self.lmp.lmp, fix_id.encode(), return self.lmp.lib.lammps_fix_external_set_energy_peratom(self.lmp.lmp, fix_id.encode(),
@ -307,7 +310,13 @@ class numpy_wrapper:
:param eatom: per-atom potential energy :param eatom: per-atom potential energy
:type: numpy.array :type: numpy.array
""" """
raise Exception('fix_external_set_virial_peratom() not yet implemented for NumPy arrays') import numpy as np
nlocal = self.lmp.extract_setting('nlocal')
if len(vatom) < nlocal:
raise Exception('per-atom virial first dimension must be at least nlocal')
if len(vatom[0]) != 6:
raise Exception('per-atom virial second dimension must be 6')
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------