apply more pylint recommendations

This commit is contained in:
Axel Kohlmeyer
2025-06-21 23:22:34 -04:00
parent 679806206d
commit 9b382dac41
6 changed files with 207 additions and 205 deletions

View File

@ -264,7 +264,7 @@ max-bool-expr = 5
max-branches = 50 max-branches = 50
# Maximum number of locals for function / method body. # Maximum number of locals for function / method body.
max-locals = 20 max-locals = 25
# Maximum number of parents for a class (see R0901). # Maximum number of parents for a class (see R0901).
max-parents = 7 max-parents = 7
@ -282,7 +282,7 @@ max-returns = 15
max-statements = 500 max-statements = 500
# Minimum number of public methods for a class (see R0903). # Minimum number of public methods for a class (see R0903).
min-public-methods = 2 min-public-methods = 0
[tool.pylint.exceptions] [tool.pylint.exceptions]
# Exceptions that will emit a warning when caught. # Exceptions that will emit a warning when caught.

View File

@ -11,8 +11,10 @@
# See the README file in the top-level LAMMPS directory. # See the README file in the top-level LAMMPS directory.
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# various symbolic constants to be used """
# in certain calls to select data formats various symbolic constants to be used
in certain calls to select data formats
"""
# these must be kept in sync with the enums in src/library.h, src/lmptype.h, # these must be kept in sync with the enums in src/library.h, src/lmptype.h,
# tools/swig/lammps.i, examples/COUPLE/plugin/liblammpsplugin.h, # tools/swig/lammps.i, examples/COUPLE/plugin/liblammpsplugin.h,
@ -55,9 +57,11 @@ LMP_BUFSIZE = 1024
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def get_ctypes_int(size): def get_ctypes_int(size):
"""return ctypes type matching the configured C/C++ integer size in LAMMPS"""
# pylint: disable=C0415
from ctypes import c_int, c_int32, c_int64 from ctypes import c_int, c_int32, c_int64
if size == 4: if size == 4:
return c_int32 return c_int32
elif size == 8: if size == 8:
return c_int64 return c_int64
return c_int return c_int

View File

@ -10,7 +10,9 @@
# #
# See the README file in the top-level LAMMPS directory. # See the README file in the top-level LAMMPS directory.
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# Python wrapper for the LAMMPS library via ctypes """
Python module wrapping the LAMMPS library via ctypes
"""
# avoid pylint warnings about naming conventions # avoid pylint warnings about naming conventions
# pylint: disable=C0103 # pylint: disable=C0103

View File

@ -11,82 +11,83 @@
# See the README file in the top-level LAMMPS directory. # See the README file in the top-level LAMMPS directory.
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
################################################################################ """
# LAMMPS data structures Data structures for LAMMPS Python module
# Written by Richard Berger <richard.berger@temple.edu> Written by Richard Berger <richard.berger@temple.edu>
################################################################################ """
class NeighList(object): class NeighList:
"""This is a wrapper class that exposes the contents of a neighbor list. """This is a wrapper class that exposes the contents of a neighbor list.
It can be used like a regular Python list. Each element is a tuple of: It can be used like a regular Python list. Each element is a tuple of:
* the atom local index * the atom local index
* its number of neighbors * its number of neighbors
* and a pointer to an c_int array containing local atom indices of its * and a pointer to an c_int array containing local atom indices of its
neighbors neighbors
Internally it uses the lower-level LAMMPS C-library interface. Internally it uses the lower-level LAMMPS C-library interface.
:param lmp: reference to instance of :py:class:`lammps` :param lmp: reference to instance of :py:class:`lammps`
:type lmp: lammps :type lmp: lammps
:param idx: neighbor list index :param idx: neighbor list index
:type idx: int :type idx: int
"""
def __init__(self, lmp, idx):
self.lmp = lmp
self.idx = idx
def __str__(self):
# pylint: disable=C0209
return "Neighbor List ({} atoms)".format(self.size)
def __repr__(self):
return self.__str__()
@property
def size(self):
""" """
def __init__(self, lmp, idx): :return: number of elements in neighbor list
self.lmp = lmp """
self.idx = idx return self.lmp.get_neighlist_size(self.idx)
def __str__(self): def get(self, element):
return "Neighbor List ({} atoms)".format(self.size) """
Access a specific neighbor list entry. "element" must be a number from 0 to the size-1 of the list
def __repr__(self): :return: tuple with atom local index, number of neighbors and ctypes pointer to neighbor's local atom indices
return self.__str__() :rtype: (int, int, ctypes.POINTER(c_int))
"""
iatom, numneigh, neighbors = self.lmp.get_neighlist_element_neighbors(self.idx, element)
return iatom, numneigh, neighbors
@property # the methods below implement the iterator interface, so NeighList can be used like a regular Python list
def size(self):
"""
:return: number of elements in neighbor list
"""
return self.lmp.get_neighlist_size(self.idx)
def get(self, element): def __getitem__(self, element):
""" return self.get(element)
Access a specific neighbor list entry. "element" must be a number from 0 to the size-1 of the list
:return: tuple with atom local index, number of neighbors and ctypes pointer to neighbor's local atom indices def __len__(self):
:rtype: (int, int, ctypes.POINTER(c_int)) return self.size
"""
iatom, numneigh, neighbors = self.lmp.get_neighlist_element_neighbors(self.idx, element)
return iatom, numneigh, neighbors
# the methods below implement the iterator interface, so NeighList can be used like a regular Python list def __iter__(self):
inum = self.size
def __getitem__(self, element): for ii in range(inum):
return self.get(element) yield self.get(ii)
def __len__(self): def find(self, iatom):
return self.size """
Find the neighbor list for a specific (local) atom iatom.
If there is no list for iatom, (-1, None) is returned.
def __iter__(self): :return: tuple with number of neighbors and ctypes pointer to neighbor's local atom indices
inum = self.size :rtype: (int, ctypes.POINTER(c_int))
"""
for ii in range(inum): inum = self.size
yield self.get(ii) for ii in range(inum):
idx, numneigh, neighbors = self.get(ii)
if idx == iatom:
return numneigh, neighbors
def find(self, iatom): return -1, None
"""
Find the neighbor list for a specific (local) atom iatom.
If there is no list for iatom, (-1, None) is returned.
:return: tuple with number of neighbors and ctypes pointer to neighbor's local atom indices
:rtype: (int, ctypes.POINTER(c_int))
"""
inum = self.size
for ii in range(inum):
idx, numneigh, neighbors = self.get(ii)
if idx == iatom:
return numneigh, neighbors
return -1, None

View File

@ -11,14 +11,15 @@
# See the README file in the top-level LAMMPS directory. # See the README file in the top-level LAMMPS directory.
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
################################################################################ """
# LAMMPS output formats Output formats for LAMMPS python module
# Written by Richard Berger <richard.berger@temple.edu> Written by Richard Berger <richard.berger@temple.edu>
# and Axel Kohlmeyer <akohlmey@gmail.com> and Axel Kohlmeyer <akohlmey@gmail.com>
################################################################################ """
import re import re
# pylint: disable=C0103
has_yaml = False has_yaml = False
try: try:
import yaml import yaml
@ -32,6 +33,7 @@ except ImportError:
pass pass
class LogFile: class LogFile:
# pylint: disable=R0903
"""Reads LAMMPS log files and extracts the thermo information """Reads LAMMPS log files and extracts the thermo information
It supports the line, multi, and yaml thermo output styles. It supports the line, multi, and yaml thermo output styles.
@ -55,73 +57,73 @@ class LogFile:
yamllog = "" yamllog = ""
self.runs = [] self.runs = []
self.errors = [] self.errors = []
with open(filename, 'rt') as f: with open(filename, 'rt', encoding='utf-8') as f:
in_thermo = False in_thermo = False
in_data_section = False in_data_section = False
for line in f: for line in f:
if "ERROR" in line or "exited on signal" in line: if "ERROR" in line or "exited on signal" in line:
self.errors.append(line) self.errors.append(line)
elif re.match(r'^ *Step ', line): elif re.match(r'^ *Step ', line):
in_thermo = True in_thermo = True
in_data_section = True in_data_section = True
keys = line.split() keys = line.split()
current_run = {} current_run = {}
for k in keys: for k in keys:
current_run[k] = [] current_run[k] = []
elif re.match(r'^(keywords:.*$|data:$|---$| - \[.*\]$)', line): elif re.match(r'^(keywords:.*$|data:$|---$| - \[.*\]$)', line):
if not has_yaml: if not has_yaml:
raise Exception('Cannot process YAML format logs without the PyYAML Python module') raise RuntimeError('Cannot process YAML format logs without the PyYAML Python module')
style = LogFile.STYLE_YAML style = LogFile.STYLE_YAML
yamllog += line; yamllog += line
current_run = {} current_run = {}
elif re.match(r'^\.\.\.$', line): elif re.match(r'^\.\.\.$', line):
thermo = yaml.load(yamllog, Loader=Loader) thermo = yaml.load(yamllog, Loader=Loader)
for k in thermo['keywords']: for k in thermo['keywords']:
current_run[k] = [] current_run[k] = []
for step in thermo['data']: for step in thermo['data']:
icol = 0 icol = 0
for k in thermo['keywords']: for k in thermo['keywords']:
current_run[k].append(step[icol]) current_run[k].append(step[icol])
icol += 1 icol += 1
self.runs.append(current_run) self.runs.append(current_run)
yamllog = "" yamllog = ""
elif re.match(r'^------* Step ', line): elif re.match(r'^------* Step ', line):
if not in_thermo: if not in_thermo:
current_run = {'Step': [], 'CPU': []} current_run = {'Step': [], 'CPU': []}
in_thermo = True in_thermo = True
in_data_section = True in_data_section = True
style = LogFile.STYLE_MULTI style = LogFile.STYLE_MULTI
str_step, str_cpu = line.strip('-\n').split('-----') str_step, str_cpu = line.strip('-\n').split('-----')
step = float(str_step.split()[1]) step = float(str_step.split()[1])
cpu = float(str_cpu.split('=')[1].split()[0]) cpu = float(str_cpu.split('=')[1].split()[0])
current_run["Step"].append(step) current_run["Step"].append(step)
current_run["CPU"].append(cpu) current_run["CPU"].append(cpu)
elif line.startswith('Loop time of'): elif line.startswith('Loop time of'):
in_thermo = False in_thermo = False
if style != LogFile.STYLE_YAML: if style != LogFile.STYLE_YAML:
self.runs.append(current_run) self.runs.append(current_run)
elif in_thermo and in_data_section: elif in_thermo and in_data_section:
if style == LogFile.STYLE_DEFAULT: if style == LogFile.STYLE_DEFAULT:
if alpha.search(line): if alpha.search(line):
continue continue
for k, v in zip(keys, map(float, line.split())): for k, v in zip(keys, map(float, line.split())):
current_run[k].append(v) current_run[k].append(v)
elif style == LogFile.STYLE_MULTI: elif style == LogFile.STYLE_MULTI:
if '=' not in line: if '=' not in line:
in_data_section = False in_data_section = False
continue continue
for k,v in kvpairs.findall(line): for k,v in kvpairs.findall(line):
if k not in current_run: if k not in current_run:
current_run[k] = [float(v)] current_run[k] = [float(v)]
else: else:
current_run[k].append(float(v)) current_run[k].append(float(v))
class AvgChunkFile: class AvgChunkFile:
"""Reads files generated by fix ave/chunk """Reads files generated by fix ave/chunk
@ -134,9 +136,13 @@ class AvgChunkFile:
:ivar chunks: List of chunks. Each chunk is a dictionary containing its ID, the coordinates, and the averaged quantities :ivar chunks: List of chunks. Each chunk is a dictionary containing its ID, the coordinates, and the averaged quantities
""" """
def __init__(self, filename): def __init__(self, filename):
with open(filename, 'rt') as f: with open(filename, 'rt', encoding='utf-8') as f:
timestep = None timestep = None
chunks_read = 0 chunks_read = 0
compress = False
coord_start = None
coord_end = None
data_start = None
self.timesteps = [] self.timesteps = []
self.total_count = [] self.total_count = []
@ -145,24 +151,24 @@ class AvgChunkFile:
for lineno, line in enumerate(f): for lineno, line in enumerate(f):
if lineno == 0: if lineno == 0:
if not line.startswith("# Chunk-averaged data for fix"): if not line.startswith("# Chunk-averaged data for fix"):
raise Exception("Chunk data reader only supports default avg/chunk headers!") raise RuntimeError("Chunk data reader only supports default avg/chunk headers!")
parts = line.split() parts = line.split()
self.fix_name = parts[5] self.fix_name = parts[5]
self.group_name = parts[8] self.group_name = parts[8]
continue continue
elif lineno == 1: if lineno == 1:
if not line.startswith("# Timestep Number-of-chunks Total-count"): if not line.startswith("# Timestep Number-of-chunks Total-count"):
raise Exception("Chunk data reader only supports default avg/chunk headers!") raise RuntimeError("Chunk data reader only supports default avg/chunk headers!")
continue continue
elif lineno == 2: if lineno == 2:
if not line.startswith("#"): if not line.startswith("#"):
raise Exception("Chunk data reader only supports default avg/chunk headers!") raise RuntimeError("Chunk data reader only supports default avg/chunk headers!")
columns = line.split()[1:] columns = line.split()[1:]
ndim = line.count("Coord") ndim = line.count("Coord")
compress = 'OrigID' in line compress = 'OrigID' in line
if ndim > 0: if ndim > 0:
coord_start = columns.index("Coord1") coord_start = columns.index("Coord1")
coord_end = columns.index("Coord%d" % ndim) coord_end = columns.index(f"Coord{ndim}")
ncount_start = coord_end + 1 ncount_start = coord_end + 1
data_start = ncount_start + 1 data_start = ncount_start + 1
else: else:
@ -216,8 +222,8 @@ class AvgChunkFile:
assert chunk == chunks_read assert chunk == chunks_read
else: else:
# do not support changing number of chunks # do not support changing number of chunks
if not (num_chunks == int(parts[1])): if not num_chunks == int(parts[1]):
raise Exception("Currently, changing numbers of chunks are not supported.") raise RuntimeError("Currently, changing numbers of chunks are not supported.")
timestep = int(parts[0]) timestep = int(parts[0])
total_count = float(parts[2]) total_count = float(parts[2])

View File

@ -11,12 +11,13 @@
# See the README file in the top-level LAMMPS directory. # See the README file in the top-level LAMMPS directory.
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
################################################################################ """
# NumPy additions NumPy additions to the LAMMPS Python module
# Written by Richard Berger <richard.berger@temple.edu> Written by Richard Berger <richard.berger@temple.edu>
################################################################################ """
from ctypes import POINTER, c_void_p, c_char_p, c_double, c_int, c_int32, c_int64, cast from ctypes import POINTER, c_void_p, c_char_p, c_double, c_int, c_int32, c_int64, cast
import numpy as np
from .constants import LAMMPS_AUTODETECT, LAMMPS_INT, LAMMPS_INT_2D, LAMMPS_DOUBLE, \ from .constants import LAMMPS_AUTODETECT, LAMMPS_INT, LAMMPS_INT_2D, LAMMPS_DOUBLE, \
LAMMPS_DOUBLE_2D, LAMMPS_INT64, LAMMPS_INT64_2D, LMP_STYLE_GLOBAL, LMP_STYLE_ATOM, \ LAMMPS_DOUBLE_2D, LAMMPS_INT64, LAMMPS_INT64_2D, LMP_STYLE_GLOBAL, LMP_STYLE_ATOM, \
@ -26,6 +27,7 @@ from .constants import LAMMPS_AUTODETECT, LAMMPS_INT, LAMMPS_INT_2D, LAMMPS_DOU
from .data import NeighList from .data import NeighList
class numpy_wrapper: class numpy_wrapper:
# pylint: disable=C0103
"""lammps API NumPy Wrapper """lammps API NumPy Wrapper
This is a wrapper class that provides additional methods on top of an This is a wrapper class that provides additional methods on top of an
@ -46,10 +48,9 @@ class numpy_wrapper:
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def _ctype_to_numpy_int(self, ctype_int): def _ctype_to_numpy_int(self, ctype_int):
import numpy as np
if ctype_int == c_int32: if ctype_int == c_int32:
return np.int32 return np.int32
elif ctype_int == c_int64: if ctype_int == c_int64:
return np.int64 return np.int64
return np.intc return np.intc
@ -102,9 +103,9 @@ class numpy_wrapper:
if dtype in (LAMMPS_DOUBLE, LAMMPS_DOUBLE_2D): if dtype in (LAMMPS_DOUBLE, LAMMPS_DOUBLE_2D):
return self.darray(raw_ptr, nelem, dim) return self.darray(raw_ptr, nelem, dim)
elif dtype in (LAMMPS_INT, LAMMPS_INT_2D): if dtype in (LAMMPS_INT, LAMMPS_INT_2D):
return self.iarray(c_int32, raw_ptr, nelem, dim) return self.iarray(c_int32, raw_ptr, nelem, dim)
elif dtype in (LAMMPS_INT64, LAMMPS_INT64_2D): if dtype in (LAMMPS_INT64, LAMMPS_INT64_2D):
return self.iarray(c_int64, raw_ptr, nelem, dim) return self.iarray(c_int64, raw_ptr, nelem, dim)
return raw_ptr return raw_ptr
@ -133,7 +134,7 @@ class numpy_wrapper:
if ctype == LMP_TYPE_VECTOR: if ctype == LMP_TYPE_VECTOR:
nrows = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_VECTOR) nrows = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_VECTOR)
return self.darray(value, nrows) return self.darray(value, nrows)
elif ctype == LMP_TYPE_ARRAY: if ctype == LMP_TYPE_ARRAY:
nrows = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_ROWS) nrows = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_ROWS)
ncols = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_COLS) ncols = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_COLS)
return self.darray(value, nrows, ncols) return self.darray(value, nrows, ncols)
@ -142,13 +143,12 @@ class numpy_wrapper:
ncols = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_COLS) ncols = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_COLS)
if ncols == 0: if ncols == 0:
return self.darray(value, nrows) return self.darray(value, nrows)
else: return self.darray(value, nrows, ncols)
return self.darray(value, nrows, ncols)
elif cstyle == LMP_STYLE_ATOM: elif cstyle == LMP_STYLE_ATOM:
if ctype == LMP_TYPE_VECTOR: if ctype == LMP_TYPE_VECTOR:
nlocal = self.lmp.extract_global("nlocal") nlocal = self.lmp.extract_global("nlocal")
return self.darray(value, nlocal) return self.darray(value, nlocal)
elif ctype == LMP_TYPE_ARRAY: if ctype == LMP_TYPE_ARRAY:
nlocal = self.lmp.extract_global("nlocal") nlocal = self.lmp.extract_global("nlocal")
ncols = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_COLS) ncols = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_COLS)
return self.darray(value, nlocal, ncols) return self.darray(value, nlocal, ncols)
@ -189,7 +189,7 @@ class numpy_wrapper:
if ftype == LMP_TYPE_VECTOR: if ftype == LMP_TYPE_VECTOR:
nlocal = self.lmp.extract_global("nlocal") nlocal = self.lmp.extract_global("nlocal")
return self.darray(value, nlocal) return self.darray(value, nlocal)
elif ftype == LMP_TYPE_ARRAY: if ftype == LMP_TYPE_ARRAY:
nlocal = self.lmp.extract_global("nlocal") nlocal = self.lmp.extract_global("nlocal")
ncols = self.lmp.extract_fix(fid, fstyle, LMP_SIZE_COLS, 0, 0) ncols = self.lmp.extract_fix(fid, fstyle, LMP_SIZE_COLS, 0, 0)
return self.darray(value, nlocal, ncols) return self.darray(value, nlocal, ncols)
@ -197,7 +197,7 @@ class numpy_wrapper:
if ftype == LMP_TYPE_VECTOR: if ftype == LMP_TYPE_VECTOR:
nrows = self.lmp.extract_fix(fid, fstyle, LMP_SIZE_ROWS, 0, 0) nrows = self.lmp.extract_fix(fid, fstyle, LMP_SIZE_ROWS, 0, 0)
return self.darray(value, nrows) return self.darray(value, nrows)
elif ftype == LMP_TYPE_ARRAY: if ftype == LMP_TYPE_ARRAY:
nrows = self.lmp.extract_fix(fid, fstyle, LMP_SIZE_ROWS, 0, 0) nrows = self.lmp.extract_fix(fid, fstyle, LMP_SIZE_ROWS, 0, 0)
ncols = self.lmp.extract_fix(fid, fstyle, LMP_SIZE_COLS, 0, 0) ncols = self.lmp.extract_fix(fid, fstyle, LMP_SIZE_COLS, 0, 0)
return self.darray(value, nrows, ncols) return self.darray(value, nrows, ncols)
@ -222,7 +222,6 @@ class numpy_wrapper:
:return: the requested data or None :return: the requested data or None
:rtype: c_double, numpy.array, or NoneType :rtype: c_double, numpy.array, or NoneType
""" """
import numpy as np
value = self.lmp.extract_variable(name, group, vartype) value = self.lmp.extract_variable(name, group, vartype)
if vartype == LMP_VAR_ATOM: if vartype == LMP_VAR_ATOM:
return np.ctypeslib.as_array(value) return np.ctypeslib.as_array(value)
@ -242,7 +241,6 @@ class numpy_wrapper:
:return: the requested data as a 2d-integer numpy array :return: the requested data as a 2d-integer numpy array
:rtype: numpy.array(nbonds,3) :rtype: numpy.array(nbonds,3)
""" """
import numpy as np
nbonds, value = self.lmp.gather_bonds() nbonds, value = self.lmp.gather_bonds()
return np.ctypeslib.as_array(value).reshape(nbonds,3) return np.ctypeslib.as_array(value).reshape(nbonds,3)
@ -260,7 +258,6 @@ class numpy_wrapper:
:return: the requested data as a 2d-integer numpy array :return: the requested data as a 2d-integer numpy array
:rtype: numpy.array(nangles,4) :rtype: numpy.array(nangles,4)
""" """
import numpy as np
nangles, value = self.lmp.gather_angles() nangles, value = self.lmp.gather_angles()
return np.ctypeslib.as_array(value).reshape(nangles,4) return np.ctypeslib.as_array(value).reshape(nangles,4)
@ -278,7 +275,6 @@ class numpy_wrapper:
:return: the requested data as a 2d-integer numpy array :return: the requested data as a 2d-integer numpy array
:rtype: numpy.array(ndihedrals,5) :rtype: numpy.array(ndihedrals,5)
""" """
import numpy as np
ndihedrals, value = self.lmp.gather_dihedrals() ndihedrals, value = self.lmp.gather_dihedrals()
return np.ctypeslib.as_array(value).reshape(ndihedrals,5) return np.ctypeslib.as_array(value).reshape(ndihedrals,5)
@ -296,7 +292,6 @@ class numpy_wrapper:
:return: the requested data as a 2d-integer numpy array :return: the requested data as a 2d-integer numpy array
:rtype: numpy.array(nimpropers,5) :rtype: numpy.array(nimpropers,5)
""" """
import numpy as np
nimpropers, value = self.lmp.gather_impropers() nimpropers, value = self.lmp.gather_impropers()
return np.ctypeslib.as_array(value).reshape(nimpropers,5) return np.ctypeslib.as_array(value).reshape(nimpropers,5)
@ -317,7 +312,6 @@ class numpy_wrapper:
:return: requested data :return: requested data
:rtype: numpy.array :rtype: numpy.array
""" """
import numpy as np
nlocal = self.lmp.extract_setting('nlocal') nlocal = self.lmp.extract_setting('nlocal')
value = self.lmp.fix_external_get_force(fix_id) value = self.lmp.fix_external_get_force(fix_id)
return self.darray(value,nlocal,3) return self.darray(value,nlocal,3)
@ -339,10 +333,9 @@ class numpy_wrapper:
:param eatom: per-atom potential energy :param eatom: per-atom potential energy
:type: numpy.array :type: numpy.array
""" """
import numpy as np
nlocal = self.lmp.extract_setting('nlocal') nlocal = self.lmp.extract_setting('nlocal')
if len(eatom) < nlocal: if len(eatom) < nlocal:
raise Exception('per-atom energy dimension must be at least nlocal') raise RuntimeError('per-atom energy dimension must be at least nlocal')
c_double_p = POINTER(c_double) c_double_p = POINTER(c_double)
value = eatom.astype(np.double) value = eatom.astype(np.double)
@ -366,12 +359,11 @@ class numpy_wrapper:
:param eatom: per-atom potential energy :param eatom: per-atom potential energy
:type: numpy.array :type: numpy.array
""" """
import numpy as np
nlocal = self.lmp.extract_setting('nlocal') nlocal = self.lmp.extract_setting('nlocal')
if len(vatom) < nlocal: if len(vatom) < nlocal:
raise Exception('per-atom virial first dimension must be at least nlocal') raise RuntimeError('per-atom virial first dimension must be at least nlocal')
if len(vatom[0]) != 6: if len(vatom[0]) != 6:
raise Exception('per-atom virial second dimension must be 6') raise RuntimeError('per-atom virial second dimension must be 6')
c_double_pp = np.ctypeslib.ndpointer(dtype=np.uintp, ndim=1, flags='C') c_double_pp = np.ctypeslib.ndpointer(dtype=np.uintp, ndim=1, flags='C')
@ -395,7 +387,7 @@ class numpy_wrapper:
:rtype: NumPyNeighList :rtype: NumPyNeighList
""" """
if idx < 0: if idx < 0:
return None return None
return NumPyNeighList(self.lmp, idx) return NumPyNeighList(self.lmp, idx)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -422,8 +414,8 @@ class numpy_wrapper:
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def iarray(self, c_int_type, raw_ptr, nelem, dim=1): def iarray(self, c_int_type, raw_ptr, nelem, dim=1):
# pylint: disable=C0116
if raw_ptr and nelem >= 0 and dim >= 0: if raw_ptr and nelem >= 0 and dim >= 0:
import numpy as np
np_int_type = self._ctype_to_numpy_int(c_int_type) np_int_type = self._ctype_to_numpy_int(c_int_type)
ptr = None ptr = None
@ -440,15 +432,15 @@ class numpy_wrapper:
if dim > 1: if dim > 1:
a.shape = (nelem, dim) a.shape = (nelem, dim)
else: else:
a.shape = (nelem) a.shape = nelem
return a return a
return None return None
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def darray(self, raw_ptr, nelem, dim=1): def darray(self, raw_ptr, nelem, dim=1):
# pylint: disable=C0116
if raw_ptr and nelem >= 0 and dim >= 0: if raw_ptr and nelem >= 0 and dim >= 0:
import numpy as np
ptr = None ptr = None
if dim == 1: if dim == 1:
@ -464,51 +456,48 @@ class numpy_wrapper:
if dim > 1: if dim > 1:
a.shape = (nelem, dim) a.shape = (nelem, dim)
else: else:
a.shape = (nelem) a.shape = nelem
return a return a
return None return None
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
class NumPyNeighList(NeighList): class NumPyNeighList(NeighList):
"""This is a wrapper class that exposes the contents of a neighbor list. """This is a wrapper class that exposes the contents of a neighbor list.
It can be used like a regular Python list. Each element is a tuple of: It can be used like a regular Python list. Each element is a tuple of:
* the atom local index * the atom local index
* a NumPy array containing the local atom indices of its neighbors * a NumPy array containing the local atom indices of its neighbors
Internally it uses the lower-level LAMMPS C-library interface. Internally it uses the lower-level LAMMPS C-library interface.
:param lmp: reference to instance of :py:class:`lammps` :param lmp: reference to instance of :py:class:`lammps`
:type lmp: lammps :type lmp: lammps
:param idx: neighbor list index :param idx: neighbor list index
:type idx: int :type idx: int
"""
def get(self, element):
""" """
def __init__(self, lmp, idx): Access a specific neighbor list entry. "element" must be a number from 0 to the size-1 of the list
super(NumPyNeighList, self).__init__(lmp, idx)
def get(self, element): :return: tuple with atom local index, numpy array of neighbor local atom indices
""" :rtype: (int, numpy.array)
Access a specific neighbor list entry. "element" must be a number from 0 to the size-1 of the list """
iatom, neighbors = self.lmp.numpy.get_neighlist_element_neighbors(self.idx, element)
return iatom, neighbors
:return: tuple with atom local index, numpy array of neighbor local atom indices def find(self, iatom):
:rtype: (int, numpy.array) """
""" Find the neighbor list for a specific (local) atom iatom.
iatom, neighbors = self.lmp.numpy.get_neighlist_element_neighbors(self.idx, element) If there is no list for iatom, None is returned.
return iatom, neighbors
def find(self, iatom): :return: numpy array of neighbor local atom indices
""" :rtype: numpy.array or None
Find the neighbor list for a specific (local) atom iatom. """
If there is no list for iatom, None is returned. inum = self.size
for ii in range(inum):
:return: numpy array of neighbor local atom indices idx, neighbors = self.get(ii)
:rtype: numpy.array or None if idx == iatom:
""" return neighbors
inum = self.size return None
for ii in range(inum):
idx, neighbors = self.get(ii)
if idx == iatom:
return neighbors
return None