python interface for per-atom data for fix external

This commit is contained in:
Axel Kohlmeyer
2021-07-22 15:27:51 -04:00
parent bb46dd7d1f
commit 324ae3181b
3 changed files with 109 additions and 4 deletions

View File

@ -1749,8 +1749,8 @@ class lammps(object):
- ntimestep is the current timestep - ntimestep is the current timestep
- nlocal is the number of local atoms on the current MPI process - nlocal is the number of local atoms on the current MPI process
- tag is a 1d NumPy array of integers representing the atom IDs of the local atoms - tag is a 1d NumPy array of integers representing the atom IDs of the local atoms
- x is a 2d NumPy array of floating point numbers of the coordinates of the local atoms - x is a 2d NumPy array of doubles of the coordinates of the local atoms
- f is a 2d NumPy array of floating point numbers of the forces on the local atoms that will be added - f is a 2d NumPy array of doubles of the forces on the local atoms that will be added
:param fix_id: Fix-ID of a fix external instance :param fix_id: Fix-ID of a fix external instance
:type: string :type: string
@ -1777,7 +1777,7 @@ class lammps(object):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def fix_external_get_force(self, fix_id): def fix_external_get_force(self, fix_id):
"""Get access to that array with per-atom forces of a fix external instance with a given fix ID. """Get access to the array with per-atom forces of a fix external instance with a given fix ID.
This is a wrapper around the :cpp:func:`lammps_fix_external_get_force` function This is a wrapper around the :cpp:func:`lammps_fix_external_get_force` function
of the C-library interface. of the C-library interface.
@ -1827,6 +1827,52 @@ class lammps(object):
with ExceptionCheck(self): with ExceptionCheck(self):
return self.lib.lammps_fix_external_set_virial_global(self.lmp, fix_id.encode(), cvirial) return self.lib.lammps_fix_external_set_virial_global(self.lmp, fix_id.encode(), cvirial)
# -------------------------------------------------------------------------
def fix_external_set_energy_peratom(self, fix_id, eatom):
"""Set the global energy contribution for a fix external instance with the given ID.
This is a wrapper around the :cpp:func:`lammps_fix_external_set_energy_peratom` function
of the C-library interface.
:param fix_id: Fix-ID of a fix external instance
:type: string
:param eatom: list of potential energy values for local atoms to be added by fix external
:type: float
"""
nlocal = self.extract_setting('nlocal')
ceatom = (nlocal*c_double)(*eatom)
with ExceptionCheck(self):
return self.lib.lammps_fix_external_set_energy_peratom(self.lmp, fix_id.encode(), ceatom)
# -------------------------------------------------------------------------
def fix_external_set_virial_peratom(self, fix_id, virial):
"""Set the global virial contribution for a fix external instance with the given ID.
This is a wrapper around the :cpp:func:`lammps_fix_external_set_virial_peratom` function
of the C-library interface.
:param fix_id: Fix-ID of a fix external instance
:type: string
:param eng: list of natoms lists with 6 floating point numbers to be added by fix external
:type: float
"""
# copy virial data to C compatible buffer
nlocal = self.extract_setting('nlocal')
vbuf = (c_double * 6)
vptr = POINTER(c_double)
cvirial = (vptr * nlocal)()
for i in range(nlocal):
cvirial[i] = vbuf()
for j in range(6):
cvirial[i][j] = virial[i][j]
with ExceptionCheck(self):
return self.lib.lammps_fix_external_set_virial_peratom(self.lmp, fix_id.encode(), cvirial)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def fix_external_set_vector_length(self, fix_id, length): def fix_external_set_vector_length(self, fix_id, length):
"""Set the vector length for a global vector stored with fix external for analysis """Set the vector length for a global vector stored with fix external for analysis

View File

@ -248,6 +248,26 @@ class numpy_wrapper:
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def fix_external_get_force(self, fix_id):
"""Get access to the array with per-atom forces of a fix external instance with a given fix ID.
This function is a wrapper around the
:py:meth:`lammps.fix_external_get_force() <lammps.lammps.fix_external_get_force()>`
method. It behaves the same as the original method, but returns a NumPy array instead
of a ``ctypes`` pointer.
:param fix_id: Fix-ID of a fix external instance
:type: string
:return: requested data
:rtype: numpy.array
"""
import numpy as np
nlocal = self.lmp.extract_setting('nlocal')
value = self.lmp.fix_external_get_force(fix_id)
return self.darray(value,nlocal,3)
# -------------------------------------------------------------------------
def get_neighlist(self, idx): def get_neighlist(self, idx):
"""Returns an instance of :class:`NumPyNeighList` which wraps access to the neighbor list with the given index """Returns an instance of :class:`NumPyNeighList` which wraps access to the neighbor list with the given index

View File

@ -20,6 +20,18 @@ def callback_one(lmp, ntimestep, nlocal, tag, x, f):
lmp.fix_external_set_vector("ext", 5, -1.0) lmp.fix_external_set_vector("ext", 5, -1.0)
lmp.fix_external_set_vector("ext", 6, 0.25) lmp.fix_external_set_vector("ext", 6, 0.25)
eatom = [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7]
vatom = [ [0.1,0.0,0.0,0.0,0.0,0.0],
[0.0,0.2,0.0,0.0,0.0,0.0],
[0.0,0.0,0.3,0.0,0.0,0.0],
[0.0,0.0,0.0,0.4,0.0,0.0],
[0.0,0.0,0.0,0.0,0.5,0.0],
[0.0,0.0,0.0,0.0,0.0,0.6],
[0.0,0.0,0.0,0.0,-7.0,0.0],
[0.0,-8.0,0.0,0.0,0.0,0.0] ]
lmp.fix_external_set_energy_peratom("ext",eatom)
lmp.fix_external_set_virial_peratom("ext",vatom)
class PythonExternal(unittest.TestCase): class PythonExternal(unittest.TestCase):
def testExternalCallback(self): def testExternalCallback(self):
"""Test fix external from Python with pf/callback""" """Test fix external from Python with pf/callback"""
@ -42,6 +54,10 @@ class PythonExternal(unittest.TestCase):
thermo 5 thermo 5
fix 1 all nve fix 1 all nve
fix ext all external pf/callback 5 1 fix ext all external pf/callback 5 1
compute eatm all pe/atom fix
compute vatm all stress/atom NULL fix
compute sum all reduce sum c_eatm c_vatm[*]
thermo_style custom step temp pe ke etotal press c_sum[*]
fix_modify ext energy yes virial yes fix_modify ext energy yes virial yes
""" """
lmp.commands_string(basic_system) lmp.commands_string(basic_system)
@ -51,6 +67,14 @@ class PythonExternal(unittest.TestCase):
self.assertAlmostEqual(lmp.get_thermo("temp"),1.0/30.0,14) self.assertAlmostEqual(lmp.get_thermo("temp"),1.0/30.0,14)
self.assertAlmostEqual(lmp.get_thermo("pe"),1.0/8.0,14) self.assertAlmostEqual(lmp.get_thermo("pe"),1.0/8.0,14)
self.assertAlmostEqual(lmp.get_thermo("press"),0.15416666666666667,14) self.assertAlmostEqual(lmp.get_thermo("press"),0.15416666666666667,14)
reduce = lmp.extract_compute("sum", LMP_STYLE_GLOBAL, LMP_TYPE_VECTOR)
self.assertAlmostEqual(reduce[0],2.8,14)
self.assertAlmostEqual(reduce[1],-0.1,14)
self.assertAlmostEqual(reduce[2],7.8,14)
self.assertAlmostEqual(reduce[3],-0.3,14)
self.assertAlmostEqual(reduce[4],-0.4,14)
self.assertAlmostEqual(reduce[5],6.5,14)
self.assertAlmostEqual(reduce[6],-0.6,14)
val = 0.0 val = 0.0
for i in range(0,6): for i in range(0,6):
val += lmp.extract_fix("ext",LMP_STYLE_GLOBAL,LMP_TYPE_VECTOR,nrow=i) val += lmp.extract_fix("ext",LMP_STYLE_GLOBAL,LMP_TYPE_VECTOR,nrow=i)
@ -59,6 +83,12 @@ class PythonExternal(unittest.TestCase):
def testExternalArray(self): def testExternalArray(self):
"""Test fix external from Python with pf/array""" """Test fix external from Python with pf/array"""
try:
import numpy
NUMPY_INSTALLED = True
except ImportError:
NUMPY_INSTALLED = False
machine=None machine=None
if 'LAMMPS_MACHINE_NAME' in os.environ: if 'LAMMPS_MACHINE_NAME' in os.environ:
machine=os.environ['LAMMPS_MACHINE_NAME'] machine=os.environ['LAMMPS_MACHINE_NAME']
@ -93,6 +123,11 @@ class PythonExternal(unittest.TestCase):
self.assertAlmostEqual(lmp.get_thermo("temp"),4.0/525.0,14) self.assertAlmostEqual(lmp.get_thermo("temp"),4.0/525.0,14)
self.assertAlmostEqual(lmp.get_thermo("pe"),1.0/16.0,14) self.assertAlmostEqual(lmp.get_thermo("pe"),1.0/16.0,14)
self.assertAlmostEqual(lmp.get_thermo("press"),0.06916666666666667,14) self.assertAlmostEqual(lmp.get_thermo("press"),0.06916666666666667,14)
if NUMPY_INSTALLED:
npforce = lmp.numpy.fix_external_get_force("ext")
self.assertEqual(len(npforce),8)
self.assertEqual(len(npforce[0]),3)
self.assertEqual(npforce[1][1],0.0)
force = lmp.fix_external_get_force("ext"); force = lmp.fix_external_get_force("ext");
nlocal = lmp.extract_setting("nlocal"); nlocal = lmp.extract_setting("nlocal");
@ -106,8 +141,12 @@ class PythonExternal(unittest.TestCase):
self.assertAlmostEqual(lmp.get_thermo("temp"),1.0/30.0,14) self.assertAlmostEqual(lmp.get_thermo("temp"),1.0/30.0,14)
self.assertAlmostEqual(lmp.get_thermo("pe"),1.0/8.0,14) self.assertAlmostEqual(lmp.get_thermo("pe"),1.0/8.0,14)
self.assertAlmostEqual(lmp.get_thermo("press"),0.15416666666666667,14) self.assertAlmostEqual(lmp.get_thermo("press"),0.15416666666666667,14)
if NUMPY_INSTALLED:
npforce = lmp.numpy.fix_external_get_force("ext")
self.assertEqual(npforce[0][0],6.0)
self.assertEqual(npforce[3][1],6.0)
self.assertEqual(npforce[7][2],6.0)
############################## ##############################
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()