Detect correct integer type in lammps python interface
This commit is contained in:
@ -37,7 +37,7 @@ def get_ctypes_int(size):
|
||||
return c_int32
|
||||
elif size == 8:
|
||||
return c_int64
|
||||
return c_int
|
||||
return c_int
|
||||
|
||||
class MPIAbortException(Exception):
|
||||
def __init__(self, message):
|
||||
@ -266,25 +266,41 @@ class lammps(object):
|
||||
def __init__(self, lmp):
|
||||
self.lmp = lmp
|
||||
|
||||
def extract_atom_iarray(self, name, nelem, dim=1):
|
||||
if dim == 1:
|
||||
tmp = self.lmp.extract_atom(name, 0)
|
||||
ptr = cast(tmp, POINTER(c_int * nelem))
|
||||
else:
|
||||
tmp = self.lmp.extract_atom(name, 1)
|
||||
ptr = cast(tmp[0], POINTER(c_int * nelem * dim))
|
||||
def _ctype_to_numpy_int(self, ctype_int):
|
||||
if ctype_int == c_int32:
|
||||
return np.int32
|
||||
elif ctype_int == c_int64:
|
||||
return np.int64
|
||||
return np.intc
|
||||
|
||||
a = np.frombuffer(ptr.contents, dtype=np.intc)
|
||||
def extract_atom_iarray(self, name, nelem, dim=1):
|
||||
if name in ['id', 'molecule']:
|
||||
c_int_type = self.lmp.c_tagint
|
||||
elif name in ['image']:
|
||||
c_int_type = self.lmp.c_imageint
|
||||
else:
|
||||
c_int_type = c_int
|
||||
|
||||
np_int_type = self._ctype_to_numpy_int(c_int_type)
|
||||
|
||||
if dim == 1:
|
||||
tmp = self.lmp.extract_atom(name, 0)
|
||||
ptr = cast(tmp, POINTER(c_int_type * nelem))
|
||||
else:
|
||||
tmp = self.lmp.extract_atom(name, 1)
|
||||
ptr = cast(tmp[0], POINTER(c_int_type * nelem * dim))
|
||||
|
||||
a = np.frombuffer(ptr.contents, dtype=np_int_type)
|
||||
a.shape = (nelem, dim)
|
||||
return a
|
||||
|
||||
def extract_atom_darray(self, name, nelem, dim=1):
|
||||
if dim == 1:
|
||||
tmp = self.lmp.extract_atom(name, 2)
|
||||
ptr = cast(tmp, POINTER(c_double * nelem))
|
||||
tmp = self.lmp.extract_atom(name, 2)
|
||||
ptr = cast(tmp, POINTER(c_double * nelem))
|
||||
else:
|
||||
tmp = self.lmp.extract_atom(name, 3)
|
||||
ptr = cast(tmp[0], POINTER(c_double * nelem * dim))
|
||||
tmp = self.lmp.extract_atom(name, 3)
|
||||
ptr = cast(tmp[0], POINTER(c_double * nelem * dim))
|
||||
|
||||
a = np.frombuffer(ptr.contents)
|
||||
a.shape = (nelem, dim)
|
||||
|
||||
Reference in New Issue
Block a user