Refactor LammpsNumpyWrapper to numpy_wrapper

LammpsNumpyWrapper was a class that was defined inside of the
lammps.numpy property when it was first accessed. This made it hard
to document the methods of this class.

This commit extracts this utility class into the lammps module and
renames it to 'numpy_wrapper'.
This commit is contained in:
Richard Berger
2020-10-02 17:28:25 -04:00
parent 0089a35d95
commit d91d8de76d
2 changed files with 205 additions and 174 deletions

View File

@ -44,6 +44,9 @@ functions. Below is a detailed documentation of the API.
.. autoclass:: lammps.lammps .. autoclass:: lammps.lammps
:members: :members:
.. autoclass:: lammps.numpy_wrapper
:members:
---------- ----------
The ``PyLammps`` class API The ``PyLammps`` class API

View File

@ -465,181 +465,17 @@ class lammps(object):
@property @property
def numpy(self): def numpy(self):
"Convert between ctypes arrays and numpy arrays" """ Return object to access numpy versions of API
It provides alternative implementations of API functions that
return numpy arrays instead of ctypes pointers. If numpy is not installed,
accessing this property will lead to an ImportError.
:return: instance of numpy wrapper object
:rtype: numpy_wrapper
"""
if not self._numpy: if not self._numpy:
import numpy as np self._numpy = numpy_wrapper(self)
class LammpsNumpyWrapper:
def __init__(self, lmp):
self.lmp = lmp
def _ctype_to_numpy_int(self, ctype_int):
if ctype_int == c_int32:
return np.int32
elif ctype_int == c_int64:
return np.int64
return np.intc
def extract_atom(self, name, dtype=LAMMPS_AUTODETECT, nelem=LAMMPS_AUTODETECT, dim=LAMMPS_AUTODETECT):
"""Retrieve per-atom properties from LAMMPS as NumPy arrays
This is a wrapper around the :cpp:func:`lammps_extract_atom`
function of the C-library interface. Its documentation includes a
list of the supported keywords and their data types.
Since Python needs to know the data type to be able to interpret
the result, by default, this function will try to auto-detect the data
type by asking the library. You can also force a specific data type.
For that purpose the :py:mod:`lammps` module contains the constants
``LAMMPS_INT``, ``LAMMPS_INT_2D``, ``LAMMPS_DOUBLE``,
``LAMMPS_DOUBLE_2D``, ``LAMMPS_INT64``, ``LAMMPS_INT64_2D``, and
``LAMMPS_STRING``.
This function returns ``None`` if either the keyword is not
recognized, or an invalid data type constant is used.
.. note::
While the returned arrays of per-atom data are dimensioned
for the range [0:nmax] - as is the underlying storage -
the data is usually only valid for the range of [0:nlocal],
unless the property of interest is also updated for ghost
atoms. In some cases, this depends on a LAMMPS setting, see
for example :doc:`comm_modify vel yes <comm_modify>`.
:param name: name of the property
:type name: string
:param dtype: type of the returned data (see :ref:`py_data_constants`)
:type dtype: int, optional
:param nelem: number of elements in array
:type nelem: int, optional
:param dim: dimension of each element
:type dim: int, optional
:return: requested data as NumPy array with direct access to C data
:rtype: numpy.array
"""
if dtype == LAMMPS_AUTODETECT:
dtype = self.lmp.extract_atom_datatype(name)
if nelem == LAMMPS_AUTODETECT:
if name == "mass":
nelem = self.lmp.extract_global("ntypes") + 1
else:
nelem = self.lmp.extract_global("nlocal")
if dim == LAMMPS_AUTODETECT:
if dtype in (LAMMPS_INT_2D, LAMMPS_DOUBLE_2D, LAMMPS_INT64_2D):
# TODO add other fields
if name in ("x", "v", "f", "angmom", "torque", "csforce", "vforce"):
dim = 3
else:
dim = 2
else:
dim = 1
raw_ptr = self.lmp.extract_atom(name, dtype)
if dtype in (LAMMPS_DOUBLE, LAMMPS_DOUBLE_2D):
return self.darray(raw_ptr, nelem, dim)
elif dtype in (LAMMPS_INT, LAMMPS_INT_2D):
return self.iarray(c_int32, raw_ptr, nelem, dim)
elif dtype in (LAMMPS_INT64, LAMMPS_INT64_2D):
return self.iarray(c_int64, raw_ptr, nelem, dim)
return raw_ptr
def extract_atom_iarray(self, name, nelem, dim=1):
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
if name in ['id', 'molecule']:
c_int_type = self.lmp.c_tagint
elif name in ['image']:
c_int_type = self.lmp.c_imageint
else:
c_int_type = c_int
if dim == 1:
raw_ptr = self.lmp.extract_atom(name, LAMMPS_INT)
else:
raw_ptr = self.lmp.extract_atom(name, LAMMPS_INT_2D)
return self.iarray(c_int_type, raw_ptr, nelem, dim)
def extract_atom_darray(self, name, nelem, dim=1):
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
if dim == 1:
raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE)
else:
raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE_2D)
return self.darray(raw_ptr, nelem, dim)
def extract_compute(self, cid, style, datatype):
value = self.lmp.extract_compute(cid, style, datatype)
if style in (LMP_STYLE_GLOBAL, LMP_STYLE_LOCAL):
if datatype == LMP_TYPE_VECTOR:
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_VECTOR)
return self.darray(value, nrows)
elif datatype == LMP_TYPE_ARRAY:
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_ROWS)
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
return self.darray(value, nrows, ncols)
elif style == LMP_STYLE_ATOM:
if datatype == LMP_TYPE_VECTOR:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
return self.darray(value, nlocal)
elif datatype == LMP_TYPE_ARRAY:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
return self.darray(value, nlocal, ncols)
return value
def extract_fix(self, fid, style, datatype, nrow=0, ncol=0):
value = self.lmp.extract_fix(fid, style, datatype, nrow, ncol)
if style == LMP_STYLE_ATOM:
if datatype == LMP_TYPE_VECTOR:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
return self.darray(value, nlocal)
elif datatype == LMP_TYPE_ARRAY:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
return self.darray(value, nlocal, ncols)
elif style == LMP_STYLE_LOCAL:
if datatype == LMP_TYPE_VECTOR:
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
return self.darray(value, nrows)
elif datatype == LMP_TYPE_ARRAY:
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
return self.darray(value, nrows, ncols)
return value
def extract_variable(self, name, group=None, datatype=LMP_VAR_EQUAL):
value = self.lmp.extract_variable(name, group, datatype)
if datatype == LMP_VAR_ATOM:
return np.ctypeslib.as_array(value)
return value
def iarray(self, c_int_type, raw_ptr, nelem, dim=1):
np_int_type = self._ctype_to_numpy_int(c_int_type)
if dim == 1:
ptr = cast(raw_ptr, POINTER(c_int_type * nelem))
else:
ptr = cast(raw_ptr[0], POINTER(c_int_type * nelem * dim))
a = np.frombuffer(ptr.contents, dtype=np_int_type)
a.shape = (nelem, dim)
return a
def darray(self, raw_ptr, nelem, dim=1):
if dim == 1:
ptr = cast(raw_ptr, POINTER(c_double * nelem))
else:
ptr = cast(raw_ptr[0], POINTER(c_double * nelem * dim))
a = np.frombuffer(ptr.contents)
a.shape = (nelem, dim)
return a
self._numpy = LammpsNumpyWrapper(self)
return self._numpy return self._numpy
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -1825,6 +1661,198 @@ class lammps(object):
neighbors = self.numpy.iarray(c_int, c_neighbors, c_numneigh.value, 1) neighbors = self.numpy.iarray(c_int, c_neighbors, c_numneigh.value, 1)
return c_iatom.value, c_numneigh.value, neighbors return c_iatom.value, c_numneigh.value, neighbors
# -------------------------------------------------------------------------
class numpy_wrapper:
"""lammps API NumPy Wrapper
This is a wrapper class that provides additional methods on top of an
existing :py:class:`lammps` instance. The methods transform raw ctypes
pointers into NumPy arrays, which give direct access to the
original data while protecting against out-of-bounds accesses.
There is no need to explicitly instantiate this class. Each instance
of :py:class:`lammps` has a :py:attr:`numpy <lammps.numpy>` property
that returns an instance.
:param lmp: instance of the :py:class:`lammps` class
:type lmp: lammps
"""
def __init__(self, lmp):
self.lmp = lmp
def _ctype_to_numpy_int(self, ctype_int):
import numpy as np
if ctype_int == c_int32:
return np.int32
elif ctype_int == c_int64:
return np.int64
return np.intc
def extract_atom(self, name, dtype=LAMMPS_AUTODETECT, nelem=LAMMPS_AUTODETECT, dim=LAMMPS_AUTODETECT):
"""Retrieve per-atom properties from LAMMPS as NumPy arrays
This is a wrapper around the :cpp:func:`lammps_extract_atom`
function of the C-library interface. Its documentation includes a
list of the supported keywords and their data types.
Since Python needs to know the data type to be able to interpret
the result, by default, this function will try to auto-detect the data
type by asking the library. You can also force a specific data type.
For that purpose the :py:mod:`lammps` module contains the constants
``LAMMPS_INT``, ``LAMMPS_INT_2D``, ``LAMMPS_DOUBLE``,
``LAMMPS_DOUBLE_2D``, ``LAMMPS_INT64``, ``LAMMPS_INT64_2D``, and
``LAMMPS_STRING``.
This function returns ``None`` if either the keyword is not
recognized, or an invalid data type constant is used.
.. note::
While the returned arrays of per-atom data are dimensioned
for the range [0:nmax] - as is the underlying storage -
the data is usually only valid for the range of [0:nlocal],
unless the property of interest is also updated for ghost
atoms. In some cases, this depends on a LAMMPS setting, see
for example :doc:`comm_modify vel yes <comm_modify>`.
:param name: name of the property
:type name: string
:param dtype: type of the returned data (see :ref:`py_data_constants`)
:type dtype: int, optional
:param nelem: number of elements in array
:type nelem: int, optional
:param dim: dimension of each element
:type dim: int, optional
:return: requested data as NumPy array with direct access to C data
:rtype: numpy.array
"""
if dtype == LAMMPS_AUTODETECT:
dtype = self.lmp.extract_atom_datatype(name)
if nelem == LAMMPS_AUTODETECT:
if name == "mass":
nelem = self.lmp.extract_global("ntypes") + 1
else:
nelem = self.lmp.extract_global("nlocal")
if dim == LAMMPS_AUTODETECT:
if dtype in (LAMMPS_INT_2D, LAMMPS_DOUBLE_2D, LAMMPS_INT64_2D):
# TODO add other fields
if name in ("x", "v", "f", "angmom", "torque", "csforce", "vforce"):
dim = 3
else:
dim = 2
else:
dim = 1
raw_ptr = self.lmp.extract_atom(name, dtype)
if dtype in (LAMMPS_DOUBLE, LAMMPS_DOUBLE_2D):
return self.darray(raw_ptr, nelem, dim)
elif dtype in (LAMMPS_INT, LAMMPS_INT_2D):
return self.iarray(c_int32, raw_ptr, nelem, dim)
elif dtype in (LAMMPS_INT64, LAMMPS_INT64_2D):
return self.iarray(c_int64, raw_ptr, nelem, dim)
return raw_ptr
def extract_atom_iarray(self, name, nelem, dim=1):
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
if name in ['id', 'molecule']:
c_int_type = self.lmp.c_tagint
elif name in ['image']:
c_int_type = self.lmp.c_imageint
else:
c_int_type = c_int
if dim == 1:
raw_ptr = self.lmp.extract_atom(name, LAMMPS_INT)
else:
raw_ptr = self.lmp.extract_atom(name, LAMMPS_INT_2D)
return self.iarray(c_int_type, raw_ptr, nelem, dim)
def extract_atom_darray(self, name, nelem, dim=1):
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
if dim == 1:
raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE)
else:
raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE_2D)
return self.darray(raw_ptr, nelem, dim)
def extract_compute(self, cid, style, datatype):
value = self.lmp.extract_compute(cid, style, datatype)
if style in (LMP_STYLE_GLOBAL, LMP_STYLE_LOCAL):
if datatype == LMP_TYPE_VECTOR:
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_VECTOR)
return self.darray(value, nrows)
elif datatype == LMP_TYPE_ARRAY:
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_ROWS)
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
return self.darray(value, nrows, ncols)
elif style == LMP_STYLE_ATOM:
if datatype == LMP_TYPE_VECTOR:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
return self.darray(value, nlocal)
elif datatype == LMP_TYPE_ARRAY:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
return self.darray(value, nlocal, ncols)
return value
def extract_fix(self, fid, style, datatype, nrow=0, ncol=0):
value = self.lmp.extract_fix(fid, style, datatype, nrow, ncol)
if style == LMP_STYLE_ATOM:
if datatype == LMP_TYPE_VECTOR:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
return self.darray(value, nlocal)
elif datatype == LMP_TYPE_ARRAY:
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
return self.darray(value, nlocal, ncols)
elif style == LMP_STYLE_LOCAL:
if datatype == LMP_TYPE_VECTOR:
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
return self.darray(value, nrows)
elif datatype == LMP_TYPE_ARRAY:
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
return self.darray(value, nrows, ncols)
return value
def extract_variable(self, name, group=None, datatype=LMP_VAR_EQUAL):
import numpy as np
value = self.lmp.extract_variable(name, group, datatype)
if datatype == LMP_VAR_ATOM:
return np.ctypeslib.as_array(value)
return value
def iarray(self, c_int_type, raw_ptr, nelem, dim=1):
import numpy as np
np_int_type = self._ctype_to_numpy_int(c_int_type)
if dim == 1:
ptr = cast(raw_ptr, POINTER(c_int_type * nelem))
else:
ptr = cast(raw_ptr[0], POINTER(c_int_type * nelem * dim))
a = np.frombuffer(ptr.contents, dtype=np_int_type)
a.shape = (nelem, dim)
return a
def darray(self, raw_ptr, nelem, dim=1):
import numpy as np
if dim == 1:
ptr = cast(raw_ptr, POINTER(c_double * nelem))
else:
ptr = cast(raw_ptr[0], POINTER(c_double * nelem * dim))
a = np.frombuffer(ptr.contents)
a.shape = (nelem, dim)
return a
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------