implement numpy wrapper for setting per-atom energy. virial not yet implemented

This commit is contained in:
Axel Kohlmeyer
2021-07-22 16:59:04 -04:00
parent 324ae3181b
commit bf8bde5b03
4 changed files with 60 additions and 10 deletions

View File

@ -1830,7 +1830,7 @@ class lammps(object):
# -------------------------------------------------------------------------
def fix_external_set_energy_peratom(self, fix_id, eatom):
"""Set the global energy contribution for a fix external instance with the given ID.
"""Set the per-atom energy contribution for a fix external instance with the given ID.
This is a wrapper around the :cpp:func:`lammps_fix_external_set_energy_peratom` function
of the C-library interface.
@ -1848,15 +1848,15 @@ class lammps(object):
# -------------------------------------------------------------------------
def fix_external_set_virial_peratom(self, fix_id, virial):
"""Set the global virial contribution for a fix external instance with the given ID.
def fix_external_set_virial_peratom(self, fix_id, vatom):
"""Set the per-atom virial contribution for a fix external instance with the given ID.
This is a wrapper around the :cpp:func:`lammps_fix_external_set_virial_peratom` function
of the C-library interface.
:param fix_id: Fix-ID of a fix external instance
:type: string
:param eng: list of natoms lists with 6 floating point numbers to be added by fix external
:param vatom: list of natoms lists with 6 floating point numbers to be added by fix external
:type: float
"""
@ -1864,14 +1864,14 @@ class lammps(object):
nlocal = self.extract_setting('nlocal')
vbuf = (c_double * 6)
vptr = POINTER(c_double)
cvirial = (vptr * nlocal)()
c_virial = (vptr * nlocal)()
for i in range(nlocal):
cvirial[i] = vbuf()
c_virial[i] = vbuf()
for j in range(6):
cvirial[i][j] = virial[i][j]
c_virial[i][j] = vatom[i][j]
with ExceptionCheck(self):
return self.lib.lammps_fix_external_set_virial_peratom(self.lmp, fix_id.encode(), cvirial)
return self.lib.lammps_fix_external_set_virial_peratom(self.lmp, fix_id.encode(), c_virial)
# -------------------------------------------------------------------------
def fix_external_set_vector_length(self, fix_id, length):