Simplify extract_atom and extract_global in Python interface

Both extract methods now can auto-detect the datatype of both global
and per-atom properties. Callers can still enforce different types
if needed by specifying the now optional dtype argument.

The numpy wrapper now has a new extract_atom function method, which
replace the extract_atom_darray and extract_atom_iarray method and
autodetects both type and size. All parameters can still be forced
to use different values if needed.
This commit is contained in:
Richard Berger
2020-09-17 16:16:17 -04:00
parent d88810f13a
commit b81ad54baa
2 changed files with 162 additions and 31 deletions

View File

@ -19,6 +19,7 @@ from __future__ import print_function
# imports for simple LAMMPS python wrapper module "lammps" # imports for simple LAMMPS python wrapper module "lammps"
import sys,traceback,types import sys,traceback,types
import warnings
from ctypes import * from ctypes import *
from os.path import dirname,abspath,join from os.path import dirname,abspath,join
from inspect import getsourcefile from inspect import getsourcefile
@ -33,6 +34,7 @@ import sys
# various symbolic constants to be used # various symbolic constants to be used
# in certain calls to select data formats # in certain calls to select data formats
LAMMPS_AUTODETECT = None
LAMMPS_INT = 0 LAMMPS_INT = 0
LAMMPS_INT2D = 1 LAMMPS_INT2D = 1
LAMMPS_DOUBLE = 2 LAMMPS_DOUBLE = 2
@ -314,6 +316,8 @@ class lammps(object):
self.lib.lammps_get_last_error_message.restype = c_int self.lib.lammps_get_last_error_message.restype = c_int
self.lib.lammps_extract_global.argtypes = [c_void_p, c_char_p] self.lib.lammps_extract_global.argtypes = [c_void_p, c_char_p]
self.lib.lammps_extract_global_datatype.argtypes = [c_void_p, c_char_p]
self.lib.lammps_extract_global_datatype.restype = c_int
self.lib.lammps_extract_compute.argtypes = [c_void_p, c_char_p, c_int, c_int] self.lib.lammps_extract_compute.argtypes = [c_void_p, c_char_p, c_int, c_int]
self.lib.lammps_get_thermo.argtypes = [c_void_p, c_char_p] self.lib.lammps_get_thermo.argtypes = [c_void_p, c_char_p]
@ -338,6 +342,8 @@ class lammps(object):
self.lib.lammps_decode_image_flags.argtypes = [self.c_imageint, POINTER(c_int*3)] self.lib.lammps_decode_image_flags.argtypes = [self.c_imageint, POINTER(c_int*3)]
self.lib.lammps_extract_atom.argtypes = [c_void_p, c_char_p] self.lib.lammps_extract_atom.argtypes = [c_void_p, c_char_p]
self.lib.lammps_extract_atom_datatype.argtypes = [c_void_p, c_char_p]
self.lib.lammps_extract_atom_datatype.restype = c_int
self.lib.lammps_extract_fix.argtypes = [c_void_p, c_char_p, c_int, c_int, c_int, c_int] self.lib.lammps_extract_fix.argtypes = [c_void_p, c_char_p, c_int, c_int, c_int, c_int]
@ -474,7 +480,36 @@ class lammps(object):
return np.int64 return np.int64
return np.intc return np.intc
def extract_atom(self, name, dtype=LAMMPS_AUTODETECT, nelem=LAMMPS_AUTODETECT, dim=LAMMPS_AUTODETECT):
if dtype == LAMMPS_AUTODETECT:
dtype = self.lmp.extract_atom_datatype(name)
if nelem == LAMMPS_AUTODETECT:
if name == "mass":
nelem = self.lmp.extract_global("ntypes") + 1
else:
nelem = self.lmp.extract_global("nlocal")
if dim == LAMMPS_AUTODETECT:
if dtype in (LAMMPS_INT2D, LAMMPS_DOUBLE2D, LAMMPS_TAGINT2D):
dim = 2
else:
dim = 1
raw_ptr = self.lmp.extract_atom(name, dtype)
if dtype in (LAMMPS_DOUBLE, LAMMPS_DOUBLE2D):
return self.darray(raw_ptr, nelem, dim)
elif dtype in (LAMMPS_INT, LAMMPS_INT2D):
return self.iarray(c_int, raw_ptr, nelem, dim)
elif dtype in (LAMMPS_TAGINT, LAMMPS_TAGINT2D):
return self.iarray(self.lmp.c_tagint, raw_ptr, nelem, dim)
elif dtype == LAMMPS_BIGINT:
return self.iarray(self.lmp.c_bigint, raw_ptr, nelem, dim)
return raw_ptr
def extract_atom_iarray(self, name, nelem, dim=1): def extract_atom_iarray(self, name, nelem, dim=1):
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
if name in ['id', 'molecule']: if name in ['id', 'molecule']:
c_int_type = self.lmp.c_tagint c_int_type = self.lmp.c_tagint
elif name in ['image']: elif name in ['image']:
@ -490,6 +525,8 @@ class lammps(object):
return self.iarray(c_int_type, raw_ptr, nelem, dim) return self.iarray(c_int_type, raw_ptr, nelem, dim)
def extract_atom_darray(self, name, nelem, dim=1): def extract_atom_darray(self, name, nelem, dim=1):
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
if dim == 1: if dim == 1:
raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE) raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE)
else: else:
@ -802,10 +839,34 @@ class lammps(object):
else: return None else: return None
return int(self.lib.lammps_extract_setting(self.lmp,name)) return int(self.lib.lammps_extract_setting(self.lmp,name))
# -------------------------------------------------------------------------
# extract global info datatype
def extract_global_datatype(self, name):
"""Retrieve global property datatype from LAMMPS
This is a wrapper around the :cpp:func:`lammps_extract_global_datatype`
function of the C-library interface. Its documentation includes a
list of the supported keywords.
This function returns ``None`` if the keyword is not
recognized. Otherwise it will return a positive integer value that
corresponds to one of the contants define in the :py:mod:`lammps` module:
``LAMMPS_INT``, ``LAMMPS_INT2D``, ``LAMMPS_DOUBLE``,``LAMMPS_DOUBLE2D``,
``LAMMPS_BIGINT``, ``LAMMPS_TAGINT``, ``LAMMPS_TAGINT2D``, and ``LAMMPS_STRING``.
:param name: name of the property
:type name: string
:return: datatype of global property
:rtype: int
"""
if name: name = name.encode()
else: return None
return self.lib.lammps_extract_global_datatype(self.lmp, name)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# extract global info # extract global info
def extract_global(self, name, type): def extract_global(self, name, dtype=LAMMPS_AUTODETECT):
"""Query LAMMPS about global settings of different types. """Query LAMMPS about global settings of different types.
This is a wrapper around the :cpp:func:`lammps_extract_global` This is a wrapper around the :cpp:func:`lammps_extract_global`
@ -815,55 +876,84 @@ class lammps(object):
of values. The :cpp:func:`lammps_extract_global` documentation of values. The :cpp:func:`lammps_extract_global` documentation
includes a list of the supported keywords and their data types. includes a list of the supported keywords and their data types.
Since Python needs to know the data type to be able to interpret Since Python needs to know the data type to be able to interpret
the result, the type has to be provided as an argument. For the result, by default, this function will try to auto-detect the datatype
by asking the library. You can also force a specific data type. For
that purpose the :py:mod:`lammps` module contains the constants that purpose the :py:mod:`lammps` module contains the constants
``LAMMPS_INT``, ``LAMMPS_DOUBLE``, ``LAMMPS_BIGINT``, ``LAMMPS_INT``, ``LAMMPS_DOUBLE``, ``LAMMPS_BIGINT``,
``LAMMPS_TAGINT``, and ``LAMMPS_STRING``. ``LAMMPS_TAGINT``, and ``LAMMPS_STRING``.
This function returns ``None`` if either the keyword is not This function returns ``None`` if either the keyword is not
recognized, or an invalid data type constant is used. recognized, or an invalid data type constant is used.
:param name: name of the setting :param name: name of the property
:type name: string :type name: string
:param type: type of the returned data :param dtype: type of the returned data
:type type: int :type dtype: int, optional
:return: value of the setting :return: value of the property
:rtype: integer or double or string or None :rtype: integer or double or string or None
""" """
if dtype == LAMMPS_AUTODETECT:
dtype = self.extract_global_datatype(name)
if name: name = name.encode() if name: name = name.encode()
else: return None else: return None
if type == LAMMPS_INT:
if dtype == LAMMPS_INT:
self.lib.lammps_extract_global.restype = POINTER(c_int) self.lib.lammps_extract_global.restype = POINTER(c_int)
elif type == LAMMPS_DOUBLE: elif dtype == LAMMPS_DOUBLE:
self.lib.lammps_extract_global.restype = POINTER(c_double) self.lib.lammps_extract_global.restype = POINTER(c_double)
elif type == LAMMPS_BIGINT: elif dtype == LAMMPS_BIGINT:
self.lib.lammps_extract_global.restype = POINTER(self.c_bigint) self.lib.lammps_extract_global.restype = POINTER(self.c_bigint)
elif type == LAMMPS_TAGINT: elif dtype == LAMMPS_TAGINT:
self.lib.lammps_extract_global.restype = POINTER(self.c_tagint) self.lib.lammps_extract_global.restype = POINTER(self.c_tagint)
elif type == LAMMPS_STRING: elif dtype == LAMMPS_STRING:
self.lib.lammps_extract_global.restype = c_char_p self.lib.lammps_extract_global.restype = c_char_p
ptr = self.lib.lammps_extract_global(self.lmp,name) ptr = self.lib.lammps_extract_global(self.lmp, name)
return str(ptr,'ascii') return str(ptr,'ascii')
else: return None else: return None
ptr = self.lib.lammps_extract_global(self.lmp,name)
ptr = self.lib.lammps_extract_global(self.lmp, name)
if ptr: return ptr[0] if ptr: return ptr[0]
else: return None else: return None
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# extract per-atom info # extract per-atom info datatype
# NOTE: need to insure are converting to/from correct Python type
# e.g. for Python list or NumPy or ctypes
def extract_atom(self,name,type): def extract_atom_datatype(self, name):
"""Retrieve per-atom property datatype from LAMMPS
This is a wrapper around the :cpp:func:`lammps_extract_atom_datatype`
function of the C-library interface. Its documentation includes a
list of the supported keywords.
This function returns ``None`` if the keyword is not
recognized. Otherwise it will return a positive integer value that
corresponds to one of the contants define in the :py:mod:`lammps` module:
``LAMMPS_INT``, ``LAMMPS_INT2D``, ``LAMMPS_DOUBLE``,``LAMMPS_DOUBLE2D``,
``LAMMPS_BIGINT``, ``LAMMPS_TAGINT``, ``LAMMPS_TAGINT2D``, and ``LAMMPS_STRING``.
:param name: name of the property
:type name: string
:return: datatype of per-atom property
:rtype: int
"""
if name: name = name.encode()
else: return None
return self.lib.lammps_extract_atom_datatype(self.lmp, name)
# -------------------------------------------------------------------------
# extract per-atom info
def extract_atom(self, name, dtype=LAMMPS_AUTODETECT):
"""Retrieve per-atom properties from LAMMPS """Retrieve per-atom properties from LAMMPS
This is a wrapper around the :cpp:func:`lammps_extract_atom` This is a wrapper around the :cpp:func:`lammps_extract_atom`
function of the C-library interface. Its documentation includes a function of the C-library interface. Its documentation includes a
list of the supported keywords and their data types. list of the supported keywords and their data types.
Since Python needs to know the data type to be able to interpret Since Python needs to know the data type to be able to interpret
the result, the type has to be provided as an argument. For the result, by default, this function will try to auto-detect the datatype
by asking the library. You can also force a specific data type. For
that purpose the :py:mod:`lammps` module contains the constants that purpose the :py:mod:`lammps` module contains the constants
``LAMMPS_INT``, ``LAMMPS_INT2D``, ``LAMMPS_DOUBLE``, ``LAMMPS_INT``, ``LAMMPS_DOUBLE``, ``LAMMPS_BIGINT``,
and ``LAMMPS_DOUBLE2D``. ``LAMMPS_TAGINT``, and ``LAMMPS_STRING``.
This function returns ``None`` if either the keyword is not This function returns ``None`` if either the keyword is not
recognized, or an invalid data type constant is used. recognized, or an invalid data type constant is used.
@ -876,27 +966,35 @@ class lammps(object):
atoms. In some cases, this depends on a LAMMPS setting, see atoms. In some cases, this depends on a LAMMPS setting, see
for example :doc:`comm_modify vel yes <comm_modify>`. for example :doc:`comm_modify vel yes <comm_modify>`.
:param name: name of the setting :param name: name of the property
:type name: string :type name: string
:param type: type of the returned data :param dtype: type of the returned data
:type type: int :type dtype: int, optional
:return: requested data :return: requested data
:rtype: pointer to integer or double or None :rtype: pointer to integer or double or None
""" """
ntypes = int(self.extract_setting('ntypes')) if dtype == LAMMPS_AUTODETECT:
nmax = int(self.extract_setting('nmax')) dtype = self.lib.lammps_extract_atom_datatype(self.lmp, name)
if name: name = name.encode() if name: name = name.encode()
else: return None else: return None
if type == LAMMPS_INT:
if dtype == LAMMPS_INT:
self.lib.lammps_extract_atom.restype = POINTER(c_int) self.lib.lammps_extract_atom.restype = POINTER(c_int)
elif type == LAMMPS_INT2D: elif dtype == LAMMPS_INT2D:
self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int)) self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int))
elif type == LAMMPS_DOUBLE: elif dtype == LAMMPS_DOUBLE:
self.lib.lammps_extract_atom.restype = POINTER(c_double) self.lib.lammps_extract_atom.restype = POINTER(c_double)
elif type == LAMMPS_DOUBLE2D: elif dtype == LAMMPS_DOUBLE2D:
self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_double)) self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_double))
elif dtype == LAMMPS_TAGINT:
self.lib.lammps_extract_atom.restype = POINTER(self.c_tagint)
elif dtype == LAMMPS_TAGINT2D:
self.lib.lammps_extract_atom.restype = POINTER(POINTER(self.c_tagint))
elif dtype == LAMMPS_BIGINT:
self.lib.lammps_extract_atom.restype = POINTER(self.c_bigint)
else: return None else: return None
ptr = self.lib.lammps_extract_atom(self.lmp,name) ptr = self.lib.lammps_extract_atom(self.lmp, name)
if ptr: return ptr if ptr: return ptr
else: return None else: return None

View File

@ -73,7 +73,7 @@ class PythonNumpy(unittest.TestCase):
# TODO # TODO
pass pass
def testExtractAtom(self): def testExtractAtomDeprecated(self):
self.lmp.command("units lj") self.lmp.command("units lj")
self.lmp.command("atom_style atomic") self.lmp.command("atom_style atomic")
self.lmp.command("atom_modify map array") self.lmp.command("atom_modify map array")
@ -100,6 +100,39 @@ class PythonNumpy(unittest.TestCase):
x = self.lmp.numpy.extract_atom_darray("x", nlocal, dim=3) x = self.lmp.numpy.extract_atom_darray("x", nlocal, dim=3)
v = self.lmp.numpy.extract_atom_darray("v", nlocal, dim=3) v = self.lmp.numpy.extract_atom_darray("v", nlocal, dim=3)
self.assertEqual(len(x), 2) self.assertEqual(len(x), 2)
self.assertTrue((x[0] == (1.0, 1.0, 1.0)).all())
self.assertTrue((x[1] == (1.0, 1.0, 1.5)).all())
self.assertEqual(len(v), 2)
def testExtractAtom(self):
self.lmp.command("units lj")
self.lmp.command("atom_style atomic")
self.lmp.command("atom_modify map array")
self.lmp.command("region box block 0 2 0 2 0 2")
self.lmp.command("create_box 1 box")
x = [
1.0, 1.0, 1.0,
1.0, 1.0, 1.5
]
types = [1, 1]
self.assertEqual(self.lmp.create_atoms(2, id=None, type=types, x=x), 2)
nlocal = self.lmp.extract_global("nlocal")
self.assertEqual(nlocal, 2)
ident = self.lmp.numpy.extract_atom("id")
self.assertEqual(len(ident), 2)
ntypes = self.lmp.extract_global("ntypes")
self.assertEqual(ntypes, 1)
x = self.lmp.numpy.extract_atom("x", dim=3)
v = self.lmp.numpy.extract_atom("v", dim=3)
self.assertEqual(len(x), 2)
self.assertTrue((x[0] == (1.0, 1.0, 1.0)).all())
self.assertTrue((x[1] == (1.0, 1.0, 1.5)).all())
self.assertEqual(len(v), 2) self.assertEqual(len(v), 2)
if __name__ == "__main__": if __name__ == "__main__":