diff --git a/python/lammps.py b/python/lammps.py index 36cf2d2fdd..e7062ba514 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -221,8 +221,8 @@ class lammps(object): # add way to insert Python callback for fix external self.callback = {} - self.FIX_EXTERNAL_CALLBACK_FUNC = CFUNCTYPE(None, c_void_p, self.c_bigint, c_int, POINTER(self.c_tagint), POINTER(POINTER(c_double)), POINTER(POINTER(c_double))) - self.lib.lammps_set_fix_external_callback.argtypes = [c_void_p, c_char_p, self.FIX_EXTERNAL_CALLBACK_FUNC, c_void_p] + self.FIX_EXTERNAL_CALLBACK_FUNC = CFUNCTYPE(None, py_object, self.c_bigint, c_int, POINTER(self.c_tagint), POINTER(POINTER(c_double)), POINTER(POINTER(c_double))) + self.lib.lammps_set_fix_external_callback.argtypes = [c_void_p, c_char_p, self.FIX_EXTERNAL_CALLBACK_FUNC, py_object] self.lib.lammps_set_fix_external_callback.restype = None # shut-down LAMMPS instance @@ -357,26 +357,38 @@ class lammps(object): else: c_int_type = c_int + if dim == 1: + raw_ptr = self.lmp.extract_atom(name, 0) + else: + raw_ptr = self.lmp.extract_atom(name, 1) + + return self.iarray(c_int_type, raw_ptr, nelem, dim) + + def extract_atom_darray(self, name, nelem, dim=1): + if dim == 1: + raw_ptr = self.lmp.extract_atom(name, 2) + else: + raw_ptr = self.lmp.extract_atom(name, 3) + + return self.darray(raw_ptr, nelem, dim) + + def iarray(self, c_int_type, raw_ptr, nelem, dim=1): np_int_type = self._ctype_to_numpy_int(c_int_type) if dim == 1: - tmp = self.lmp.extract_atom(name, 0) - ptr = cast(tmp, POINTER(c_int_type * nelem)) + ptr = cast(raw_ptr, POINTER(c_int_type * nelem)) else: - tmp = self.lmp.extract_atom(name, 1) - ptr = cast(tmp[0], POINTER(c_int_type * nelem * dim)) + ptr = cast(raw_ptr[0], POINTER(c_int_type * nelem * dim)) a = np.frombuffer(ptr.contents, dtype=np_int_type) a.shape = (nelem, dim) return a - def extract_atom_darray(self, name, nelem, dim=1): + def darray(self, raw_ptr, nelem, dim=1): if dim == 1: - tmp = self.lmp.extract_atom(name, 2) - ptr = cast(tmp, POINTER(c_double * nelem)) + ptr = cast(raw_ptr, POINTER(c_double * nelem)) else: - tmp = self.lmp.extract_atom(name, 3) - ptr = cast(tmp[0], POINTER(c_double * nelem * dim)) + ptr = cast(raw_ptr[0], POINTER(c_double * nelem * dim)) a = np.frombuffer(ptr.contents) a.shape = (nelem, dim) @@ -617,28 +629,14 @@ class lammps(object): return np.int64 return np.intc - def callback_wrapper(caller_ptr, ntimestep, nlocal, tag_ptr, x_ptr, fext_ptr): - if cast(caller_ptr,POINTER(py_object)).contents: - pyCallerObj = cast(caller_ptr,POINTER(py_object)).contents.value - else: - pyCallerObj = None - - tptr = cast(tag_ptr, POINTER(self.c_tagint * nlocal)) - tag = np.frombuffer(tptr.contents, dtype=_ctype_to_numpy_int(self.c_tagint)) - tag.shape = (nlocal) - - xptr = cast(x_ptr[0], POINTER(c_double * nlocal * 3)) - x = np.frombuffer(xptr.contents) - x.shape = (nlocal, 3) - - fptr = cast(fext_ptr[0], POINTER(c_double * nlocal * 3)) - f = np.frombuffer(fptr.contents) - f.shape = (nlocal, 3) - - callback(pyCallerObj, ntimestep, nlocal, tag, x, f) + def callback_wrapper(caller, ntimestep, nlocal, tag_ptr, x_ptr, fext_ptr): + tag = self.numpy.iarray(self.c_tagint, tag_ptr, nlocal, 1) + x = self.numpy.darray(x_ptr, nlocal, 3) + f = self.numpy.darray(fext_ptr, nlocal, 3) + callback(caller, ntimestep, nlocal, tag, x, f) cFunc = self.FIX_EXTERNAL_CALLBACK_FUNC(callback_wrapper) - cCaller = cast(pointer(py_object(caller)), c_void_p) + cCaller = caller self.callback[fix_name] = { 'function': cFunc, 'caller': caller }