Simplify extract_atom and extract_global in Python interface

Both extract methods now can auto-detect the datatype of both global
and per-atom properties. Callers can still enforce different types
if needed by specifying the now optional dtype argument.

The numpy wrapper now has a new extract_atom function method, which
replace the extract_atom_darray and extract_atom_iarray method and
autodetects both type and size. All parameters can still be forced
to use different values if needed.
This commit is contained in:
Richard Berger
2020-09-17 16:16:17 -04:00
parent d88810f13a
commit b81ad54baa
2 changed files with 162 additions and 31 deletions

View File

@ -19,6 +19,7 @@ from __future__ import print_function
# imports for simple LAMMPS python wrapper module "lammps"
import sys,traceback,types
import warnings
from ctypes import *
from os.path import dirname,abspath,join
from inspect import getsourcefile
@ -33,6 +34,7 @@ import sys
# various symbolic constants to be used
# in certain calls to select data formats
LAMMPS_AUTODETECT = None
LAMMPS_INT = 0
LAMMPS_INT2D = 1
LAMMPS_DOUBLE = 2
@ -314,6 +316,8 @@ class lammps(object):
self.lib.lammps_get_last_error_message.restype = c_int
self.lib.lammps_extract_global.argtypes = [c_void_p, c_char_p]
self.lib.lammps_extract_global_datatype.argtypes = [c_void_p, c_char_p]
self.lib.lammps_extract_global_datatype.restype = c_int
self.lib.lammps_extract_compute.argtypes = [c_void_p, c_char_p, c_int, c_int]
self.lib.lammps_get_thermo.argtypes = [c_void_p, c_char_p]
@ -338,6 +342,8 @@ class lammps(object):
self.lib.lammps_decode_image_flags.argtypes = [self.c_imageint, POINTER(c_int*3)]
self.lib.lammps_extract_atom.argtypes = [c_void_p, c_char_p]
self.lib.lammps_extract_atom_datatype.argtypes = [c_void_p, c_char_p]
self.lib.lammps_extract_atom_datatype.restype = c_int
self.lib.lammps_extract_fix.argtypes = [c_void_p, c_char_p, c_int, c_int, c_int, c_int]
@ -474,7 +480,36 @@ class lammps(object):
return np.int64
return np.intc
def extract_atom(self, name, dtype=LAMMPS_AUTODETECT, nelem=LAMMPS_AUTODETECT, dim=LAMMPS_AUTODETECT):
if dtype == LAMMPS_AUTODETECT:
dtype = self.lmp.extract_atom_datatype(name)
if nelem == LAMMPS_AUTODETECT:
if name == "mass":
nelem = self.lmp.extract_global("ntypes") + 1
else:
nelem = self.lmp.extract_global("nlocal")
if dim == LAMMPS_AUTODETECT:
if dtype in (LAMMPS_INT2D, LAMMPS_DOUBLE2D, LAMMPS_TAGINT2D):
dim = 2
else:
dim = 1
raw_ptr = self.lmp.extract_atom(name, dtype)
if dtype in (LAMMPS_DOUBLE, LAMMPS_DOUBLE2D):
return self.darray(raw_ptr, nelem, dim)
elif dtype in (LAMMPS_INT, LAMMPS_INT2D):
return self.iarray(c_int, raw_ptr, nelem, dim)
elif dtype in (LAMMPS_TAGINT, LAMMPS_TAGINT2D):
return self.iarray(self.lmp.c_tagint, raw_ptr, nelem, dim)
elif dtype == LAMMPS_BIGINT:
return self.iarray(self.lmp.c_bigint, raw_ptr, nelem, dim)
return raw_ptr
def extract_atom_iarray(self, name, nelem, dim=1):
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
if name in ['id', 'molecule']:
c_int_type = self.lmp.c_tagint
elif name in ['image']:
@ -490,6 +525,8 @@ class lammps(object):
return self.iarray(c_int_type, raw_ptr, nelem, dim)
def extract_atom_darray(self, name, nelem, dim=1):
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
if dim == 1:
raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE)
else:
@ -802,10 +839,34 @@ class lammps(object):
else: return None
return int(self.lib.lammps_extract_setting(self.lmp,name))
# -------------------------------------------------------------------------
# extract global info datatype
def extract_global_datatype(self, name):
"""Retrieve global property datatype from LAMMPS
This is a wrapper around the :cpp:func:`lammps_extract_global_datatype`
function of the C-library interface. Its documentation includes a
list of the supported keywords.
This function returns ``None`` if the keyword is not
recognized. Otherwise it will return a positive integer value that
corresponds to one of the contants define in the :py:mod:`lammps` module:
``LAMMPS_INT``, ``LAMMPS_INT2D``, ``LAMMPS_DOUBLE``,``LAMMPS_DOUBLE2D``,
``LAMMPS_BIGINT``, ``LAMMPS_TAGINT``, ``LAMMPS_TAGINT2D``, and ``LAMMPS_STRING``.
:param name: name of the property
:type name: string
:return: datatype of global property
:rtype: int
"""
if name: name = name.encode()
else: return None
return self.lib.lammps_extract_global_datatype(self.lmp, name)
# -------------------------------------------------------------------------
# extract global info
def extract_global(self, name, type):
def extract_global(self, name, dtype=LAMMPS_AUTODETECT):
"""Query LAMMPS about global settings of different types.
This is a wrapper around the :cpp:func:`lammps_extract_global`
@ -815,55 +876,84 @@ class lammps(object):
of values. The :cpp:func:`lammps_extract_global` documentation
includes a list of the supported keywords and their data types.
Since Python needs to know the data type to be able to interpret
the result, the type has to be provided as an argument. For
the result, by default, this function will try to auto-detect the datatype
by asking the library. You can also force a specific data type. For
that purpose the :py:mod:`lammps` module contains the constants
``LAMMPS_INT``, ``LAMMPS_DOUBLE``, ``LAMMPS_BIGINT``,
``LAMMPS_TAGINT``, and ``LAMMPS_STRING``.
This function returns ``None`` if either the keyword is not
recognized, or an invalid data type constant is used.
:param name: name of the setting
:param name: name of the property
:type name: string
:param type: type of the returned data
:type type: int
:return: value of the setting
:param dtype: type of the returned data
:type dtype: int, optional
:return: value of the property
:rtype: integer or double or string or None
"""
if dtype == LAMMPS_AUTODETECT:
dtype = self.extract_global_datatype(name)
if name: name = name.encode()
else: return None
if type == LAMMPS_INT:
if dtype == LAMMPS_INT:
self.lib.lammps_extract_global.restype = POINTER(c_int)
elif type == LAMMPS_DOUBLE:
elif dtype == LAMMPS_DOUBLE:
self.lib.lammps_extract_global.restype = POINTER(c_double)
elif type == LAMMPS_BIGINT:
elif dtype == LAMMPS_BIGINT:
self.lib.lammps_extract_global.restype = POINTER(self.c_bigint)
elif type == LAMMPS_TAGINT:
elif dtype == LAMMPS_TAGINT:
self.lib.lammps_extract_global.restype = POINTER(self.c_tagint)
elif type == LAMMPS_STRING:
elif dtype == LAMMPS_STRING:
self.lib.lammps_extract_global.restype = c_char_p
ptr = self.lib.lammps_extract_global(self.lmp, name)
return str(ptr,'ascii')
else: return None
ptr = self.lib.lammps_extract_global(self.lmp, name)
if ptr: return ptr[0]
else: return None
# -------------------------------------------------------------------------
# extract per-atom info
# NOTE: need to insure are converting to/from correct Python type
# e.g. for Python list or NumPy or ctypes
# extract per-atom info datatype
def extract_atom(self,name,type):
def extract_atom_datatype(self, name):
"""Retrieve per-atom property datatype from LAMMPS
This is a wrapper around the :cpp:func:`lammps_extract_atom_datatype`
function of the C-library interface. Its documentation includes a
list of the supported keywords.
This function returns ``None`` if the keyword is not
recognized. Otherwise it will return a positive integer value that
corresponds to one of the contants define in the :py:mod:`lammps` module:
``LAMMPS_INT``, ``LAMMPS_INT2D``, ``LAMMPS_DOUBLE``,``LAMMPS_DOUBLE2D``,
``LAMMPS_BIGINT``, ``LAMMPS_TAGINT``, ``LAMMPS_TAGINT2D``, and ``LAMMPS_STRING``.
:param name: name of the property
:type name: string
:return: datatype of per-atom property
:rtype: int
"""
if name: name = name.encode()
else: return None
return self.lib.lammps_extract_atom_datatype(self.lmp, name)
# -------------------------------------------------------------------------
# extract per-atom info
def extract_atom(self, name, dtype=LAMMPS_AUTODETECT):
"""Retrieve per-atom properties from LAMMPS
This is a wrapper around the :cpp:func:`lammps_extract_atom`
function of the C-library interface. Its documentation includes a
list of the supported keywords and their data types.
Since Python needs to know the data type to be able to interpret
the result, the type has to be provided as an argument. For
the result, by default, this function will try to auto-detect the datatype
by asking the library. You can also force a specific data type. For
that purpose the :py:mod:`lammps` module contains the constants
``LAMMPS_INT``, ``LAMMPS_INT2D``, ``LAMMPS_DOUBLE``,
and ``LAMMPS_DOUBLE2D``.
``LAMMPS_INT``, ``LAMMPS_DOUBLE``, ``LAMMPS_BIGINT``,
``LAMMPS_TAGINT``, and ``LAMMPS_STRING``.
This function returns ``None`` if either the keyword is not
recognized, or an invalid data type constant is used.
@ -876,25 +966,33 @@ class lammps(object):
atoms. In some cases, this depends on a LAMMPS setting, see
for example :doc:`comm_modify vel yes <comm_modify>`.
:param name: name of the setting
:param name: name of the property
:type name: string
:param type: type of the returned data
:type type: int
:param dtype: type of the returned data
:type dtype: int, optional
:return: requested data
:rtype: pointer to integer or double or None
"""
ntypes = int(self.extract_setting('ntypes'))
nmax = int(self.extract_setting('nmax'))
if dtype == LAMMPS_AUTODETECT:
dtype = self.lib.lammps_extract_atom_datatype(self.lmp, name)
if name: name = name.encode()
else: return None
if type == LAMMPS_INT:
if dtype == LAMMPS_INT:
self.lib.lammps_extract_atom.restype = POINTER(c_int)
elif type == LAMMPS_INT2D:
elif dtype == LAMMPS_INT2D:
self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int))
elif type == LAMMPS_DOUBLE:
elif dtype == LAMMPS_DOUBLE:
self.lib.lammps_extract_atom.restype = POINTER(c_double)
elif type == LAMMPS_DOUBLE2D:
elif dtype == LAMMPS_DOUBLE2D:
self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_double))
elif dtype == LAMMPS_TAGINT:
self.lib.lammps_extract_atom.restype = POINTER(self.c_tagint)
elif dtype == LAMMPS_TAGINT2D:
self.lib.lammps_extract_atom.restype = POINTER(POINTER(self.c_tagint))
elif dtype == LAMMPS_BIGINT:
self.lib.lammps_extract_atom.restype = POINTER(self.c_bigint)
else: return None
ptr = self.lib.lammps_extract_atom(self.lmp, name)
if ptr: return ptr

View File

@ -73,7 +73,7 @@ class PythonNumpy(unittest.TestCase):
# TODO
pass
def testExtractAtom(self):
def testExtractAtomDeprecated(self):
self.lmp.command("units lj")
self.lmp.command("atom_style atomic")
self.lmp.command("atom_modify map array")
@ -100,6 +100,39 @@ class PythonNumpy(unittest.TestCase):
x = self.lmp.numpy.extract_atom_darray("x", nlocal, dim=3)
v = self.lmp.numpy.extract_atom_darray("v", nlocal, dim=3)
self.assertEqual(len(x), 2)
self.assertTrue((x[0] == (1.0, 1.0, 1.0)).all())
self.assertTrue((x[1] == (1.0, 1.0, 1.5)).all())
self.assertEqual(len(v), 2)
def testExtractAtom(self):
self.lmp.command("units lj")
self.lmp.command("atom_style atomic")
self.lmp.command("atom_modify map array")
self.lmp.command("region box block 0 2 0 2 0 2")
self.lmp.command("create_box 1 box")
x = [
1.0, 1.0, 1.0,
1.0, 1.0, 1.5
]
types = [1, 1]
self.assertEqual(self.lmp.create_atoms(2, id=None, type=types, x=x), 2)
nlocal = self.lmp.extract_global("nlocal")
self.assertEqual(nlocal, 2)
ident = self.lmp.numpy.extract_atom("id")
self.assertEqual(len(ident), 2)
ntypes = self.lmp.extract_global("ntypes")
self.assertEqual(ntypes, 1)
x = self.lmp.numpy.extract_atom("x", dim=3)
v = self.lmp.numpy.extract_atom("v", dim=3)
self.assertEqual(len(x), 2)
self.assertTrue((x[0] == (1.0, 1.0, 1.0)).all())
self.assertTrue((x[1] == (1.0, 1.0, 1.5)).all())
self.assertEqual(len(v), 2)
if __name__ == "__main__":