correct lammps.extract_global() method for returned arrays which are returned as list

This commit is contained in:
Axel Kohlmeyer
2021-03-25 10:09:08 -04:00
committed by Richard Berger
parent 3c41c12dbc
commit e0fdd2ad89
2 changed files with 26 additions and 4 deletions

View File

@ -746,12 +746,21 @@ class lammps(object):
:type name: string :type name: string
:param dtype: data type of the returned data (see :ref:`py_datatype_constants`) :param dtype: data type of the returned data (see :ref:`py_datatype_constants`)
:type dtype: int, optional :type dtype: int, optional
:return: value of the property or None :return: value of the property or list of values or None
:rtype: int, float, or NoneType :rtype: int, float, list, or NoneType
""" """
if dtype == LAMMPS_AUTODETECT: if dtype == LAMMPS_AUTODETECT:
dtype = self.extract_global_datatype(name) dtype = self.extract_global_datatype(name)
# set length of vector for items that are not a scalar
vec_dict = { 'boxlo':3, 'boxhi':3, 'sublo':3, 'subhi':3,
'sublo_lambda':3, 'subhi_lambda':3, 'periodicity':3 }
if name in vec_dict:
veclen = vec_dict[name]
else:
veclen = 1
if name: name = name.encode() if name: name = name.encode()
else: return None else: return None
@ -770,10 +779,14 @@ class lammps(object):
ptr = self.lib.lammps_extract_global(self.lmp, name) ptr = self.lib.lammps_extract_global(self.lmp, name)
if ptr: if ptr:
return target_type(ptr[0]) if veclen > 1:
result = []
for i in range(0,veclen):
result.append(target_type(ptr[i]))
return result
else: return target_type(ptr[0])
return None return None
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# extract per-atom info datatype # extract per-atom info datatype

View File

@ -268,6 +268,15 @@ create_atoms 1 single &
self.assertEqual(self.lmp.extract_global("boxyhi"), 2.0) self.assertEqual(self.lmp.extract_global("boxyhi"), 2.0)
self.assertEqual(self.lmp.extract_global("boxzlo"), -3.0) self.assertEqual(self.lmp.extract_global("boxzlo"), -3.0)
self.assertEqual(self.lmp.extract_global("boxzhi"), 3.0) self.assertEqual(self.lmp.extract_global("boxzhi"), 3.0)
self.assertEqual(self.lmp.extract_global("boxlo"), [-1.0, -2.0, -3.0])
self.assertEqual(self.lmp.extract_global("boxhi"), [1.0, 2.0, 3.0])
self.assertEqual(self.lmp.extract_global("sublo"), [-1.0, -2.0, -3.0])
self.assertEqual(self.lmp.extract_global("subhi"), [1.0, 2.0, 3.0])
self.assertEqual(self.lmp.extract_global("periodicity"), [1,1,1])
# only valid for triclinic box
self.lmp.command("change_box all triclinic")
self.assertEqual(self.lmp.extract_global("sublo_lambda"), [0.0, 0.0, 0.0])
self.assertEqual(self.lmp.extract_global("subhi_lambda"), [1.0, 1.0, 1.0])
############################## ##############################
if __name__ == "__main__": if __name__ == "__main__":