From b81ad54baabb35e19b8a2aa8321d97dd2fe2db42 Mon Sep 17 00:00:00 2001 From: Richard Berger Date: Thu, 17 Sep 2020 16:16:17 -0400 Subject: [PATCH] Simplify extract_atom and extract_global in Python interface Both extract methods now can auto-detect the datatype of both global and per-atom properties. Callers can still enforce different types if needed by specifying the now optional dtype argument. The numpy wrapper now has a new extract_atom function method, which replace the extract_atom_darray and extract_atom_iarray method and autodetects both type and size. All parameters can still be forced to use different values if needed. --- python/lammps.py | 158 ++++++++++++++++++++++++++------ unittest/python/python-numpy.py | 35 ++++++- 2 files changed, 162 insertions(+), 31 deletions(-) diff --git a/python/lammps.py b/python/lammps.py index 112fdc4108..1b0b34691c 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -19,6 +19,7 @@ from __future__ import print_function # imports for simple LAMMPS python wrapper module "lammps" import sys,traceback,types +import warnings from ctypes import * from os.path import dirname,abspath,join from inspect import getsourcefile @@ -33,6 +34,7 @@ import sys # various symbolic constants to be used # in certain calls to select data formats +LAMMPS_AUTODETECT = None LAMMPS_INT = 0 LAMMPS_INT2D = 1 LAMMPS_DOUBLE = 2 @@ -314,6 +316,8 @@ class lammps(object): self.lib.lammps_get_last_error_message.restype = c_int self.lib.lammps_extract_global.argtypes = [c_void_p, c_char_p] + self.lib.lammps_extract_global_datatype.argtypes = [c_void_p, c_char_p] + self.lib.lammps_extract_global_datatype.restype = c_int self.lib.lammps_extract_compute.argtypes = [c_void_p, c_char_p, c_int, c_int] self.lib.lammps_get_thermo.argtypes = [c_void_p, c_char_p] @@ -338,6 +342,8 @@ class lammps(object): self.lib.lammps_decode_image_flags.argtypes = [self.c_imageint, POINTER(c_int*3)] self.lib.lammps_extract_atom.argtypes = [c_void_p, c_char_p] + self.lib.lammps_extract_atom_datatype.argtypes = [c_void_p, c_char_p] + self.lib.lammps_extract_atom_datatype.restype = c_int self.lib.lammps_extract_fix.argtypes = [c_void_p, c_char_p, c_int, c_int, c_int, c_int] @@ -474,7 +480,36 @@ class lammps(object): return np.int64 return np.intc + def extract_atom(self, name, dtype=LAMMPS_AUTODETECT, nelem=LAMMPS_AUTODETECT, dim=LAMMPS_AUTODETECT): + if dtype == LAMMPS_AUTODETECT: + dtype = self.lmp.extract_atom_datatype(name) + + if nelem == LAMMPS_AUTODETECT: + if name == "mass": + nelem = self.lmp.extract_global("ntypes") + 1 + else: + nelem = self.lmp.extract_global("nlocal") + if dim == LAMMPS_AUTODETECT: + if dtype in (LAMMPS_INT2D, LAMMPS_DOUBLE2D, LAMMPS_TAGINT2D): + dim = 2 + else: + dim = 1 + + raw_ptr = self.lmp.extract_atom(name, dtype) + + if dtype in (LAMMPS_DOUBLE, LAMMPS_DOUBLE2D): + return self.darray(raw_ptr, nelem, dim) + elif dtype in (LAMMPS_INT, LAMMPS_INT2D): + return self.iarray(c_int, raw_ptr, nelem, dim) + elif dtype in (LAMMPS_TAGINT, LAMMPS_TAGINT2D): + return self.iarray(self.lmp.c_tagint, raw_ptr, nelem, dim) + elif dtype == LAMMPS_BIGINT: + return self.iarray(self.lmp.c_bigint, raw_ptr, nelem, dim) + return raw_ptr + def extract_atom_iarray(self, name, nelem, dim=1): + warnings.warn("deprecated, use extract_atom instead", DeprecationWarning) + if name in ['id', 'molecule']: c_int_type = self.lmp.c_tagint elif name in ['image']: @@ -490,6 +525,8 @@ class lammps(object): return self.iarray(c_int_type, raw_ptr, nelem, dim) def extract_atom_darray(self, name, nelem, dim=1): + warnings.warn("deprecated, use extract_atom instead", DeprecationWarning) + if dim == 1: raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE) else: @@ -802,10 +839,34 @@ class lammps(object): else: return None return int(self.lib.lammps_extract_setting(self.lmp,name)) + # ------------------------------------------------------------------------- + # extract global info datatype + + def extract_global_datatype(self, name): + """Retrieve global property datatype from LAMMPS + + This is a wrapper around the :cpp:func:`lammps_extract_global_datatype` + function of the C-library interface. Its documentation includes a + list of the supported keywords. + This function returns ``None`` if the keyword is not + recognized. Otherwise it will return a positive integer value that + corresponds to one of the contants define in the :py:mod:`lammps` module: + ``LAMMPS_INT``, ``LAMMPS_INT2D``, ``LAMMPS_DOUBLE``,``LAMMPS_DOUBLE2D``, + ``LAMMPS_BIGINT``, ``LAMMPS_TAGINT``, ``LAMMPS_TAGINT2D``, and ``LAMMPS_STRING``. + + :param name: name of the property + :type name: string + :return: datatype of global property + :rtype: int + """ + if name: name = name.encode() + else: return None + return self.lib.lammps_extract_global_datatype(self.lmp, name) + # ------------------------------------------------------------------------- # extract global info - def extract_global(self, name, type): + def extract_global(self, name, dtype=LAMMPS_AUTODETECT): """Query LAMMPS about global settings of different types. This is a wrapper around the :cpp:func:`lammps_extract_global` @@ -815,55 +876,84 @@ class lammps(object): of values. The :cpp:func:`lammps_extract_global` documentation includes a list of the supported keywords and their data types. Since Python needs to know the data type to be able to interpret - the result, the type has to be provided as an argument. For + the result, by default, this function will try to auto-detect the datatype + by asking the library. You can also force a specific data type. For that purpose the :py:mod:`lammps` module contains the constants ``LAMMPS_INT``, ``LAMMPS_DOUBLE``, ``LAMMPS_BIGINT``, ``LAMMPS_TAGINT``, and ``LAMMPS_STRING``. This function returns ``None`` if either the keyword is not recognized, or an invalid data type constant is used. - :param name: name of the setting + :param name: name of the property :type name: string - :param type: type of the returned data - :type type: int - :return: value of the setting + :param dtype: type of the returned data + :type dtype: int, optional + :return: value of the property :rtype: integer or double or string or None """ + if dtype == LAMMPS_AUTODETECT: + dtype = self.extract_global_datatype(name) + if name: name = name.encode() else: return None - if type == LAMMPS_INT: + + if dtype == LAMMPS_INT: self.lib.lammps_extract_global.restype = POINTER(c_int) - elif type == LAMMPS_DOUBLE: + elif dtype == LAMMPS_DOUBLE: self.lib.lammps_extract_global.restype = POINTER(c_double) - elif type == LAMMPS_BIGINT: + elif dtype == LAMMPS_BIGINT: self.lib.lammps_extract_global.restype = POINTER(self.c_bigint) - elif type == LAMMPS_TAGINT: + elif dtype == LAMMPS_TAGINT: self.lib.lammps_extract_global.restype = POINTER(self.c_tagint) - elif type == LAMMPS_STRING: + elif dtype == LAMMPS_STRING: self.lib.lammps_extract_global.restype = c_char_p - ptr = self.lib.lammps_extract_global(self.lmp,name) + ptr = self.lib.lammps_extract_global(self.lmp, name) return str(ptr,'ascii') else: return None - ptr = self.lib.lammps_extract_global(self.lmp,name) + + ptr = self.lib.lammps_extract_global(self.lmp, name) if ptr: return ptr[0] else: return None # ------------------------------------------------------------------------- - # extract per-atom info - # NOTE: need to insure are converting to/from correct Python type - # e.g. for Python list or NumPy or ctypes + # extract per-atom info datatype - def extract_atom(self,name,type): + def extract_atom_datatype(self, name): + """Retrieve per-atom property datatype from LAMMPS + + This is a wrapper around the :cpp:func:`lammps_extract_atom_datatype` + function of the C-library interface. Its documentation includes a + list of the supported keywords. + This function returns ``None`` if the keyword is not + recognized. Otherwise it will return a positive integer value that + corresponds to one of the contants define in the :py:mod:`lammps` module: + ``LAMMPS_INT``, ``LAMMPS_INT2D``, ``LAMMPS_DOUBLE``,``LAMMPS_DOUBLE2D``, + ``LAMMPS_BIGINT``, ``LAMMPS_TAGINT``, ``LAMMPS_TAGINT2D``, and ``LAMMPS_STRING``. + + :param name: name of the property + :type name: string + :return: datatype of per-atom property + :rtype: int + """ + if name: name = name.encode() + else: return None + return self.lib.lammps_extract_atom_datatype(self.lmp, name) + + # ------------------------------------------------------------------------- + # extract per-atom info + + def extract_atom(self, name, dtype=LAMMPS_AUTODETECT): """Retrieve per-atom properties from LAMMPS This is a wrapper around the :cpp:func:`lammps_extract_atom` function of the C-library interface. Its documentation includes a list of the supported keywords and their data types. Since Python needs to know the data type to be able to interpret - the result, the type has to be provided as an argument. For + the result, by default, this function will try to auto-detect the datatype + by asking the library. You can also force a specific data type. For that purpose the :py:mod:`lammps` module contains the constants - ``LAMMPS_INT``, ``LAMMPS_INT2D``, ``LAMMPS_DOUBLE``, - and ``LAMMPS_DOUBLE2D``. + ``LAMMPS_INT``, ``LAMMPS_DOUBLE``, ``LAMMPS_BIGINT``, + ``LAMMPS_TAGINT``, and ``LAMMPS_STRING``. This function returns ``None`` if either the keyword is not recognized, or an invalid data type constant is used. @@ -876,27 +966,35 @@ class lammps(object): atoms. In some cases, this depends on a LAMMPS setting, see for example :doc:`comm_modify vel yes `. - :param name: name of the setting + :param name: name of the property :type name: string - :param type: type of the returned data - :type type: int + :param dtype: type of the returned data + :type dtype: int, optional :return: requested data :rtype: pointer to integer or double or None """ - ntypes = int(self.extract_setting('ntypes')) - nmax = int(self.extract_setting('nmax')) + if dtype == LAMMPS_AUTODETECT: + dtype = self.lib.lammps_extract_atom_datatype(self.lmp, name) + if name: name = name.encode() else: return None - if type == LAMMPS_INT: + + if dtype == LAMMPS_INT: self.lib.lammps_extract_atom.restype = POINTER(c_int) - elif type == LAMMPS_INT2D: + elif dtype == LAMMPS_INT2D: self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int)) - elif type == LAMMPS_DOUBLE: + elif dtype == LAMMPS_DOUBLE: self.lib.lammps_extract_atom.restype = POINTER(c_double) - elif type == LAMMPS_DOUBLE2D: + elif dtype == LAMMPS_DOUBLE2D: self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_double)) + elif dtype == LAMMPS_TAGINT: + self.lib.lammps_extract_atom.restype = POINTER(self.c_tagint) + elif dtype == LAMMPS_TAGINT2D: + self.lib.lammps_extract_atom.restype = POINTER(POINTER(self.c_tagint)) + elif dtype == LAMMPS_BIGINT: + self.lib.lammps_extract_atom.restype = POINTER(self.c_bigint) else: return None - ptr = self.lib.lammps_extract_atom(self.lmp,name) + ptr = self.lib.lammps_extract_atom(self.lmp, name) if ptr: return ptr else: return None diff --git a/unittest/python/python-numpy.py b/unittest/python/python-numpy.py index 2f1b3e4fcd..3e223af877 100644 --- a/unittest/python/python-numpy.py +++ b/unittest/python/python-numpy.py @@ -73,7 +73,7 @@ class PythonNumpy(unittest.TestCase): # TODO pass - def testExtractAtom(self): + def testExtractAtomDeprecated(self): self.lmp.command("units lj") self.lmp.command("atom_style atomic") self.lmp.command("atom_modify map array") @@ -100,6 +100,39 @@ class PythonNumpy(unittest.TestCase): x = self.lmp.numpy.extract_atom_darray("x", nlocal, dim=3) v = self.lmp.numpy.extract_atom_darray("v", nlocal, dim=3) self.assertEqual(len(x), 2) + self.assertTrue((x[0] == (1.0, 1.0, 1.0)).all()) + self.assertTrue((x[1] == (1.0, 1.0, 1.5)).all()) + self.assertEqual(len(v), 2) + + def testExtractAtom(self): + self.lmp.command("units lj") + self.lmp.command("atom_style atomic") + self.lmp.command("atom_modify map array") + self.lmp.command("region box block 0 2 0 2 0 2") + self.lmp.command("create_box 1 box") + + x = [ + 1.0, 1.0, 1.0, + 1.0, 1.0, 1.5 + ] + + types = [1, 1] + + self.assertEqual(self.lmp.create_atoms(2, id=None, type=types, x=x), 2) + nlocal = self.lmp.extract_global("nlocal") + self.assertEqual(nlocal, 2) + + ident = self.lmp.numpy.extract_atom("id") + self.assertEqual(len(ident), 2) + + ntypes = self.lmp.extract_global("ntypes") + self.assertEqual(ntypes, 1) + + x = self.lmp.numpy.extract_atom("x", dim=3) + v = self.lmp.numpy.extract_atom("v", dim=3) + self.assertEqual(len(x), 2) + self.assertTrue((x[0] == (1.0, 1.0, 1.0)).all()) + self.assertTrue((x[1] == (1.0, 1.0, 1.5)).all()) self.assertEqual(len(v), 2) if __name__ == "__main__":