Refactor LammpsNumpyWrapper to numpy_wrapper
LammpsNumpyWrapper was a class that was defined inside of the lammps.numpy property when it was first accessed. This made it hard to document the methods of this class. This commit extracts this utility class into the lammps module and renames it to 'numpy_wrapper'.
This commit is contained in:
@ -44,6 +44,9 @@ functions. Below is a detailed documentation of the API.
|
||||
.. autoclass:: lammps.lammps
|
||||
:members:
|
||||
|
||||
.. autoclass:: lammps.numpy_wrapper
|
||||
:members:
|
||||
|
||||
----------
|
||||
|
||||
The ``PyLammps`` class API
|
||||
|
||||
376
python/lammps.py
376
python/lammps.py
@ -465,181 +465,17 @@ class lammps(object):
|
||||
|
||||
@property
|
||||
def numpy(self):
|
||||
"Convert between ctypes arrays and numpy arrays"
|
||||
""" Return object to access numpy versions of API
|
||||
|
||||
It provides alternative implementations of API functions that
|
||||
return numpy arrays instead of ctypes pointers. If numpy is not installed,
|
||||
accessing this property will lead to an ImportError.
|
||||
|
||||
:return: instance of numpy wrapper object
|
||||
:rtype: numpy_wrapper
|
||||
"""
|
||||
if not self._numpy:
|
||||
import numpy as np
|
||||
class LammpsNumpyWrapper:
|
||||
def __init__(self, lmp):
|
||||
self.lmp = lmp
|
||||
|
||||
def _ctype_to_numpy_int(self, ctype_int):
|
||||
if ctype_int == c_int32:
|
||||
return np.int32
|
||||
elif ctype_int == c_int64:
|
||||
return np.int64
|
||||
return np.intc
|
||||
|
||||
def extract_atom(self, name, dtype=LAMMPS_AUTODETECT, nelem=LAMMPS_AUTODETECT, dim=LAMMPS_AUTODETECT):
|
||||
"""Retrieve per-atom properties from LAMMPS as NumPy arrays
|
||||
|
||||
This is a wrapper around the :cpp:func:`lammps_extract_atom`
|
||||
function of the C-library interface. Its documentation includes a
|
||||
list of the supported keywords and their data types.
|
||||
Since Python needs to know the data type to be able to interpret
|
||||
the result, by default, this function will try to auto-detect the data
|
||||
type by asking the library. You can also force a specific data type.
|
||||
For that purpose the :py:mod:`lammps` module contains the constants
|
||||
``LAMMPS_INT``, ``LAMMPS_INT_2D``, ``LAMMPS_DOUBLE``,
|
||||
``LAMMPS_DOUBLE_2D``, ``LAMMPS_INT64``, ``LAMMPS_INT64_2D``, and
|
||||
``LAMMPS_STRING``.
|
||||
This function returns ``None`` if either the keyword is not
|
||||
recognized, or an invalid data type constant is used.
|
||||
|
||||
.. note::
|
||||
|
||||
While the returned arrays of per-atom data are dimensioned
|
||||
for the range [0:nmax] - as is the underlying storage -
|
||||
the data is usually only valid for the range of [0:nlocal],
|
||||
unless the property of interest is also updated for ghost
|
||||
atoms. In some cases, this depends on a LAMMPS setting, see
|
||||
for example :doc:`comm_modify vel yes <comm_modify>`.
|
||||
|
||||
:param name: name of the property
|
||||
:type name: string
|
||||
:param dtype: type of the returned data (see :ref:`py_data_constants`)
|
||||
:type dtype: int, optional
|
||||
:param nelem: number of elements in array
|
||||
:type nelem: int, optional
|
||||
:param dim: dimension of each element
|
||||
:type dim: int, optional
|
||||
:return: requested data as NumPy array with direct access to C data
|
||||
:rtype: numpy.array
|
||||
"""
|
||||
if dtype == LAMMPS_AUTODETECT:
|
||||
dtype = self.lmp.extract_atom_datatype(name)
|
||||
|
||||
if nelem == LAMMPS_AUTODETECT:
|
||||
if name == "mass":
|
||||
nelem = self.lmp.extract_global("ntypes") + 1
|
||||
else:
|
||||
nelem = self.lmp.extract_global("nlocal")
|
||||
if dim == LAMMPS_AUTODETECT:
|
||||
if dtype in (LAMMPS_INT_2D, LAMMPS_DOUBLE_2D, LAMMPS_INT64_2D):
|
||||
# TODO add other fields
|
||||
if name in ("x", "v", "f", "angmom", "torque", "csforce", "vforce"):
|
||||
dim = 3
|
||||
else:
|
||||
dim = 2
|
||||
else:
|
||||
dim = 1
|
||||
|
||||
raw_ptr = self.lmp.extract_atom(name, dtype)
|
||||
|
||||
if dtype in (LAMMPS_DOUBLE, LAMMPS_DOUBLE_2D):
|
||||
return self.darray(raw_ptr, nelem, dim)
|
||||
elif dtype in (LAMMPS_INT, LAMMPS_INT_2D):
|
||||
return self.iarray(c_int32, raw_ptr, nelem, dim)
|
||||
elif dtype in (LAMMPS_INT64, LAMMPS_INT64_2D):
|
||||
return self.iarray(c_int64, raw_ptr, nelem, dim)
|
||||
return raw_ptr
|
||||
|
||||
def extract_atom_iarray(self, name, nelem, dim=1):
|
||||
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
|
||||
|
||||
if name in ['id', 'molecule']:
|
||||
c_int_type = self.lmp.c_tagint
|
||||
elif name in ['image']:
|
||||
c_int_type = self.lmp.c_imageint
|
||||
else:
|
||||
c_int_type = c_int
|
||||
|
||||
if dim == 1:
|
||||
raw_ptr = self.lmp.extract_atom(name, LAMMPS_INT)
|
||||
else:
|
||||
raw_ptr = self.lmp.extract_atom(name, LAMMPS_INT_2D)
|
||||
|
||||
return self.iarray(c_int_type, raw_ptr, nelem, dim)
|
||||
|
||||
def extract_atom_darray(self, name, nelem, dim=1):
|
||||
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
|
||||
|
||||
if dim == 1:
|
||||
raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE)
|
||||
else:
|
||||
raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE_2D)
|
||||
|
||||
return self.darray(raw_ptr, nelem, dim)
|
||||
|
||||
def extract_compute(self, cid, style, datatype):
|
||||
value = self.lmp.extract_compute(cid, style, datatype)
|
||||
|
||||
if style in (LMP_STYLE_GLOBAL, LMP_STYLE_LOCAL):
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_VECTOR)
|
||||
return self.darray(value, nrows)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_ROWS)
|
||||
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
|
||||
return self.darray(value, nrows, ncols)
|
||||
elif style == LMP_STYLE_ATOM:
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
return self.darray(value, nlocal)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
|
||||
return self.darray(value, nlocal, ncols)
|
||||
return value
|
||||
|
||||
def extract_fix(self, fid, style, datatype, nrow=0, ncol=0):
|
||||
value = self.lmp.extract_fix(fid, style, datatype, nrow, ncol)
|
||||
if style == LMP_STYLE_ATOM:
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
return self.darray(value, nlocal)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
|
||||
return self.darray(value, nlocal, ncols)
|
||||
elif style == LMP_STYLE_LOCAL:
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
|
||||
return self.darray(value, nrows)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
|
||||
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
|
||||
return self.darray(value, nrows, ncols)
|
||||
return value
|
||||
|
||||
def extract_variable(self, name, group=None, datatype=LMP_VAR_EQUAL):
|
||||
value = self.lmp.extract_variable(name, group, datatype)
|
||||
if datatype == LMP_VAR_ATOM:
|
||||
return np.ctypeslib.as_array(value)
|
||||
return value
|
||||
|
||||
def iarray(self, c_int_type, raw_ptr, nelem, dim=1):
|
||||
np_int_type = self._ctype_to_numpy_int(c_int_type)
|
||||
|
||||
if dim == 1:
|
||||
ptr = cast(raw_ptr, POINTER(c_int_type * nelem))
|
||||
else:
|
||||
ptr = cast(raw_ptr[0], POINTER(c_int_type * nelem * dim))
|
||||
|
||||
a = np.frombuffer(ptr.contents, dtype=np_int_type)
|
||||
a.shape = (nelem, dim)
|
||||
return a
|
||||
|
||||
def darray(self, raw_ptr, nelem, dim=1):
|
||||
if dim == 1:
|
||||
ptr = cast(raw_ptr, POINTER(c_double * nelem))
|
||||
else:
|
||||
ptr = cast(raw_ptr[0], POINTER(c_double * nelem * dim))
|
||||
|
||||
a = np.frombuffer(ptr.contents)
|
||||
a.shape = (nelem, dim)
|
||||
return a
|
||||
|
||||
self._numpy = LammpsNumpyWrapper(self)
|
||||
self._numpy = numpy_wrapper(self)
|
||||
return self._numpy
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
@ -1825,6 +1661,198 @@ class lammps(object):
|
||||
neighbors = self.numpy.iarray(c_int, c_neighbors, c_numneigh.value, 1)
|
||||
return c_iatom.value, c_numneigh.value, neighbors
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
class numpy_wrapper:
|
||||
"""lammps API NumPy Wrapper
|
||||
|
||||
This is a wrapper class that provides additional methods on top of an
|
||||
existing :py:class:`lammps` instance. The methods transform raw ctypes
|
||||
pointers into NumPy arrays, which give direct access to the
|
||||
original data while protecting against out-of-bounds accesses.
|
||||
|
||||
There is no need to explicitly instantiate this class. Each instance
|
||||
of :py:class:`lammps` has a :py:attr:`numpy <lammps.numpy>` property
|
||||
that returns an instance.
|
||||
|
||||
:param lmp: instance of the :py:class:`lammps` class
|
||||
:type lmp: lammps
|
||||
"""
|
||||
def __init__(self, lmp):
|
||||
self.lmp = lmp
|
||||
|
||||
def _ctype_to_numpy_int(self, ctype_int):
|
||||
import numpy as np
|
||||
if ctype_int == c_int32:
|
||||
return np.int32
|
||||
elif ctype_int == c_int64:
|
||||
return np.int64
|
||||
return np.intc
|
||||
|
||||
def extract_atom(self, name, dtype=LAMMPS_AUTODETECT, nelem=LAMMPS_AUTODETECT, dim=LAMMPS_AUTODETECT):
|
||||
"""Retrieve per-atom properties from LAMMPS as NumPy arrays
|
||||
|
||||
This is a wrapper around the :cpp:func:`lammps_extract_atom`
|
||||
function of the C-library interface. Its documentation includes a
|
||||
list of the supported keywords and their data types.
|
||||
Since Python needs to know the data type to be able to interpret
|
||||
the result, by default, this function will try to auto-detect the data
|
||||
type by asking the library. You can also force a specific data type.
|
||||
For that purpose the :py:mod:`lammps` module contains the constants
|
||||
``LAMMPS_INT``, ``LAMMPS_INT_2D``, ``LAMMPS_DOUBLE``,
|
||||
``LAMMPS_DOUBLE_2D``, ``LAMMPS_INT64``, ``LAMMPS_INT64_2D``, and
|
||||
``LAMMPS_STRING``.
|
||||
This function returns ``None`` if either the keyword is not
|
||||
recognized, or an invalid data type constant is used.
|
||||
|
||||
.. note::
|
||||
|
||||
While the returned arrays of per-atom data are dimensioned
|
||||
for the range [0:nmax] - as is the underlying storage -
|
||||
the data is usually only valid for the range of [0:nlocal],
|
||||
unless the property of interest is also updated for ghost
|
||||
atoms. In some cases, this depends on a LAMMPS setting, see
|
||||
for example :doc:`comm_modify vel yes <comm_modify>`.
|
||||
|
||||
:param name: name of the property
|
||||
:type name: string
|
||||
:param dtype: type of the returned data (see :ref:`py_data_constants`)
|
||||
:type dtype: int, optional
|
||||
:param nelem: number of elements in array
|
||||
:type nelem: int, optional
|
||||
:param dim: dimension of each element
|
||||
:type dim: int, optional
|
||||
:return: requested data as NumPy array with direct access to C data
|
||||
:rtype: numpy.array
|
||||
"""
|
||||
if dtype == LAMMPS_AUTODETECT:
|
||||
dtype = self.lmp.extract_atom_datatype(name)
|
||||
|
||||
if nelem == LAMMPS_AUTODETECT:
|
||||
if name == "mass":
|
||||
nelem = self.lmp.extract_global("ntypes") + 1
|
||||
else:
|
||||
nelem = self.lmp.extract_global("nlocal")
|
||||
if dim == LAMMPS_AUTODETECT:
|
||||
if dtype in (LAMMPS_INT_2D, LAMMPS_DOUBLE_2D, LAMMPS_INT64_2D):
|
||||
# TODO add other fields
|
||||
if name in ("x", "v", "f", "angmom", "torque", "csforce", "vforce"):
|
||||
dim = 3
|
||||
else:
|
||||
dim = 2
|
||||
else:
|
||||
dim = 1
|
||||
|
||||
raw_ptr = self.lmp.extract_atom(name, dtype)
|
||||
|
||||
if dtype in (LAMMPS_DOUBLE, LAMMPS_DOUBLE_2D):
|
||||
return self.darray(raw_ptr, nelem, dim)
|
||||
elif dtype in (LAMMPS_INT, LAMMPS_INT_2D):
|
||||
return self.iarray(c_int32, raw_ptr, nelem, dim)
|
||||
elif dtype in (LAMMPS_INT64, LAMMPS_INT64_2D):
|
||||
return self.iarray(c_int64, raw_ptr, nelem, dim)
|
||||
return raw_ptr
|
||||
|
||||
def extract_atom_iarray(self, name, nelem, dim=1):
|
||||
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
|
||||
|
||||
if name in ['id', 'molecule']:
|
||||
c_int_type = self.lmp.c_tagint
|
||||
elif name in ['image']:
|
||||
c_int_type = self.lmp.c_imageint
|
||||
else:
|
||||
c_int_type = c_int
|
||||
|
||||
if dim == 1:
|
||||
raw_ptr = self.lmp.extract_atom(name, LAMMPS_INT)
|
||||
else:
|
||||
raw_ptr = self.lmp.extract_atom(name, LAMMPS_INT_2D)
|
||||
|
||||
return self.iarray(c_int_type, raw_ptr, nelem, dim)
|
||||
|
||||
def extract_atom_darray(self, name, nelem, dim=1):
|
||||
warnings.warn("deprecated, use extract_atom instead", DeprecationWarning)
|
||||
|
||||
if dim == 1:
|
||||
raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE)
|
||||
else:
|
||||
raw_ptr = self.lmp.extract_atom(name, LAMMPS_DOUBLE_2D)
|
||||
|
||||
return self.darray(raw_ptr, nelem, dim)
|
||||
|
||||
def extract_compute(self, cid, style, datatype):
|
||||
value = self.lmp.extract_compute(cid, style, datatype)
|
||||
|
||||
if style in (LMP_STYLE_GLOBAL, LMP_STYLE_LOCAL):
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_VECTOR)
|
||||
return self.darray(value, nrows)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_ROWS)
|
||||
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
|
||||
return self.darray(value, nrows, ncols)
|
||||
elif style == LMP_STYLE_ATOM:
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
return self.darray(value, nlocal)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS)
|
||||
return self.darray(value, nlocal, ncols)
|
||||
return value
|
||||
|
||||
def extract_fix(self, fid, style, datatype, nrow=0, ncol=0):
|
||||
value = self.lmp.extract_fix(fid, style, datatype, nrow, ncol)
|
||||
if style == LMP_STYLE_ATOM:
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
return self.darray(value, nlocal)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT)
|
||||
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
|
||||
return self.darray(value, nlocal, ncols)
|
||||
elif style == LMP_STYLE_LOCAL:
|
||||
if datatype == LMP_TYPE_VECTOR:
|
||||
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
|
||||
return self.darray(value, nrows)
|
||||
elif datatype == LMP_TYPE_ARRAY:
|
||||
nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0)
|
||||
ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0)
|
||||
return self.darray(value, nrows, ncols)
|
||||
return value
|
||||
|
||||
def extract_variable(self, name, group=None, datatype=LMP_VAR_EQUAL):
|
||||
import numpy as np
|
||||
value = self.lmp.extract_variable(name, group, datatype)
|
||||
if datatype == LMP_VAR_ATOM:
|
||||
return np.ctypeslib.as_array(value)
|
||||
return value
|
||||
|
||||
def iarray(self, c_int_type, raw_ptr, nelem, dim=1):
|
||||
import numpy as np
|
||||
np_int_type = self._ctype_to_numpy_int(c_int_type)
|
||||
|
||||
if dim == 1:
|
||||
ptr = cast(raw_ptr, POINTER(c_int_type * nelem))
|
||||
else:
|
||||
ptr = cast(raw_ptr[0], POINTER(c_int_type * nelem * dim))
|
||||
|
||||
a = np.frombuffer(ptr.contents, dtype=np_int_type)
|
||||
a.shape = (nelem, dim)
|
||||
return a
|
||||
|
||||
def darray(self, raw_ptr, nelem, dim=1):
|
||||
import numpy as np
|
||||
if dim == 1:
|
||||
ptr = cast(raw_ptr, POINTER(c_double * nelem))
|
||||
else:
|
||||
ptr = cast(raw_ptr[0], POINTER(c_double * nelem * dim))
|
||||
|
||||
a = np.frombuffer(ptr.contents)
|
||||
a.shape = (nelem, dim)
|
||||
return a
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user