Add missing checks for C++ exceptions

This commit is contained in:
Richard Berger
2021-03-18 14:26:18 -04:00
parent b6498c8b9b
commit 64ba2f4ee2

View File

@ -39,6 +39,20 @@ class MPIAbortException(Exception):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
class ExceptionCheck:
"""Utility class to rethrow LAMMPS C++ exceptions as Python exceptions"""
def __init__(self, lmp):
self.lmp = lmp
def __enter__(self):
pass
def __exit__(self, type, value, traceback):
if self.lmp.has_exceptions and self.lmp.lib.lammps_has_error(self.lmp.lmp):
raise self.lmp._lammps_exception
# -------------------------------------------------------------------------
class lammps(object): class lammps(object):
"""Create an instance of the LAMMPS Python class. """Create an instance of the LAMMPS Python class.
@ -519,10 +533,9 @@ class lammps(object):
""" """
if path: path = path.encode() if path: path = path.encode()
else: return else: return
self.lib.lammps_file(self.lmp, path)
if self.has_exceptions and self.lib.lammps_has_error(self.lmp): with ExceptionCheck(self):
raise self._lammps_exception self.lib.lammps_file(self.lmp, path)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -537,10 +550,9 @@ class lammps(object):
""" """
if cmd: cmd = cmd.encode() if cmd: cmd = cmd.encode()
else: return else: return
self.lib.lammps_command(self.lmp,cmd)
if self.has_exceptions and self.lib.lammps_has_error(self.lmp): with ExceptionCheck(self):
raise self._lammps_exception self.lib.lammps_command(self.lmp,cmd)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -558,10 +570,9 @@ class lammps(object):
narg = len(cmdlist) narg = len(cmdlist)
args = (c_char_p * narg)(*cmds) args = (c_char_p * narg)(*cmds)
self.lib.lammps_commands_list.argtypes = [c_void_p, c_int, c_char_p * narg] self.lib.lammps_commands_list.argtypes = [c_void_p, c_int, c_char_p * narg]
self.lib.lammps_commands_list(self.lmp,narg,args)
if self.has_exceptions and self.lib.lammps_has_error(self.lmp): with ExceptionCheck(self):
raise self._lammps_exception self.lib.lammps_commands_list(self.lmp,narg,args)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -576,10 +587,9 @@ class lammps(object):
:type multicmd: string :type multicmd: string
""" """
if type(multicmd) is str: multicmd = multicmd.encode() if type(multicmd) is str: multicmd = multicmd.encode()
self.lib.lammps_commands_string(self.lmp,c_char_p(multicmd))
if self.has_exceptions and self.lib.lammps_has_error(self.lmp): with ExceptionCheck(self):
raise self._lammps_exception self.lib.lammps_commands_string(self.lmp,c_char_p(multicmd))
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -614,9 +624,10 @@ class lammps(object):
periodicity = (3*c_int)() periodicity = (3*c_int)()
box_change = c_int() box_change = c_int()
self.lib.lammps_extract_box(self.lmp,boxlo,boxhi, with ExceptionCheck(self):
byref(xy),byref(yz),byref(xz), self.lib.lammps_extract_box(self.lmp,boxlo,boxhi,
periodicity,byref(box_change)) byref(xy),byref(yz),byref(xz),
periodicity,byref(box_change))
boxlo = boxlo[:3] boxlo = boxlo[:3]
boxhi = boxhi[:3] boxhi = boxhi[:3]
@ -649,7 +660,8 @@ class lammps(object):
""" """
cboxlo = (3*c_double)(*boxlo) cboxlo = (3*c_double)(*boxlo)
cboxhi = (3*c_double)(*boxhi) cboxhi = (3*c_double)(*boxhi)
self.lib.lammps_reset_box(self.lmp,cboxlo,cboxhi,xy,yz,xz) with ExceptionCheck(self):
self.lib.lammps_reset_box(self.lmp,cboxlo,cboxhi,xy,yz,xz)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -666,7 +678,9 @@ class lammps(object):
""" """
if name: name = name.encode() if name: name = name.encode()
else: return None else: return None
return self.lib.lammps_get_thermo(self.lmp,name)
with ExceptionCheck(self):
return self.lib.lammps_get_thermo(self.lmp,name)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -838,6 +852,7 @@ class lammps(object):
elif dtype == LAMMPS_INT64_2D: elif dtype == LAMMPS_INT64_2D:
self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int64)) self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int64))
else: return None else: return None
ptr = self.lib.lammps_extract_atom(self.lmp, name) ptr = self.lib.lammps_extract_atom(self.lmp, name)
if ptr: return ptr if ptr: return ptr
else: return None else: return None
@ -872,39 +887,44 @@ class lammps(object):
if type == LMP_TYPE_SCALAR: if type == LMP_TYPE_SCALAR:
if style == LMP_STYLE_GLOBAL: if style == LMP_STYLE_GLOBAL:
self.lib.lammps_extract_compute.restype = POINTER(c_double) self.lib.lammps_extract_compute.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr[0] return ptr[0]
elif style == LMP_STYLE_ATOM: elif style == LMP_STYLE_ATOM:
return None return None
elif style == LMP_STYLE_LOCAL: elif style == LMP_STYLE_LOCAL:
self.lib.lammps_extract_compute.restype = POINTER(c_int) self.lib.lammps_extract_compute.restype = POINTER(c_int)
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr[0] return ptr[0]
if type == LMP_TYPE_VECTOR: elif type == LMP_TYPE_VECTOR:
self.lib.lammps_extract_compute.restype = POINTER(c_double) self.lib.lammps_extract_compute.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr return ptr
if type == LMP_TYPE_ARRAY: elif type == LMP_TYPE_ARRAY:
self.lib.lammps_extract_compute.restype = POINTER(POINTER(c_double)) self.lib.lammps_extract_compute.restype = POINTER(POINTER(c_double))
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr return ptr
if type == LMP_SIZE_COLS: elif type == LMP_SIZE_COLS:
if style == LMP_STYLE_GLOBAL \ if style == LMP_STYLE_GLOBAL \
or style == LMP_STYLE_ATOM \ or style == LMP_STYLE_ATOM \
or style == LMP_STYLE_LOCAL: or style == LMP_STYLE_LOCAL:
self.lib.lammps_extract_compute.restype = POINTER(c_int) self.lib.lammps_extract_compute.restype = POINTER(c_int)
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr[0] return ptr[0]
if type == LMP_SIZE_VECTOR \ elif type == LMP_SIZE_VECTOR or type == LMP_SIZE_ROWS:
or type == LMP_SIZE_ROWS:
if style == LMP_STYLE_GLOBAL \ if style == LMP_STYLE_GLOBAL \
or style == LMP_STYLE_LOCAL: or style == LMP_STYLE_LOCAL:
self.lib.lammps_extract_compute.restype = POINTER(c_int) self.lib.lammps_extract_compute.restype = POINTER(c_int)
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr[0] return ptr[0]
return None return None
@ -947,13 +967,15 @@ class lammps(object):
if style == LMP_STYLE_GLOBAL: if style == LMP_STYLE_GLOBAL:
if type in (LMP_TYPE_SCALAR, LMP_TYPE_VECTOR, LMP_TYPE_ARRAY): if type in (LMP_TYPE_SCALAR, LMP_TYPE_VECTOR, LMP_TYPE_ARRAY):
self.lib.lammps_extract_fix.restype = POINTER(c_double) self.lib.lammps_extract_fix.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
result = ptr[0] result = ptr[0]
self.lib.lammps_free(ptr) self.lib.lammps_free(ptr)
return result return result
elif type in (LMP_SIZE_VECTOR, LMP_SIZE_ROWS, LMP_SIZE_COLS): elif type in (LMP_SIZE_VECTOR, LMP_SIZE_ROWS, LMP_SIZE_COLS):
self.lib.lammps_extract_fix.restype = POINTER(c_int) self.lib.lammps_extract_fix.restype = POINTER(c_int)
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
return ptr[0] return ptr[0]
else: else:
return None return None
@ -967,7 +989,8 @@ class lammps(object):
self.lib.lammps_extract_fix.restype = POINTER(c_int) self.lib.lammps_extract_fix.restype = POINTER(c_int)
else: else:
return None return None
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
if type == LMP_SIZE_COLS: if type == LMP_SIZE_COLS:
return ptr[0] return ptr[0]
else: else:
@ -982,7 +1005,8 @@ class lammps(object):
self.lib.lammps_extract_fix.restype = POINTER(c_int) self.lib.lammps_extract_fix.restype = POINTER(c_int)
else: else:
return None return None
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
if type in (LMP_TYPE_VECTOR, LMP_TYPE_ARRAY): if type in (LMP_TYPE_VECTOR, LMP_TYPE_ARRAY):
return ptr return ptr
else: else:
@ -1026,7 +1050,8 @@ class lammps(object):
if group: group = group.encode() if group: group = group.encode()
if vartype == LMP_VAR_EQUAL: if vartype == LMP_VAR_EQUAL:
self.lib.lammps_extract_variable.restype = POINTER(c_double) self.lib.lammps_extract_variable.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_variable(self.lmp,name,group) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
if ptr: result = ptr[0] if ptr: result = ptr[0]
else: return None else: return None
self.lib.lammps_free(ptr) self.lib.lammps_free(ptr)
@ -1035,7 +1060,8 @@ class lammps(object):
nlocal = self.extract_global("nlocal") nlocal = self.extract_global("nlocal")
result = (c_double*nlocal)() result = (c_double*nlocal)()
self.lib.lammps_extract_variable.restype = POINTER(c_double) self.lib.lammps_extract_variable.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_variable(self.lmp,name,group) with ExceptionCheck(self):
ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
if ptr: if ptr:
for i in range(nlocal): result[i] = ptr[i] for i in range(nlocal): result[i] = ptr[i]
self.lib.lammps_free(ptr) self.lib.lammps_free(ptr)
@ -1062,7 +1088,8 @@ class lammps(object):
else: return -1 else: return -1
if value: value = str(value).encode() if value: value = str(value).encode()
else: return -1 else: return -1
return self.lib.lammps_set_variable(self.lmp,name,value) with ExceptionCheck(self):
return self.lib.lammps_set_variable(self.lmp,name,value)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -1078,13 +1105,15 @@ class lammps(object):
def gather_atoms(self,name,type,count): def gather_atoms(self,name,type,count):
if name: name = name.encode() if name: name = name.encode()
natoms = self.get_natoms() natoms = self.get_natoms()
if type == 0: with ExceptionCheck(self):
data = ((count*natoms)*c_int)() if type == 0:
self.lib.lammps_gather_atoms(self.lmp,name,type,count,data) data = ((count*natoms)*c_int)()
elif type == 1: self.lib.lammps_gather_atoms(self.lmp,name,type,count,data)
data = ((count*natoms)*c_double)() elif type == 1:
self.lib.lammps_gather_atoms(self.lmp,name,type,count,data) data = ((count*natoms)*c_double)()
else: return None self.lib.lammps_gather_atoms(self.lmp,name,type,count,data)
else:
return None
return data return data
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -1092,24 +1121,28 @@ class lammps(object):
def gather_atoms_concat(self,name,type,count): def gather_atoms_concat(self,name,type,count):
if name: name = name.encode() if name: name = name.encode()
natoms = self.get_natoms() natoms = self.get_natoms()
if type == 0: with ExceptionCheck(self):
data = ((count*natoms)*c_int)() if type == 0:
self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data) data = ((count*natoms)*c_int)()
elif type == 1: self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data)
data = ((count*natoms)*c_double)() elif type == 1:
self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data) data = ((count*natoms)*c_double)()
else: return None self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data)
else:
return None
return data return data
def gather_atoms_subset(self,name,type,count,ndata,ids): def gather_atoms_subset(self,name,type,count,ndata,ids):
if name: name = name.encode() if name: name = name.encode()
if type == 0: with ExceptionCheck(self):
data = ((count*ndata)*c_int)() if type == 0:
self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data) data = ((count*ndata)*c_int)()
elif type == 1: self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data)
data = ((count*ndata)*c_double)() elif type == 1:
self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data) data = ((count*ndata)*c_double)()
else: return None self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data)
else:
return None
return data return data
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -1125,13 +1158,15 @@ class lammps(object):
def scatter_atoms(self,name,type,count,data): def scatter_atoms(self,name,type,count,data):
if name: name = name.encode() if name: name = name.encode()
self.lib.lammps_scatter_atoms(self.lmp,name,type,count,data) with ExceptionCheck(self):
self.lib.lammps_scatter_atoms(self.lmp,name,type,count,data)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def scatter_atoms_subset(self,name,type,count,ndata,ids,data): def scatter_atoms_subset(self,name,type,count,ndata,ids,data):
if name: name = name.encode() if name: name = name.encode()
self.lib.lammps_scatter_atoms_subset(self.lmp,name,type,count,ndata,ids,data) with ExceptionCheck(self):
self.lib.lammps_scatter_atoms_subset(self.lmp,name,type,count,ndata,ids,data)
# return vector of atom/compute/fix properties gathered across procs # return vector of atom/compute/fix properties gathered across procs
# 3 variants to match src/library.cpp # 3 variants to match src/library.cpp
@ -1144,36 +1179,42 @@ class lammps(object):
def gather(self,name,type,count): def gather(self,name,type,count):
if name: name = name.encode() if name: name = name.encode()
natoms = self.get_natoms() natoms = self.get_natoms()
if type == 0: with ExceptionCheck(self):
data = ((count*natoms)*c_int)() if type == 0:
self.lib.lammps_gather(self.lmp,name,type,count,data) data = ((count*natoms)*c_int)()
elif type == 1: self.lib.lammps_gather(self.lmp,name,type,count,data)
data = ((count*natoms)*c_double)() elif type == 1:
self.lib.lammps_gather(self.lmp,name,type,count,data) data = ((count*natoms)*c_double)()
else: return None self.lib.lammps_gather(self.lmp,name,type,count,data)
else:
return None
return data return data
def gather_concat(self,name,type,count): def gather_concat(self,name,type,count):
if name: name = name.encode() if name: name = name.encode()
natoms = self.get_natoms() natoms = self.get_natoms()
if type == 0: with ExceptionCheck(self):
data = ((count*natoms)*c_int)() if type == 0:
self.lib.lammps_gather_concat(self.lmp,name,type,count,data) data = ((count*natoms)*c_int)()
elif type == 1: self.lib.lammps_gather_concat(self.lmp,name,type,count,data)
data = ((count*natoms)*c_double)() elif type == 1:
self.lib.lammps_gather_concat(self.lmp,name,type,count,data) data = ((count*natoms)*c_double)()
else: return None self.lib.lammps_gather_concat(self.lmp,name,type,count,data)
else:
return None
return data return data
def gather_subset(self,name,type,count,ndata,ids): def gather_subset(self,name,type,count,ndata,ids):
if name: name = name.encode() if name: name = name.encode()
if type == 0: with ExceptionCheck(self):
data = ((count*ndata)*c_int)() if type == 0:
self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data) data = ((count*ndata)*c_int)()
elif type == 1: self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data)
data = ((count*ndata)*c_double)() elif type == 1:
self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data) data = ((count*ndata)*c_double)()
else: return None self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data)
else:
return None
return data return data
# scatter vector of atom/compute/fix properties across procs # scatter vector of atom/compute/fix properties across procs
@ -1187,11 +1228,13 @@ class lammps(object):
def scatter(self,name,type,count,data): def scatter(self,name,type,count,data):
if name: name = name.encode() if name: name = name.encode()
self.lib.lammps_scatter(self.lmp,name,type,count,data) with ExceptionCheck(self):
self.lib.lammps_scatter(self.lmp,name,type,count,data)
def scatter_subset(self,name,type,count,ndata,ids,data): def scatter_subset(self,name,type,count,ndata,ids,data):
if name: name = name.encode() if name: name = name.encode()
self.lib.lammps_scatter_subset(self.lmp,name,type,count,ndata,ids,data) with ExceptionCheck(self):
self.lib.lammps_scatter_subset(self.lmp,name,type,count,ndata,ids,data)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -1328,7 +1371,8 @@ class lammps(object):
POINTER(c_int*n), POINTER(c_double*three_n), POINTER(c_int*n), POINTER(c_double*three_n),
POINTER(c_double*three_n), POINTER(c_double*three_n),
POINTER(self.c_imageint*n), c_int] POINTER(self.c_imageint*n), c_int]
return self.lib.lammps_create_atoms(self.lmp, n, id_lmp, type_lmp, x_lmp, v_lmp, img_lmp, se_lmp) with ExceptionCheck(self):
return self.lib.lammps_create_atoms(self.lmp, n, id_lmp, type_lmp, x_lmp, v_lmp, img_lmp, se_lmp)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -1538,10 +1582,12 @@ class lammps(object):
if category not in self._available_styles: if category not in self._available_styles:
self._available_styles[category] = [] self._available_styles[category] = []
nstyles = self.lib.lammps_style_count(self.lmp, category.encode()) with ExceptionCheck(self):
nstyles = self.lib.lammps_style_count(self.lmp, category.encode())
sb = create_string_buffer(100) sb = create_string_buffer(100)
for idx in range(nstyles): for idx in range(nstyles):
self.lib.lammps_style_name(self.lmp, category.encode(), idx, sb, 100) with ExceptionCheck(self):
self.lib.lammps_style_name(self.lmp, category.encode(), idx, sb, 100)
self._available_styles[category].append(sb.value.decode()) self._available_styles[category].append(sb.value.decode())
return self._available_styles[category] return self._available_styles[category]
@ -1607,7 +1653,8 @@ class lammps(object):
cCaller = caller cCaller = caller
self.callback[fix_name] = { 'function': cFunc, 'caller': caller } self.callback[fix_name] = { 'function': cFunc, 'caller': caller }
self.lib.lammps_set_fix_external_callback(self.lmp, fix_name.encode(), cFunc, cCaller) with ExceptionCheck(self):
self.lib.lammps_set_fix_external_callback(self.lmp, fix_name.encode(), cFunc, cCaller)
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------