implement setting per-atom virial from numpy array (thanks to stackoverflow)

This commit is contained in:
Axel Kohlmeyer
2021-07-22 22:50:05 -04:00
parent a078d1ba53
commit c8cc5ecb9f
2 changed files with 41 additions and 18 deletions

View File

@ -17,7 +17,7 @@
################################################################################
import warnings
from ctypes import POINTER, c_double, c_int, c_int32, c_int64, cast
from ctypes import POINTER, c_void_p, c_char_p, c_double, c_int, c_int32, c_int64, cast
from .constants import * # lgtm [py/polluting-import]
@ -301,10 +301,6 @@ class numpy_wrapper:
method. It behaves the same as the original method, but accepts a NumPy array
instead of a list as argument.
.. note::
This function is not yet implemented.
:param fix_id: Fix-ID of a fix external instance
:type: string
:param eatom: per-atom potential energy
@ -317,6 +313,16 @@ class numpy_wrapper:
if len(vatom[0]) != 6:
raise Exception('per-atom virial second dimension must be 6')
c_double_pp = np.ctypeslib.ndpointer(dtype=np.uintp, ndim=1, flags='C')
# recast numpy array to be compatible with library interface
value = (vatom.__array_interface__['data'][0]
+ np.arange(vatom.shape[0])*vatom.strides[0]).astype(np.uintp)
# change prototype to our custom type
self.lmp.lib.lammps_fix_external_set_virial_peratom.argtypes = [ c_void_p, c_char_p, c_double_pp ]
self.lmp.lib.lammps_fix_external_set_virial_peratom(self.lmp.lmp, fix_id.encode(), value)
# -------------------------------------------------------------------------

View File

@ -2,6 +2,12 @@ import sys,os,unittest
from ctypes import *
from lammps import lammps, LMP_STYLE_GLOBAL, LMP_TYPE_VECTOR
try:
import numpy
NUMPY_INSTALLED = True
except ImportError:
NUMPY_INSTALLED = False
# add timestep dependent force
def callback_one(lmp, ntimestep, nlocal, tag, x, f):
lmp.fix_external_set_virial_global("ext",[1.0, 1.0, 1.0, 0.0, 0.0, 0.0])
@ -29,17 +35,21 @@ def callback_one(lmp, ntimestep, nlocal, tag, x, f):
[0.0,0.0,0.0,0.0,0.0,0.6],
[0.0,0.0,0.0,0.0,-7.0,0.0],
[0.0,-8.0,0.0,0.0,0.0,0.0] ]
lmp.fix_external_set_energy_peratom("ext",eatom)
lmp.fix_external_set_virial_peratom("ext",vatom)
if ntimestep < 5:
lmp.fix_external_set_energy_peratom("ext",eatom)
lmp.fix_external_set_virial_peratom("ext",vatom)
else:
import numpy as np
eng = np.array(eatom)
vir = np.array(vatom)
#import numpy as np
#eng = np.array(eatom)
#vir = np.array(vatom)
lmp.numpy.fix_external_set_energy_peratom("ext",eng)
lmp.numpy.fix_external_set_virial_peratom("ext",vir)
#lmp.numpy.fix_external_set_energy_peratom("ext",eng)
#lmp.numpy.fix_external_set_virial_peratom("ext",vir) # not yet implemented
# ------------------------------------------------------------------------
class PythonExternal(unittest.TestCase):
@unittest.skipIf(not NUMPY_INSTALLED, "NumPy is not available")
def testExternalCallback(self):
"""Test fix external from Python with pf/callback"""
@ -70,10 +80,23 @@ class PythonExternal(unittest.TestCase):
lmp.commands_string(basic_system)
lmp.fix_external_set_vector_length("ext",6);
lmp.set_fix_external_callback("ext",callback_one,lmp)
# check setting per-atom data with python lists
lmp.command("run 0 post no")
reduce = lmp.extract_compute("sum", LMP_STYLE_GLOBAL, LMP_TYPE_VECTOR)
self.assertAlmostEqual(reduce[0],2.8,14)
self.assertAlmostEqual(reduce[1],-0.1,14)
self.assertAlmostEqual(reduce[2],7.8,14)
self.assertAlmostEqual(reduce[3],-0.3,14)
self.assertAlmostEqual(reduce[4],-0.4,14)
self.assertAlmostEqual(reduce[5],6.5,14)
self.assertAlmostEqual(reduce[6],-0.6,14)
lmp.command("run 10 post no")
self.assertAlmostEqual(lmp.get_thermo("temp"),1.0/30.0,14)
self.assertAlmostEqual(lmp.get_thermo("pe"),1.0/8.0,14)
self.assertAlmostEqual(lmp.get_thermo("press"),0.15416666666666667,14)
# check setting per-atom data numpy arrays
reduce = lmp.extract_compute("sum", LMP_STYLE_GLOBAL, LMP_TYPE_VECTOR)
self.assertAlmostEqual(reduce[0],2.8,14)
self.assertAlmostEqual(reduce[1],-0.1,14)
@ -90,12 +113,6 @@ class PythonExternal(unittest.TestCase):
def testExternalArray(self):
"""Test fix external from Python with pf/array"""
try:
import numpy
NUMPY_INSTALLED = True
except ImportError:
NUMPY_INSTALLED = False
machine=None
if 'LAMMPS_MACHINE_NAME' in os.environ:
machine=os.environ['LAMMPS_MACHINE_NAME']