From 0b8136a38b7834ebe59c294831215da801c25c8e Mon Sep 17 00:00:00 2001 From: Richard Berger Date: Thu, 27 Aug 2020 16:15:59 -0400 Subject: [PATCH] Add extract_compute, extract_fix, and extract_variable to lammps.numpy --- python/lammps.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/python/lammps.py b/python/lammps.py index a729e706e7..565b89cfd9 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -435,6 +435,54 @@ class lammps(object): return self.darray(raw_ptr, nelem, dim) + def extract_compute(self, cid, style, datatype): + value = self.lmp.extract_compute(cid, style, datatype) + + if style in (LMP_STYLE_GLOBAL, LMP_STYLE_LOCAL): + if datatype == LMP_TYPE_VECTOR: + nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_VECTOR) + print("NROWS", nrows) + return self.darray(value, nrows) + elif datatype == LMP_TYPE_ARRAY: + nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_ROWS) + ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS) + return self.darray(value, nrows, ncols) + elif style == LMP_STYLE_ATOM: + if datatype == LMP_TYPE_VECTOR: + nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT) + return self.darray(value, nlocal) + elif datatype == LMP_TYPE_ARRAY: + nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT) + ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS) + return self.darray(value, nlocal, ncols) + return value + + def extract_fix(self, fid, style, datatype, nrow=0, ncol=0): + value = self.lmp.extract_fix(fid, style, datatype, nrow, ncol) + if style == LMP_STYLE_ATOM: + if datatype == LMP_TYPE_VECTOR: + nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT) + return self.darray(value, nlocal) + elif datatype == LMP_TYPE_ARRAY: + nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT) + ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0) + return self.darray(value, nlocal, ncols) + elif style == LMP_STYLE_LOCAL: + if datatype == LMP_TYPE_VECTOR: + nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0) + return self.darray(value, nrows) + elif datatype == LMP_TYPE_ARRAY: + nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0) + ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0) + return self.darray(value, nrows, ncols) + return value + + def extract_variable(self, name, group=None, datatype=LMP_VAR_EQUAL): + value = self.lmp.extract_variable(name, group, datatype) + if datatype == LMP_VAR_ATOM: + return np.ctypeslib.as_array(value) + return value + def iarray(self, c_int_type, raw_ptr, nelem, dim=1): np_int_type = self._ctype_to_numpy_int(c_int_type)