diff --git a/python/lammps/core.py b/python/lammps/core.py index eaf78dfa0c..be026d5e10 100644 --- a/python/lammps/core.py +++ b/python/lammps/core.py @@ -39,6 +39,20 @@ class MPIAbortException(Exception): # ------------------------------------------------------------------------- +class ExceptionCheck: + """Utility class to rethrow LAMMPS C++ exceptions as Python exceptions""" + def __init__(self, lmp): + self.lmp = lmp + + def __enter__(self): + pass + + def __exit__(self, type, value, traceback): + if self.lmp.has_exceptions and self.lmp.lib.lammps_has_error(self.lmp.lmp): + raise self.lmp._lammps_exception + +# ------------------------------------------------------------------------- + class lammps(object): """Create an instance of the LAMMPS Python class. @@ -519,10 +533,9 @@ class lammps(object): """ if path: path = path.encode() else: return - self.lib.lammps_file(self.lmp, path) - if self.has_exceptions and self.lib.lammps_has_error(self.lmp): - raise self._lammps_exception + with ExceptionCheck(self): + self.lib.lammps_file(self.lmp, path) # ------------------------------------------------------------------------- @@ -537,10 +550,9 @@ class lammps(object): """ if cmd: cmd = cmd.encode() else: return - self.lib.lammps_command(self.lmp,cmd) - if self.has_exceptions and self.lib.lammps_has_error(self.lmp): - raise self._lammps_exception + with ExceptionCheck(self): + self.lib.lammps_command(self.lmp,cmd) # ------------------------------------------------------------------------- @@ -558,10 +570,9 @@ class lammps(object): narg = len(cmdlist) args = (c_char_p * narg)(*cmds) self.lib.lammps_commands_list.argtypes = [c_void_p, c_int, c_char_p * narg] - self.lib.lammps_commands_list(self.lmp,narg,args) - if self.has_exceptions and self.lib.lammps_has_error(self.lmp): - raise self._lammps_exception + with ExceptionCheck(self): + self.lib.lammps_commands_list(self.lmp,narg,args) # ------------------------------------------------------------------------- @@ -576,10 +587,9 @@ class lammps(object): :type multicmd: string """ if type(multicmd) is str: multicmd = multicmd.encode() - self.lib.lammps_commands_string(self.lmp,c_char_p(multicmd)) - if self.has_exceptions and self.lib.lammps_has_error(self.lmp): - raise self._lammps_exception + with ExceptionCheck(self): + self.lib.lammps_commands_string(self.lmp,c_char_p(multicmd)) # ------------------------------------------------------------------------- @@ -614,9 +624,10 @@ class lammps(object): periodicity = (3*c_int)() box_change = c_int() - self.lib.lammps_extract_box(self.lmp,boxlo,boxhi, - byref(xy),byref(yz),byref(xz), - periodicity,byref(box_change)) + with ExceptionCheck(self): + self.lib.lammps_extract_box(self.lmp,boxlo,boxhi, + byref(xy),byref(yz),byref(xz), + periodicity,byref(box_change)) boxlo = boxlo[:3] boxhi = boxhi[:3] @@ -649,7 +660,8 @@ class lammps(object): """ cboxlo = (3*c_double)(*boxlo) cboxhi = (3*c_double)(*boxhi) - self.lib.lammps_reset_box(self.lmp,cboxlo,cboxhi,xy,yz,xz) + with ExceptionCheck(self): + self.lib.lammps_reset_box(self.lmp,cboxlo,cboxhi,xy,yz,xz) # ------------------------------------------------------------------------- @@ -666,7 +678,9 @@ class lammps(object): """ if name: name = name.encode() else: return None - return self.lib.lammps_get_thermo(self.lmp,name) + + with ExceptionCheck(self): + return self.lib.lammps_get_thermo(self.lmp,name) # ------------------------------------------------------------------------- @@ -838,6 +852,7 @@ class lammps(object): elif dtype == LAMMPS_INT64_2D: self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int64)) else: return None + ptr = self.lib.lammps_extract_atom(self.lmp, name) if ptr: return ptr else: return None @@ -872,39 +887,44 @@ class lammps(object): if type == LMP_TYPE_SCALAR: if style == LMP_STYLE_GLOBAL: self.lib.lammps_extract_compute.restype = POINTER(c_double) - ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) return ptr[0] elif style == LMP_STYLE_ATOM: return None elif style == LMP_STYLE_LOCAL: self.lib.lammps_extract_compute.restype = POINTER(c_int) - ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) return ptr[0] - if type == LMP_TYPE_VECTOR: + elif type == LMP_TYPE_VECTOR: self.lib.lammps_extract_compute.restype = POINTER(c_double) - ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) return ptr - if type == LMP_TYPE_ARRAY: + elif type == LMP_TYPE_ARRAY: self.lib.lammps_extract_compute.restype = POINTER(POINTER(c_double)) - ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) return ptr - if type == LMP_SIZE_COLS: + elif type == LMP_SIZE_COLS: if style == LMP_STYLE_GLOBAL \ or style == LMP_STYLE_ATOM \ or style == LMP_STYLE_LOCAL: self.lib.lammps_extract_compute.restype = POINTER(c_int) - ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) return ptr[0] - if type == LMP_SIZE_VECTOR \ - or type == LMP_SIZE_ROWS: + elif type == LMP_SIZE_VECTOR or type == LMP_SIZE_ROWS: if style == LMP_STYLE_GLOBAL \ or style == LMP_STYLE_LOCAL: self.lib.lammps_extract_compute.restype = POINTER(c_int) - ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) return ptr[0] return None @@ -947,13 +967,15 @@ class lammps(object): if style == LMP_STYLE_GLOBAL: if type in (LMP_TYPE_SCALAR, LMP_TYPE_VECTOR, LMP_TYPE_ARRAY): self.lib.lammps_extract_fix.restype = POINTER(c_double) - ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) result = ptr[0] self.lib.lammps_free(ptr) return result elif type in (LMP_SIZE_VECTOR, LMP_SIZE_ROWS, LMP_SIZE_COLS): self.lib.lammps_extract_fix.restype = POINTER(c_int) - ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) return ptr[0] else: return None @@ -967,7 +989,8 @@ class lammps(object): self.lib.lammps_extract_fix.restype = POINTER(c_int) else: return None - ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) if type == LMP_SIZE_COLS: return ptr[0] else: @@ -982,7 +1005,8 @@ class lammps(object): self.lib.lammps_extract_fix.restype = POINTER(c_int) else: return None - ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) if type in (LMP_TYPE_VECTOR, LMP_TYPE_ARRAY): return ptr else: @@ -1026,7 +1050,8 @@ class lammps(object): if group: group = group.encode() if vartype == LMP_VAR_EQUAL: self.lib.lammps_extract_variable.restype = POINTER(c_double) - ptr = self.lib.lammps_extract_variable(self.lmp,name,group) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_variable(self.lmp,name,group) if ptr: result = ptr[0] else: return None self.lib.lammps_free(ptr) @@ -1035,7 +1060,8 @@ class lammps(object): nlocal = self.extract_global("nlocal") result = (c_double*nlocal)() self.lib.lammps_extract_variable.restype = POINTER(c_double) - ptr = self.lib.lammps_extract_variable(self.lmp,name,group) + with ExceptionCheck(self): + ptr = self.lib.lammps_extract_variable(self.lmp,name,group) if ptr: for i in range(nlocal): result[i] = ptr[i] self.lib.lammps_free(ptr) @@ -1062,7 +1088,8 @@ class lammps(object): else: return -1 if value: value = str(value).encode() else: return -1 - return self.lib.lammps_set_variable(self.lmp,name,value) + with ExceptionCheck(self): + return self.lib.lammps_set_variable(self.lmp,name,value) # ------------------------------------------------------------------------- @@ -1078,13 +1105,15 @@ class lammps(object): def gather_atoms(self,name,type,count): if name: name = name.encode() natoms = self.get_natoms() - if type == 0: - data = ((count*natoms)*c_int)() - self.lib.lammps_gather_atoms(self.lmp,name,type,count,data) - elif type == 1: - data = ((count*natoms)*c_double)() - self.lib.lammps_gather_atoms(self.lmp,name,type,count,data) - else: return None + with ExceptionCheck(self): + if type == 0: + data = ((count*natoms)*c_int)() + self.lib.lammps_gather_atoms(self.lmp,name,type,count,data) + elif type == 1: + data = ((count*natoms)*c_double)() + self.lib.lammps_gather_atoms(self.lmp,name,type,count,data) + else: + return None return data # ------------------------------------------------------------------------- @@ -1092,24 +1121,28 @@ class lammps(object): def gather_atoms_concat(self,name,type,count): if name: name = name.encode() natoms = self.get_natoms() - if type == 0: - data = ((count*natoms)*c_int)() - self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data) - elif type == 1: - data = ((count*natoms)*c_double)() - self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data) - else: return None + with ExceptionCheck(self): + if type == 0: + data = ((count*natoms)*c_int)() + self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data) + elif type == 1: + data = ((count*natoms)*c_double)() + self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data) + else: + return None return data def gather_atoms_subset(self,name,type,count,ndata,ids): if name: name = name.encode() - if type == 0: - data = ((count*ndata)*c_int)() - self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data) - elif type == 1: - data = ((count*ndata)*c_double)() - self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data) - else: return None + with ExceptionCheck(self): + if type == 0: + data = ((count*ndata)*c_int)() + self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data) + elif type == 1: + data = ((count*ndata)*c_double)() + self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data) + else: + return None return data # ------------------------------------------------------------------------- @@ -1125,13 +1158,15 @@ class lammps(object): def scatter_atoms(self,name,type,count,data): if name: name = name.encode() - self.lib.lammps_scatter_atoms(self.lmp,name,type,count,data) + with ExceptionCheck(self): + self.lib.lammps_scatter_atoms(self.lmp,name,type,count,data) # ------------------------------------------------------------------------- def scatter_atoms_subset(self,name,type,count,ndata,ids,data): if name: name = name.encode() - self.lib.lammps_scatter_atoms_subset(self.lmp,name,type,count,ndata,ids,data) + with ExceptionCheck(self): + self.lib.lammps_scatter_atoms_subset(self.lmp,name,type,count,ndata,ids,data) # return vector of atom/compute/fix properties gathered across procs # 3 variants to match src/library.cpp @@ -1144,36 +1179,42 @@ class lammps(object): def gather(self,name,type,count): if name: name = name.encode() natoms = self.get_natoms() - if type == 0: - data = ((count*natoms)*c_int)() - self.lib.lammps_gather(self.lmp,name,type,count,data) - elif type == 1: - data = ((count*natoms)*c_double)() - self.lib.lammps_gather(self.lmp,name,type,count,data) - else: return None + with ExceptionCheck(self): + if type == 0: + data = ((count*natoms)*c_int)() + self.lib.lammps_gather(self.lmp,name,type,count,data) + elif type == 1: + data = ((count*natoms)*c_double)() + self.lib.lammps_gather(self.lmp,name,type,count,data) + else: + return None return data def gather_concat(self,name,type,count): if name: name = name.encode() natoms = self.get_natoms() - if type == 0: - data = ((count*natoms)*c_int)() - self.lib.lammps_gather_concat(self.lmp,name,type,count,data) - elif type == 1: - data = ((count*natoms)*c_double)() - self.lib.lammps_gather_concat(self.lmp,name,type,count,data) - else: return None + with ExceptionCheck(self): + if type == 0: + data = ((count*natoms)*c_int)() + self.lib.lammps_gather_concat(self.lmp,name,type,count,data) + elif type == 1: + data = ((count*natoms)*c_double)() + self.lib.lammps_gather_concat(self.lmp,name,type,count,data) + else: + return None return data def gather_subset(self,name,type,count,ndata,ids): if name: name = name.encode() - if type == 0: - data = ((count*ndata)*c_int)() - self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data) - elif type == 1: - data = ((count*ndata)*c_double)() - self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data) - else: return None + with ExceptionCheck(self): + if type == 0: + data = ((count*ndata)*c_int)() + self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data) + elif type == 1: + data = ((count*ndata)*c_double)() + self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data) + else: + return None return data # scatter vector of atom/compute/fix properties across procs @@ -1187,11 +1228,13 @@ class lammps(object): def scatter(self,name,type,count,data): if name: name = name.encode() - self.lib.lammps_scatter(self.lmp,name,type,count,data) + with ExceptionCheck(self): + self.lib.lammps_scatter(self.lmp,name,type,count,data) def scatter_subset(self,name,type,count,ndata,ids,data): if name: name = name.encode() - self.lib.lammps_scatter_subset(self.lmp,name,type,count,ndata,ids,data) + with ExceptionCheck(self): + self.lib.lammps_scatter_subset(self.lmp,name,type,count,ndata,ids,data) # ------------------------------------------------------------------------- @@ -1328,7 +1371,8 @@ class lammps(object): POINTER(c_int*n), POINTER(c_double*three_n), POINTER(c_double*three_n), POINTER(self.c_imageint*n), c_int] - return self.lib.lammps_create_atoms(self.lmp, n, id_lmp, type_lmp, x_lmp, v_lmp, img_lmp, se_lmp) + with ExceptionCheck(self): + return self.lib.lammps_create_atoms(self.lmp, n, id_lmp, type_lmp, x_lmp, v_lmp, img_lmp, se_lmp) # ------------------------------------------------------------------------- @@ -1538,10 +1582,12 @@ class lammps(object): if category not in self._available_styles: self._available_styles[category] = [] - nstyles = self.lib.lammps_style_count(self.lmp, category.encode()) + with ExceptionCheck(self): + nstyles = self.lib.lammps_style_count(self.lmp, category.encode()) sb = create_string_buffer(100) for idx in range(nstyles): - self.lib.lammps_style_name(self.lmp, category.encode(), idx, sb, 100) + with ExceptionCheck(self): + self.lib.lammps_style_name(self.lmp, category.encode(), idx, sb, 100) self._available_styles[category].append(sb.value.decode()) return self._available_styles[category] @@ -1607,7 +1653,8 @@ class lammps(object): cCaller = caller self.callback[fix_name] = { 'function': cFunc, 'caller': caller } - self.lib.lammps_set_fix_external_callback(self.lmp, fix_name.encode(), cFunc, cCaller) + with ExceptionCheck(self): + self.lib.lammps_set_fix_external_callback(self.lmp, fix_name.encode(), cFunc, cCaller) # -------------------------------------------------------------------------