diff --git a/python/lammps.py b/python/lammps.py index 1b0b34691c..a39d96c69d 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -481,6 +481,40 @@ class lammps(object): return np.intc def extract_atom(self, name, dtype=LAMMPS_AUTODETECT, nelem=LAMMPS_AUTODETECT, dim=LAMMPS_AUTODETECT): + """Retrieve per-atom properties from LAMMPS as NumPy arrays + + This is a wrapper around the :cpp:func:`lammps_extract_atom` + function of the C-library interface. Its documentation includes a + list of the supported keywords and their data types. + Since Python needs to know the data type to be able to interpret + the result, by default, this function will try to auto-detect the datatype + by asking the library. You can also force a specific datatype. For + that purpose the :py:mod:`lammps` module contains the constants + ``LAMMPS_INT``, ``LAMMPS_DOUBLE``, ``LAMMPS_BIGINT``, + ``LAMMPS_TAGINT``, and ``LAMMPS_STRING``. + This function returns ``None`` if either the keyword is not + recognized, or an invalid data type constant is used. + + .. note:: + + While the returned arrays of per-atom data are dimensioned + for the range [0:nmax] - as is the underlying storage - + the data is usually only valid for the range of [0:nlocal], + unless the property of interest is also updated for ghost + atoms. In some cases, this depends on a LAMMPS setting, see + for example :doc:`comm_modify vel yes `. + + :param name: name of the property + :type name: string + :param dtype: type of the returned data + :type dtype: int, optional + :param nelem: number of elements in array + :type nelem: int, optional + :param dim: dimension of each element + :type dim: int, optional + :return: requested data as NumPy array with direct access to C data + :rtype: numpy.array + """ if dtype == LAMMPS_AUTODETECT: dtype = self.lmp.extract_atom_datatype(name) @@ -491,7 +525,11 @@ class lammps(object): nelem = self.lmp.extract_global("nlocal") if dim == LAMMPS_AUTODETECT: if dtype in (LAMMPS_INT2D, LAMMPS_DOUBLE2D, LAMMPS_TAGINT2D): - dim = 2 + # TODO add other fields + if name in ("x", "v", "f", "angmom", "torque", "csforce", "vforce"): + dim = 3 + else: + dim = 2 else: dim = 1 diff --git a/unittest/python/python-numpy.py b/unittest/python/python-numpy.py index 3e223af877..3c8ff9f512 100644 --- a/unittest/python/python-numpy.py +++ b/unittest/python/python-numpy.py @@ -128,8 +128,8 @@ class PythonNumpy(unittest.TestCase): ntypes = self.lmp.extract_global("ntypes") self.assertEqual(ntypes, 1) - x = self.lmp.numpy.extract_atom("x", dim=3) - v = self.lmp.numpy.extract_atom("v", dim=3) + x = self.lmp.numpy.extract_atom("x") + v = self.lmp.numpy.extract_atom("v") self.assertEqual(len(x), 2) self.assertTrue((x[0] == (1.0, 1.0, 1.0)).all()) self.assertTrue((x[1] == (1.0, 1.0, 1.5)).all())