diff --git a/python/lammps.py b/python/lammps.py index a729e706e7..565b89cfd9 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -435,6 +435,54 @@ class lammps(object): return self.darray(raw_ptr, nelem, dim) + def extract_compute(self, cid, style, datatype): + value = self.lmp.extract_compute(cid, style, datatype) + + if style in (LMP_STYLE_GLOBAL, LMP_STYLE_LOCAL): + if datatype == LMP_TYPE_VECTOR: + nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_VECTOR) + print("NROWS", nrows) + return self.darray(value, nrows) + elif datatype == LMP_TYPE_ARRAY: + nrows = self.lmp.extract_compute(cid, style, LMP_SIZE_ROWS) + ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS) + return self.darray(value, nrows, ncols) + elif style == LMP_STYLE_ATOM: + if datatype == LMP_TYPE_VECTOR: + nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT) + return self.darray(value, nlocal) + elif datatype == LMP_TYPE_ARRAY: + nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT) + ncols = self.lmp.extract_compute(cid, style, LMP_SIZE_COLS) + return self.darray(value, nlocal, ncols) + return value + + def extract_fix(self, fid, style, datatype, nrow=0, ncol=0): + value = self.lmp.extract_fix(fid, style, datatype, nrow, ncol) + if style == LMP_STYLE_ATOM: + if datatype == LMP_TYPE_VECTOR: + nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT) + return self.darray(value, nlocal) + elif datatype == LMP_TYPE_ARRAY: + nlocal = self.lmp.extract_global("nlocal", LAMMPS_INT) + ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0) + return self.darray(value, nlocal, ncols) + elif style == LMP_STYLE_LOCAL: + if datatype == LMP_TYPE_VECTOR: + nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0) + return self.darray(value, nrows) + elif datatype == LMP_TYPE_ARRAY: + nrows = self.lmp.extract_fix(fid, style, LMP_SIZE_ROWS, 0, 0) + ncols = self.lmp.extract_fix(fid, style, LMP_SIZE_COLS, 0, 0) + return self.darray(value, nrows, ncols) + return value + + def extract_variable(self, name, group=None, datatype=LMP_VAR_EQUAL): + value = self.lmp.extract_variable(name, group, datatype) + if datatype == LMP_VAR_ATOM: + return np.ctypeslib.as_array(value) + return value + def iarray(self, c_int_type, raw_ptr, nelem, dim=1): np_int_type = self._ctype_to_numpy_int(c_int_type)