don't overwrite string type argument variables with their encoded version

This commit is contained in:
Axel Kohlmeyer
2024-07-16 19:00:18 -04:00
parent 4daba292d7
commit 8d4a80729a

View File

@ -533,11 +533,11 @@ class lammps(object):
:param error_text:
:type error_text: string
"""
if error_text: error_text = error_text.encode()
else: error_text = "(unknown error)".encode()
if error_text: new_error_text = error_text.encode()
else: new_error_text = "(unknown error)".encode()
with ExceptionCheck(self):
self.lib.lammps_error(self.lmp, error_type, error_text)
self.lib.lammps_error(self.lmp, error_type, new_error_text)
# -------------------------------------------------------------------------
@ -612,11 +612,11 @@ class lammps(object):
:param path: Name of the file/path with LAMMPS commands
:type path: string
"""
if path: path = path.encode()
if path: newpath = path.encode()
else: return
with ExceptionCheck(self):
self.lib.lammps_file(self.lmp, path)
self.lib.lammps_file(self.lmp, newpath)
# -------------------------------------------------------------------------
@ -629,11 +629,11 @@ class lammps(object):
:param cmd: a single lammps command
:type cmd: string
"""
if cmd: cmd = cmd.encode()
if cmd: newcmd = cmd.encode()
else: return
with ExceptionCheck(self):
self.lib.lammps_command(self.lmp,cmd)
self.lib.lammps_command(self.lmp, newcmd)
# -------------------------------------------------------------------------
@ -667,10 +667,11 @@ class lammps(object):
:param multicmd: text block of lammps commands
:type multicmd: string
"""
if type(multicmd) is str: multicmd = multicmd.encode()
if type(multicmd) is str: newmulticmd = multicmd.encode()
else: newmulticmd = multicmd
with ExceptionCheck(self):
self.lib.lammps_commands_string(self.lmp,c_char_p(multicmd))
self.lib.lammps_commands_string(self.lmp,c_char_p(newmulticmd))
# -------------------------------------------------------------------------
@ -757,11 +758,11 @@ class lammps(object):
:return: value of thermo keyword
:rtype: double or None
"""
if name: name = name.encode()
if name: newname = name.encode()
else: return None
with ExceptionCheck(self):
return self.lib.lammps_get_thermo(self.lmp,name)
return self.lib.lammps_get_thermo(self.lmp, newname)
# -------------------------------------------------------------------------
@property
@ -835,9 +836,9 @@ class lammps(object):
:return: value of the setting
:rtype: int
"""
if name: name = name.encode()
if name: newname = name.encode()
else: return None
return int(self.lib.lammps_extract_setting(self.lmp,name))
return int(self.lib.lammps_extract_setting(self.lmp, newname))
# -------------------------------------------------------------------------
# extract global info datatype
@ -858,9 +859,9 @@ class lammps(object):
:return: data type of global property, see :ref:`py_datatype_constants`
:rtype: int
"""
if name: name = name.encode()
if name: newname = name.encode()
else: return None
return self.lib.lammps_extract_global_datatype(self.lmp, name)
return self.lib.lammps_extract_global_datatype(self.lmp, newname)
# -------------------------------------------------------------------------
# extract global info
@ -904,7 +905,7 @@ class lammps(object):
else:
veclen = 1
if name: name = name.encode()
if name: newname = name.encode()
else: return None
if dtype == LAMMPS_INT:
@ -922,7 +923,7 @@ class lammps(object):
else:
target_type = None
ptr = self.lib.lammps_extract_global(self.lmp, name)
ptr = self.lib.lammps_extract_global(self.lmp, newname)
if ptr:
if dtype == LAMMPS_STRING:
return ptr.decode('utf-8')
@ -940,6 +941,8 @@ class lammps(object):
def extract_pair_dimension(self, name):
"""Retrieve pair style property dimensionality from LAMMPS
.. versionadded:: TBD
This is a wrapper around the :cpp:func:`lammps_extract_pair_dimension`
function of the C-library interface. The list of supported keywords
depends on the pair style. This function returns ``None`` if the keyword
@ -951,10 +954,10 @@ class lammps(object):
:rtype: int
"""
if name:
name = name.encode()
newname = name.encode()
else:
return None
dim = self.lib.lammps_extract_pair_dimension(self.lmp, name)
dim = self.lib.lammps_extract_pair_dimension(self.lmp, newname)
if dim < 0:
return None
@ -967,6 +970,8 @@ class lammps(object):
def extract_pair(self, name):
"""Extract pair style data from LAMMPS.
.. versionadded:: TBD
This is a wrapper around the :cpp:func:`lammps_extract_pair` function
of the C-library interface. Since there are no pointers in Python, this
method will - unlike the C function - return the value or a list of
@ -982,7 +987,7 @@ class lammps(object):
"""
if name:
name = name.encode()
newname = name.encode()
else:
return None
@ -999,7 +1004,7 @@ class lammps(object):
return None
ntypes = self.extract_setting('ntypes')
ptr = self.lib.lammps_extract_pair(self.lmp, name)
ptr = self.lib.lammps_extract_pair(self.lmp, newname)
if ptr:
if dim == 0:
return float(ptr[0])
@ -1061,9 +1066,9 @@ class lammps(object):
:return: data type of per-atom property (see :ref:`py_datatype_constants`)
:rtype: int
"""
if name: name = name.encode()
if name: newname = name.encode()
else: return None
return self.lib.lammps_extract_atom_datatype(self.lmp, name)
return self.lib.lammps_extract_atom_datatype(self.lmp, newname)
# -------------------------------------------------------------------------
# extract per-atom info
@ -1104,7 +1109,7 @@ class lammps(object):
if dtype == LAMMPS_AUTODETECT:
dtype = self.extract_atom_datatype(name)
if name: name = name.encode()
if name: newname = name.encode()
else: return None
if dtype == LAMMPS_INT:
@ -1121,7 +1126,7 @@ class lammps(object):
self.lib.lammps_extract_atom.restype = POINTER(POINTER(c_int64))
else: return None
ptr = self.lib.lammps_extract_atom(self.lmp, name)
ptr = self.lib.lammps_extract_atom(self.lmp, newname)
if ptr: return ptr
else: return None
@ -1149,47 +1154,47 @@ class lammps(object):
:return: requested data as scalar, pointer to 1d or 2d double array, or None
:rtype: c_double, ctypes.POINTER(c_double), ctypes.POINTER(ctypes.POINTER(c_double)), or NoneType
"""
if cid: cid = cid.encode()
if cid: newcid = cid.encode()
else: return None
if ctype == LMP_TYPE_SCALAR:
if cstyle == LMP_STYLE_GLOBAL:
self.lib.lammps_extract_compute.restype = POINTER(c_double)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,cid,cstyle,ctype)
ptr = self.lib.lammps_extract_compute(self.lmp,newcid,cstyle,ctype)
return ptr[0]
elif cstyle == LMP_STYLE_ATOM:
return None
elif cstyle == LMP_STYLE_LOCAL:
self.lib.lammps_extract_compute.restype = POINTER(c_int)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,cid,cstyle,ctype)
ptr = self.lib.lammps_extract_compute(self.lmp,newcid,cstyle,ctype)
return ptr[0]
elif ctype == LMP_TYPE_VECTOR:
self.lib.lammps_extract_compute.restype = POINTER(c_double)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,cid,cstyle,ctype)
ptr = self.lib.lammps_extract_compute(self.lmp,newcid,cstyle,ctype)
return ptr
elif ctype == LMP_TYPE_ARRAY:
self.lib.lammps_extract_compute.restype = POINTER(POINTER(c_double))
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,cid,cstyle,ctype)
ptr = self.lib.lammps_extract_compute(self.lmp,newcid,cstyle,ctype)
return ptr
elif ctype == LMP_SIZE_COLS:
if cstyle == LMP_STYLE_GLOBAL or cstyle == LMP_STYLE_ATOM or cstyle == LMP_STYLE_LOCAL:
self.lib.lammps_extract_compute.restype = POINTER(c_int)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,cid,cstyle,ctype)
ptr = self.lib.lammps_extract_compute(self.lmp,newcid,cstyle,ctype)
return ptr[0]
elif ctype == LMP_SIZE_VECTOR or ctype == LMP_SIZE_ROWS:
if cstyle == LMP_STYLE_GLOBAL or cstyle == LMP_STYLE_LOCAL:
self.lib.lammps_extract_compute.restype = POINTER(c_int)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_compute(self.lmp,cid,cstyle,ctype)
ptr = self.lib.lammps_extract_compute(self.lmp,newcid,cstyle,ctype)
return ptr[0]
return None
@ -1234,21 +1239,21 @@ class lammps(object):
:rtype: c_double, ctypes.POINTER(c_double), ctypes.POINTER(ctypes.POINTER(c_double)), or NoneType
"""
if fid: fid = fid.encode()
if fid: newfid = fid.encode()
else: return None
if fstyle == LMP_STYLE_GLOBAL:
if ftype in (LMP_TYPE_SCALAR, LMP_TYPE_VECTOR, LMP_TYPE_ARRAY):
self.lib.lammps_extract_fix.restype = POINTER(c_double)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,fid,fstyle,ftype,nrow,ncol)
ptr = self.lib.lammps_extract_fix(self.lmp,newfid,fstyle,ftype,nrow,ncol)
result = ptr[0]
self.lib.lammps_free(ptr)
return result
elif ftype in (LMP_SIZE_VECTOR, LMP_SIZE_ROWS, LMP_SIZE_COLS):
self.lib.lammps_extract_fix.restype = POINTER(c_int)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,fid,fstyle,ftype,nrow,ncol)
ptr = self.lib.lammps_extract_fix(self.lmp,newfid,fstyle,ftype,nrow,ncol)
return ptr[0]
else:
return None
@ -1263,7 +1268,7 @@ class lammps(object):
else:
return None
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,fid,fstyle,ftype,nrow,ncol)
ptr = self.lib.lammps_extract_fix(self.lmp,newfid,fstyle,ftype,nrow,ncol)
if ftype == LMP_SIZE_COLS:
return ptr[0]
else:
@ -1279,7 +1284,7 @@ class lammps(object):
else:
return None
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_fix(self.lmp,fid,fstyle,ftype,nrow,ncol)
ptr = self.lib.lammps_extract_fix(self.lmp,newfid,fstyle,ftype,nrow,ncol)
if ftype in (LMP_TYPE_VECTOR, LMP_TYPE_ARRAY):
return ptr
else:
@ -1320,15 +1325,16 @@ class lammps(object):
:return: the requested data
:rtype: c_double, (c_double), or NoneType
"""
if name: name = name.encode()
if name: newname = name.encode()
else: return None
if group: group = group.encode()
if group: newgroup = group.encode()
else: newgroup = None
if vartype is None :
vartype = self.lib.lammps_extract_variable_datatype(self.lmp, name)
vartype = self.lib.lammps_extract_variable_datatype(self.lmp, newname)
if vartype == LMP_VAR_EQUAL:
self.lib.lammps_extract_variable.restype = POINTER(c_double)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
ptr = self.lib.lammps_extract_variable(self.lmp, newname, newgroup)
if ptr: result = ptr[0]
else: return None
self.lib.lammps_free(ptr)
@ -1338,7 +1344,7 @@ class lammps(object):
result = (c_double*nlocal)()
self.lib.lammps_extract_variable.restype = POINTER(c_double)
with ExceptionCheck(self):
ptr = self.lib.lammps_extract_variable(self.lmp,name,group)
ptr = self.lib.lammps_extract_variable(self.lmp, newname, newgroup)
if ptr:
for i in range(nlocal): result[i] = ptr[i]
self.lib.lammps_free(ptr)
@ -1347,27 +1353,27 @@ class lammps(object):
elif vartype == LMP_VAR_VECTOR :
nvector = 0
self.lib.lammps_extract_variable.restype = POINTER(c_int)
ptr = self.lib.lammps_extract_variable(self.lmp,name,
ptr = self.lib.lammps_extract_variable(self.lmp, newname,
'LMP_SIZE_VECTOR'.encode())
if ptr :
nvector = ptr[0]
self.lib.lammps_free(ptr)
else :
else:
return None
self.lib.lammps_extract_variable.restype = POINTER(c_double)
result = (c_double*nvector)()
values = self.lib.lammps_extract_variable(self.lmp,name,group)
values = self.lib.lammps_extract_variable(self.lmp, newname, newgroup)
if values :
for i in range(nvector) :
result[i] = values[i]
# do NOT free the values pointer (points to internal vector data)
return result
else :
else:
return None
elif vartype == LMP_VAR_STRING :
self.lib.lammps_extract_variable.restype = c_char_p
with ExceptionCheck(self) :
ptr = self.lib.lammps_extract_variable(self.lmp, name, group)
ptr = self.lib.lammps_extract_variable(self.lmp, newname, newgroup)
return ptr.decode('utf-8')
return None
@ -1398,12 +1404,12 @@ class lammps(object):
:return: either 0 on success or -1 on failure
:rtype: int
"""
if name: name = name.encode()
if name: newname = name.encode()
else: return -1
if value: value = str(value).encode()
if value: newvalue = str(value).encode()
else: return -1
with ExceptionCheck(self):
return self.lib.lammps_set_variable(self.lmp,name,value)
return self.lib.lammps_set_variable(self.lmp, newname, newvalue)
# -------------------------------------------------------------------------
@ -1422,12 +1428,12 @@ class lammps(object):
:return: either 0 on success or -1 on failure
:rtype: int
"""
if name: name = name.encode()
if name: newname = name.encode()
else: return -1
if value: value = str(value).encode()
if value: newvalue = str(value).encode()
else: return -1
with ExceptionCheck(self):
return self.lib.lammps_set_string_variable(self.lmp,name,value)
return self.lib.lammps_set_string_variable(self.lmp,newname,newvalue)
# -------------------------------------------------------------------------
@ -1446,10 +1452,10 @@ class lammps(object):
:return: either 0 on success or -1 on failure
:rtype: int
"""
if name: name = name.encode()
if name: newname = name.encode()
else: return -1
with ExceptionCheck(self):
return self.lib.lammps_set_internal_variable(self.lmp,name,value)
return self.lib.lammps_set_internal_variable(self.lmp,newname,value)
# -------------------------------------------------------------------------
@ -1463,15 +1469,16 @@ class lammps(object):
# e.g. for Python list or NumPy or ctypes
def gather_atoms(self,name,dtype,count):
if name: name = name.encode()
if name: newname = name.encode()
else: newname = None
natoms = self.get_natoms()
with ExceptionCheck(self):
if dtype == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather_atoms(self.lmp,name,dtype,count,data)
self.lib.lammps_gather_atoms(self.lmp,newname,dtype,count,data)
elif dtype == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather_atoms(self.lmp,name,dtype,count,data)
self.lib.lammps_gather_atoms(self.lmp,newname,dtype,count,data)
else:
return None
return data
@ -1479,28 +1486,30 @@ class lammps(object):
# -------------------------------------------------------------------------
def gather_atoms_concat(self,name,dtype,count):
if name: name = name.encode()
if name: newname = name.encode()
else: newname = None
natoms = self.get_natoms()
with ExceptionCheck(self):
if dtype == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather_atoms_concat(self.lmp,name,dtype,count,data)
self.lib.lammps_gather_atoms_concat(self.lmp,newname,dtype,count,data)
elif dtype == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather_atoms_concat(self.lmp,name,dtype,count,data)
self.lib.lammps_gather_atoms_concat(self.lmp,newname,dtype,count,data)
else:
return None
return data
def gather_atoms_subset(self,name,dtype,count,ndata,ids):
if name: name = name.encode()
if name: newname = name.encode()
else: newname = None
with ExceptionCheck(self):
if dtype == 0:
data = ((count*ndata)*c_int)()
self.lib.lammps_gather_atoms_subset(self.lmp,name,dtype,count,ndata,ids,data)
self.lib.lammps_gather_atoms_subset(self.lmp,newname,dtype,count,ndata,ids,data)
elif dtype == 1:
data = ((count*ndata)*c_double)()
self.lib.lammps_gather_atoms_subset(self.lmp,name,dtype,count,ndata,ids,data)
self.lib.lammps_gather_atoms_subset(self.lmp,newname,dtype,count,ndata,ids,data)
else:
return None
return data
@ -1517,16 +1526,18 @@ class lammps(object):
# e.g. for Python list or NumPy or ctypes
def scatter_atoms(self,name,dtype,count,data):
if name: name = name.encode()
if name: newname = name.encode()
else: newname = None
with ExceptionCheck(self):
self.lib.lammps_scatter_atoms(self.lmp,name,dtype,count,data)
self.lib.lammps_scatter_atoms(self.lmp,newname,dtype,count,data)
# -------------------------------------------------------------------------
def scatter_atoms_subset(self,name,dtype,count,ndata,ids,data):
if name: name = name.encode()
if name: newname = name.encode()
else: newname = None
with ExceptionCheck(self):
self.lib.lammps_scatter_atoms_subset(self.lmp,name,dtype,count,ndata,ids,data)
self.lib.lammps_scatter_atoms_subset(self.lmp,newname,dtype,count,ndata,ids,data)
# -------------------------------------------------------------------------
@ -1632,42 +1643,45 @@ class lammps(object):
# NOTE: need to ensure are converting to/from correct Python type
# e.g. for Python list or NumPy or ctypes
def gather(self,name,dtype,count):
if name: name = name.encode()
if name: newname = name.encode()
else: newname = None
natoms = self.get_natoms()
with ExceptionCheck(self):
if dtype == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather(self.lmp,name,dtype,count,data)
self.lib.lammps_gather(self.lmp,newname,dtype,count,data)
elif dtype == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather(self.lmp,name,dtype,count,data)
self.lib.lammps_gather(self.lmp,newname,dtype,count,data)
else:
return None
return data
def gather_concat(self,name,dtype,count):
if name: name = name.encode()
if name: newname = name.encode()
else: newname = None
natoms = self.get_natoms()
with ExceptionCheck(self):
if dtype == 0:
data = ((count*natoms)*c_int)()
self.lib.lammps_gather_concat(self.lmp,name,dtype,count,data)
self.lib.lammps_gather_concat(self.lmp,newname,dtype,count,data)
elif dtype == 1:
data = ((count*natoms)*c_double)()
self.lib.lammps_gather_concat(self.lmp,name,dtype,count,data)
self.lib.lammps_gather_concat(self.lmp,newname,dtype,count,data)
else:
return None
return data
def gather_subset(self,name,dtype,count,ndata,ids):
if name: name = name.encode()
if name: newname = name.encode()
else: newname = None
with ExceptionCheck(self):
if dtype == 0:
data = ((count*ndata)*c_int)()
self.lib.lammps_gather_subset(self.lmp,name,dtype,count,ndata,ids,data)
self.lib.lammps_gather_subset(self.lmp,newname,dtype,count,ndata,ids,data)
elif dtype == 1:
data = ((count*ndata)*c_double)()
self.lib.lammps_gather_subset(self.lmp,name,dtype,count,ndata,ids,data)
self.lib.lammps_gather_subset(self.lmp,newname,dtype,count,ndata,ids,data)
else:
return None
return data
@ -1682,14 +1696,16 @@ class lammps(object):
# e.g. for Python list or NumPy or ctypes
def scatter(self,name,dtype,count,data):
if name: name = name.encode()
if name: newname = name.encode()
else: newname = None
with ExceptionCheck(self):
self.lib.lammps_scatter(self.lmp,name,dtype,count,data)
self.lib.lammps_scatter(self.lmp,newname,dtype,count,data)
def scatter_subset(self,name,dtype,count,ndata,ids,data):
if name: name = name.encode()
if name: newname = name.encode()
else: newname = None
with ExceptionCheck(self):
self.lib.lammps_scatter_subset(self.lmp,name,dtype,count,ndata,ids,data)
self.lib.lammps_scatter_subset(self.lmp,newname,dtype,count,ndata,ids,data)
# -------------------------------------------------------------------------
@ -2442,9 +2458,9 @@ class lammps(object):
:rtype: int
"""
style = style.encode()
newstyle = style.encode()
exact = int(exact)
idx = self.lib.lammps_find_pair_neighlist(self.lmp, style, exact, nsub, reqid)
idx = self.lib.lammps_find_pair_neighlist(self.lmp, newstyle, exact, nsub, reqid)
return idx
# -------------------------------------------------------------------------
@ -2465,8 +2481,8 @@ class lammps(object):
:rtype: int
"""
fixid = fixid.encode()
idx = self.lib.lammps_find_fix_neighlist(self.lmp, fixid, reqid)
newfixid = fixid.encode()
idx = self.lib.lammps_find_fix_neighlist(self.lmp, newfixid, reqid)
return idx
# -------------------------------------------------------------------------
@ -2488,6 +2504,6 @@ class lammps(object):
:rtype: int
"""
computeid = computeid.encode()
idx = self.lib.lammps_find_compute_neighlist(self.lmp, computeid, reqid)
newcomputeid = computeid.encode()
idx = self.lib.lammps_find_compute_neighlist(self.lmp, newcomputeid, reqid)
return idx