implement numpy wrapper for setting per-atom energy. virial not yet implemented

This commit is contained in:
Axel Kohlmeyer
2021-07-22 16:59:04 -04:00
parent 324ae3181b
commit bf8bde5b03
4 changed files with 60 additions and 10 deletions

View File

@ -1830,7 +1830,7 @@ class lammps(object):
# -------------------------------------------------------------------------
def fix_external_set_energy_peratom(self, fix_id, eatom):
"""Set the global energy contribution for a fix external instance with the given ID.
"""Set the per-atom energy contribution for a fix external instance with the given ID.
This is a wrapper around the :cpp:func:`lammps_fix_external_set_energy_peratom` function
of the C-library interface.
@ -1848,15 +1848,15 @@ class lammps(object):
# -------------------------------------------------------------------------
def fix_external_set_virial_peratom(self, fix_id, virial):
"""Set the global virial contribution for a fix external instance with the given ID.
def fix_external_set_virial_peratom(self, fix_id, vatom):
"""Set the per-atom virial contribution for a fix external instance with the given ID.
This is a wrapper around the :cpp:func:`lammps_fix_external_set_virial_peratom` function
of the C-library interface.
:param fix_id: Fix-ID of a fix external instance
:type: string
:param eng: list of natoms lists with 6 floating point numbers to be added by fix external
:param vatom: list of natoms lists with 6 floating point numbers to be added by fix external
:type: float
"""
@ -1864,14 +1864,14 @@ class lammps(object):
nlocal = self.extract_setting('nlocal')
vbuf = (c_double * 6)
vptr = POINTER(c_double)
cvirial = (vptr * nlocal)()
c_virial = (vptr * nlocal)()
for i in range(nlocal):
cvirial[i] = vbuf()
c_virial[i] = vbuf()
for j in range(6):
cvirial[i][j] = virial[i][j]
c_virial[i][j] = vatom[i][j]
with ExceptionCheck(self):
return self.lib.lammps_fix_external_set_virial_peratom(self.lmp, fix_id.encode(), cvirial)
return self.lib.lammps_fix_external_set_virial_peratom(self.lmp, fix_id.encode(), c_virial)
# -------------------------------------------------------------------------
def fix_external_set_vector_length(self, fix_id, length):

View File

@ -268,6 +268,49 @@ class numpy_wrapper:
# -------------------------------------------------------------------------
def fix_external_set_energy_peratom(self, fix_id, eatom):
"""Set the per-atom energy contribution for a fix external instance with the given ID.
This function is an alternative to
:py:meth:`lammps.fix_external_set_energy_peratom() <lammps.lammps.fix_external_set_energy_peratom()>`
method. It behaves the same as the original method, but accepts a NumPy array
instead of a list as argument.
:param fix_id: Fix-ID of a fix external instance
:type: string
:param eatom: per-atom potential energy
:type: numpy.array
"""
import numpy as np
nlocal = self.lmp.extract_setting('nlocal')
c_double_p = POINTER(c_double)
value = eatom.astype(np.double)
return self.lmp.lib.lammps_fix_external_set_energy_peratom(self.lmp.lmp, fix_id.encode(),
value.ctypes.data_as(c_double_p))
# -------------------------------------------------------------------------
def fix_external_set_virial_peratom(self, fix_id, vatom):
"""Set the per-atom virial contribution for a fix external instance with the given ID.
This function is an alternative to
:py:meth:`lammps.fix_external_set_virial_peratom() <lammps.lammps.fix_external_set_virial_peratom()>`
method. It behaves the same as the original method, but accepts a NumPy array
instead of a list as argument.
.. note::
This function is not yet implemented.
:param fix_id: Fix-ID of a fix external instance
:type: string
:param eatom: per-atom potential energy
:type: numpy.array
"""
raise Exception('fix_external_set_virial_peratom() not yet implemented for NumPy arrays')
# -------------------------------------------------------------------------
def get_neighlist(self, idx):
"""Returns an instance of :class:`NumPyNeighList` which wraps access to the neighbor list with the given index

View File

@ -32,6 +32,13 @@ def callback_one(lmp, ntimestep, nlocal, tag, x, f):
lmp.fix_external_set_energy_peratom("ext",eatom)
lmp.fix_external_set_virial_peratom("ext",vatom)
#import numpy as np
#eng = np.array(eatom)
#vir = np.array(vatom)
#lmp.numpy.fix_external_set_energy_peratom("ext",eng)
#lmp.numpy.fix_external_set_virial_peratom("ext",vir) # not yet implemented
class PythonExternal(unittest.TestCase):
def testExternalCallback(self):
"""Test fix external from Python with pf/callback"""