diff --git a/python/lammps/numpy_wrapper.py b/python/lammps/numpy_wrapper.py index 3619728081..ce0cb35e47 100644 --- a/python/lammps/numpy_wrapper.py +++ b/python/lammps/numpy_wrapper.py @@ -165,7 +165,7 @@ class numpy_wrapper: """ value = self.lmp.extract_compute(cid, cstyle, ctype) - if cstyle in (LMP_STYLE_GLOBAL, LMP_STYLE_LOCAL): + if cstyle == LMP_STYLE_GLOBAL: if ctype == LMP_TYPE_VECTOR: nrows = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_VECTOR) return self.darray(value, nrows) @@ -173,6 +173,13 @@ class numpy_wrapper: nrows = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_ROWS) ncols = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_COLS) return self.darray(value, nrows, ncols) + elif cstyle == LMP_STYLE_LOCAL: + nrows = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_ROWS) + ncols = self.lmp.extract_compute(cid, cstyle, LMP_SIZE_COLS) + if ncols == 0: + return self.darray(value, nrows) + else: + return self.darray(value, nrows, ncols) elif cstyle == LMP_STYLE_ATOM: if ctype == LMP_TYPE_VECTOR: nlocal = self.lmp.extract_global("nlocal")