Add missing checks for C++ exceptions

This commit is contained in:
Richard Berger
2021-03-18 14:26:18 -04:00
parent b6498c8b9b
commit 64ba2f4ee2

View File

@ -39,6 +39,20 @@ class MPIAbortException(Exception):
# -------------------------------------------------------------------------
class ExceptionCheck:
"""Utility class to rethrow LAMMPS C++ exceptions as Python exceptions"""
def __init__(self, lmp):
self.lmp = lmp
def __enter__(self):
pass
def __exit__(self, type, value, traceback):
if self.lmp.has_exceptions and self.lmp.lib.lammps_has_error(self.lmp.lmp):
raise self.lmp._lammps_exception
# -------------------------------------------------------------------------
class lammps(object):
"""Create an instance of the LAMMPS Python class.
@ -519,10 +533,9 @@ class lammps(object):
"""
if path: path = path.encode()
else: return
self.lib.lammps_file(self.lmp, path)
if self.has_exceptions and self.lib.lammps_has_error(self.lmp):
raise self._lammps_exception
with ExceptionCheck(self):
self.lib.lammps_file(self.lmp, path)
# -------------------------------------------------------------------------
@ -537,10 +550,9 @@ class lammps(object):
"""
if cmd: cmd = cmd.encode()
else: return
self.lib.lammps_command(self.lmp,cmd)
if self.has_exceptions and self.lib.lammps_has_error(self.lmp):
raise self._lammps_exception
with ExceptionCheck(self):
self.lib.lammps_command(self.lmp,cmd)
# -------------------------------------------------------------------------
@ -558,10 +570,9 @@ class lammps(object):
narg = len(cmdlist)
args = (c_char_p * narg)(*cmds)
self.lib.lammps_commands_list.argtypes = [c_void_p, c_int, c_char_p * narg]
self.lib.lammps_commands_list(self.lmp,narg,args)
if self.has_exceptions and self.lib.lammps_has_error(self.lmp):
raise self._lammps_exception
with ExceptionCheck(self):
self.lib.lammps_commands_list(self.lmp,narg,args)
# -------------------------------------------------------------------------
@ -576,10 +587,9 @@ class lammps(object):
:type multicmd: string
"""
if type(multicmd) is str: multicmd = multicmd.encode()
self.lib.lammps_commands_string(self.lmp,c_char_p(multicmd))
if self.has_exceptions and self.lib.lammps_has_error(self.lmp):
raise self._lammps_exception
with ExceptionCheck(self):
self.lib.lammps_commands_string(self.lmp,c_char_p(multicmd))
# -------------------------------------------------------------------------
@ -614,9 +624,10 @@ class lammps(object):
periodicity = (3*c_int)()
box_change = c_int()
self.lib.lammps_extract_box(self.lmp,boxlo,boxhi,
byref(xy),byref(yz),byref(xz),
periodicity,byref(box_change))
with ExceptionCheck(self):
self.lib.lammps_extract_box(self.lmp,boxlo,boxhi,
byref(xy),byref(yz),byref(xz),
periodicity,byref(box_change))
boxlo = boxlo[:3]
boxhi = boxhi[:3]
@ -649,7 +660,8 @@ class lammps(object):
"""
cboxlo = (3*c_double)(*boxlo)
cboxhi = (3*c_double)(*boxhi)
self.lib.lammps_reset_box(self.lmp,cboxlo,cboxhi,xy,yz,xz)
with ExceptionCheck(self):
self.lib.lammps_reset_box(self.lmp,cboxlo,cboxhi,xy,yz,xz)
# -------------------------------------------------------------------------
@ -666,7 +678,9 @@ class lammps(object):
"""
if name: name = name.encode()
else: return None
return self.lib.lammps_get_thermo(self.lmp,name)
with ExceptionCheck(self):
return self.lib.lammps_get_thermo(self.lmp,name)
# -------------------------------------------------------------------------
@ -838,6 +852,7 @@ class lammps(object):
elif dtype == LAMMPS_INT64_2D:
self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int64))
else: return None
ptr = self.lib.lammps_extract_atom(self.lmp, name)
if ptr: return ptr
else: return None
@ -872,39 +887,44 @@ class lammps(object):
if type == LMP_TYPE_SCALAR:
if style == LMP_STYLE_GLOBAL:
self.lib.lammps_extract_compute.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr[0]
elif style == LMP_STYLE_ATOM:
return None
elif style == LMP_STYLE_LOCAL:
self.lib.lammps_extract_compute.restype = POINTER(c_int)
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr[0]
if type == LMP_TYPE_VECTOR:
elif type == LMP_TYPE_VECTOR:
self.lib.lammps_extract_compute.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr
if type == LMP_TYPE_ARRAY:
elif type == LMP_TYPE_ARRAY:
self.lib.lammps_extract_compute.restype = POINTER(POINTER(c_double))
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr
if type == LMP_SIZE_COLS:
elif type == LMP_SIZE_COLS:
if style == LMP_STYLE_GLOBAL \
or style == LMP_STYLE_ATOM \
or style == LMP_STYLE_LOCAL:
self.lib.lammps_extract_compute.restype = POINTER(c_int)
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr[0]
if type == LMP_SIZE_VECTOR \
or type == LMP_SIZE_ROWS:
elif type == LMP_SIZE_VECTOR or type == LMP_SIZE_ROWS:
if style == LMP_STYLE_GLOBAL \
or style == LMP_STYLE_LOCAL:
self.lib.lammps_extract_compute.restype = POINTER(c_int)
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,id,style,type)
return ptr[0]
return None
@ -947,13 +967,15 @@ class lammps(object):
if style == LMP_STYLE_GLOBAL:
if type in (LMP_TYPE_SCALAR, LMP_TYPE_VECTOR, LMP_TYPE_ARRAY):
self.lib.lammps_extract_fix.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
result = ptr[0]
self.lib.lammps_free(ptr)
return result
elif type in (LMP_SIZE_VECTOR, LMP_SIZE_ROWS, LMP_SIZE_COLS):
self.lib.lammps_extract_fix.restype = POINTER(c_int)
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
return ptr[0]
else:
return None
@ -967,7 +989,8 @@ class lammps(object):
self.lib.lammps_extract_fix.restype = POINTER(c_int)
else:
return None
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
if type == LMP_SIZE_COLS:
return ptr[0]
else:
@ -982,7 +1005,8 @@ class lammps(object):
self.lib.lammps_extract_fix.restype = POINTER(c_int)
else:
return None
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,id,style,type,nrow,ncol)
if type in (LMP_TYPE_VECTOR, LMP_TYPE_ARRAY):
return ptr
else:
@ -1026,7 +1050,8 @@ class lammps(object):
if group: group = group.encode()
if vartype == LMP_VAR_EQUAL:
self.lib.lammps_extract_variable.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
if ptr: result = ptr[0]
else: return None
self.lib.lammps_free(ptr)
@ -1035,7 +1060,8 @@ class lammps(object):
nlocal = self.extract_global("nlocal")
result = (c_double*nlocal)()
self.lib.lammps_extract_variable.restype = POINTER(c_double)
ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
if ptr:
for i in range(nlocal): result[i] = ptr[i]
self.lib.lammps_free(ptr)
@ -1062,7 +1088,8 @@ class lammps(object):
else: return -1
if value: value = str(value).encode()
else: return -1
return self.lib.lammps_set_variable(self.lmp,name,value)
with ExceptionCheck(self):
return self.lib.lammps_set_variable(self.lmp,name,value)
# -------------------------------------------------------------------------
@ -1078,13 +1105,15 @@ class lammps(object):
def gather_atoms(self,name,type,count):
if name: name = name.encode()
natoms = self.get_natoms()
if type == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather_atoms(self.lmp,name,type,count,data)
elif type == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather_atoms(self.lmp,name,type,count,data)
else: return None
with ExceptionCheck(self):
if type == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather_atoms(self.lmp,name,type,count,data)
elif type == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather_atoms(self.lmp,name,type,count,data)
else:
return None
return data
# -------------------------------------------------------------------------
@ -1092,24 +1121,28 @@ class lammps(object):
def gather_atoms_concat(self,name,type,count):
if name: name = name.encode()
natoms = self.get_natoms()
if type == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data)
elif type == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data)
else: return None
with ExceptionCheck(self):
if type == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data)
elif type == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather_atoms_concat(self.lmp,name,type,count,data)
else:
return None
return data
def gather_atoms_subset(self,name,type,count,ndata,ids):
if name: name = name.encode()
if type == 0:
data = ((count*ndata)*c_int)()
self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data)
elif type == 1:
data = ((count*ndata)*c_double)()
self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data)
else: return None
with ExceptionCheck(self):
if type == 0:
data = ((count*ndata)*c_int)()
self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data)
elif type == 1:
data = ((count*ndata)*c_double)()
self.lib.lammps_gather_atoms_subset(self.lmp,name,type,count,ndata,ids,data)
else:
return None
return data
# -------------------------------------------------------------------------
@ -1125,13 +1158,15 @@ class lammps(object):
def scatter_atoms(self,name,type,count,data):
if name: name = name.encode()
self.lib.lammps_scatter_atoms(self.lmp,name,type,count,data)
with ExceptionCheck(self):
self.lib.lammps_scatter_atoms(self.lmp,name,type,count,data)
# -------------------------------------------------------------------------
def scatter_atoms_subset(self,name,type,count,ndata,ids,data):
if name: name = name.encode()
self.lib.lammps_scatter_atoms_subset(self.lmp,name,type,count,ndata,ids,data)
with ExceptionCheck(self):
self.lib.lammps_scatter_atoms_subset(self.lmp,name,type,count,ndata,ids,data)
# return vector of atom/compute/fix properties gathered across procs
# 3 variants to match src/library.cpp
@ -1144,36 +1179,42 @@ class lammps(object):
def gather(self,name,type,count):
if name: name = name.encode()
natoms = self.get_natoms()
if type == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather(self.lmp,name,type,count,data)
elif type == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather(self.lmp,name,type,count,data)
else: return None
with ExceptionCheck(self):
if type == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather(self.lmp,name,type,count,data)
elif type == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather(self.lmp,name,type,count,data)
else:
return None
return data
def gather_concat(self,name,type,count):
if name: name = name.encode()
natoms = self.get_natoms()
if type == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather_concat(self.lmp,name,type,count,data)
elif type == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather_concat(self.lmp,name,type,count,data)
else: return None
with ExceptionCheck(self):
if type == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather_concat(self.lmp,name,type,count,data)
elif type == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather_concat(self.lmp,name,type,count,data)
else:
return None
return data
def gather_subset(self,name,type,count,ndata,ids):
if name: name = name.encode()
if type == 0:
data = ((count*ndata)*c_int)()
self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data)
elif type == 1:
data = ((count*ndata)*c_double)()
self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data)
else: return None
with ExceptionCheck(self):
if type == 0:
data = ((count*ndata)*c_int)()
self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data)
elif type == 1:
data = ((count*ndata)*c_double)()
self.lib.lammps_gather_subset(self.lmp,name,type,count,ndata,ids,data)
else:
return None
return data
# scatter vector of atom/compute/fix properties across procs
@ -1187,11 +1228,13 @@ class lammps(object):
def scatter(self,name,type,count,data):
if name: name = name.encode()
self.lib.lammps_scatter(self.lmp,name,type,count,data)
with ExceptionCheck(self):
self.lib.lammps_scatter(self.lmp,name,type,count,data)
def scatter_subset(self,name,type,count,ndata,ids,data):
if name: name = name.encode()
self.lib.lammps_scatter_subset(self.lmp,name,type,count,ndata,ids,data)
with ExceptionCheck(self):
self.lib.lammps_scatter_subset(self.lmp,name,type,count,ndata,ids,data)
# -------------------------------------------------------------------------
@ -1328,7 +1371,8 @@ class lammps(object):
POINTER(c_int*n), POINTER(c_double*three_n),
POINTER(c_double*three_n),
POINTER(self.c_imageint*n), c_int]
return self.lib.lammps_create_atoms(self.lmp, n, id_lmp, type_lmp, x_lmp, v_lmp, img_lmp, se_lmp)
with ExceptionCheck(self):
return self.lib.lammps_create_atoms(self.lmp, n, id_lmp, type_lmp, x_lmp, v_lmp, img_lmp, se_lmp)
# -------------------------------------------------------------------------
@ -1538,10 +1582,12 @@ class lammps(object):
if category not in self._available_styles:
self._available_styles[category] = []
nstyles = self.lib.lammps_style_count(self.lmp, category.encode())
with ExceptionCheck(self):
nstyles = self.lib.lammps_style_count(self.lmp, category.encode())
sb = create_string_buffer(100)
for idx in range(nstyles):
self.lib.lammps_style_name(self.lmp, category.encode(), idx, sb, 100)
with ExceptionCheck(self):
self.lib.lammps_style_name(self.lmp, category.encode(), idx, sb, 100)
self._available_styles[category].append(sb.value.decode())
return self._available_styles[category]
@ -1607,7 +1653,8 @@ class lammps(object):
cCaller = caller
self.callback[fix_name] = { 'function': cFunc, 'caller': caller }
self.lib.lammps_set_fix_external_callback(self.lmp, fix_name.encode(), cFunc, cCaller)
with ExceptionCheck(self):
self.lib.lammps_set_fix_external_callback(self.lmp, fix_name.encode(), cFunc, cCaller)
# -------------------------------------------------------------------------