Simplify access to system, comm, variable and atom data

This commit is contained in:
Richard Berger
2016-05-26 16:52:36 -04:00
parent 171e131878
commit e7262acc46

View File

@ -17,6 +17,7 @@ import sys, traceback, types
from ctypes import *
from os.path import dirname, abspath, join
from inspect import getsourcefile
from collections import namedtuple
import os
import select
import re
@ -297,31 +298,94 @@ class OutputCapture(object):
return self.read_pipe(self.stdout_pipe_read)
class LammpsVariable(object):
class Variable(object):
def __init__(self, lammps_wrapper_instance, name, style, definition):
self.lmp = lammps_wrapper_instance
self.name = name
self.style = style
self.definition = definition
self.definition = definition.split()
@property
def value(self):
return float(self.lmp.print('"${%s}"' % self.name))
value = self.lmp.print('"${%s}"' % self.name).strip()
try:
return float(value)
except ValueError:
return value
class AtomList(object):
def __init__(self, lammps_wrapper_instance):
self.lmp = lammps_wrapper_instance
self.natoms = self.lmp.system.natoms
def __getitem__(self, index):
return Atom(self.lmp, index+1)
class Atom(object):
def __init__(self, lammps_wrapper_instance, index):
self.lmp = lammps_wrapper_instance
self.index = index
@property
def id(self):
return int(self.lmp.eval("id[%d]" % self.index))
@property
def type(self):
return int(self.lmp.eval("type[%d]" % self.index))
@property
def mol(self):
return self.lmp.eval("mol[%d]" % self.index)
@property
def mass(self):
return self.lmp.eval("mass[%d]" % self.index)
@property
def position(self):
return (self.lmp.eval("x[%d]" % self.index),
self.lmp.eval("y[%d]" % self.index),
self.lmp.eval("z[%d]" % self.index))
@property
def velocity(self):
return (self.lmp.eval("vx[%d]" % self.index),
self.lmp.eval("vy[%d]" % self.index),
self.lmp.eval("vz[%d]" % self.index))
@property
def force(self):
return (self.lmp.eval("fx[%d]" % self.index),
self.lmp.eval("fy[%d]" % self.index),
self.lmp.eval("fz[%d]" % self.index))
@property
def charge(self):
return self.lmp.eval("q[%d]" % self.index)
class LammpsWrapper(object):
def __init__(self, lmp):
self.lmp = lmp
@property
def atoms(self):
return AtomList(self)
@property
def system(self):
output = self.info("system")
return self._parse_info_system(output)
d = self._parse_info_system(output)
return namedtuple('System', d.keys())(*d.values())
@property
def communication(self):
output = self.info("communication")
return self._parse_info_communication(output)
d = self._parse_info_communication(output)
return namedtuple('Communication', d.keys())(*d.values())
@property
def computes(self):
@ -348,11 +412,16 @@ class LammpsWrapper(object):
output = self.info("variables")
vars = {}
for v in self._parse_element_list(output):
vars[v['name']] = LammpsVariable(self, v['name'], v['style'], v['def'])
vars[v['name']] = Variable(self, v['name'], v['style'], v['def'])
return vars
def eval(self, expr):
return float(self.print('"$(%s)"' % expr))
value = self.print('"$(%s)"' % expr).strip()
try:
return float(value)
except ValueError:
return value
def _split_values(self, line):
return [x.strip() for x in line.split(',')]
@ -361,7 +430,7 @@ class LammpsWrapper(object):
return [x.strip() for x in value.split('=')]
def _parse_info_system(self, output):
lines = output.splitlines()[6:-2]
lines = output[6:-2]
system = {}
for line in lines:
@ -373,8 +442,8 @@ class LammpsWrapper(object):
system['atom_map'] = self._get_pair(line)[1]
elif line.startswith("Atoms"):
parts = self._split_values(line)
system['natoms'] = self._get_pair(parts[0])[1]
system['ntypes'] = self._get_pair(parts[1])[1]
system['natoms'] = int(self._get_pair(parts[0])[1])
system['ntypes'] = int(self._get_pair(parts[1])[1])
system['style'] = self._get_pair(parts[2])[1]
elif line.startswith("Kspace style"):
system['kspace_style'] = self._get_pair(line)[1]
@ -399,7 +468,7 @@ class LammpsWrapper(object):
return system
def _parse_info_communication(self, output):
lines = output.splitlines()[6:-3]
lines = output[6:-3]
comm = {}
for line in lines:
@ -420,7 +489,7 @@ class LammpsWrapper(object):
return comm
def _parse_element_list(self, output):
lines = output.splitlines()[6:-3]
lines = output[6:-3]
elements = []
for line in lines:
@ -432,7 +501,7 @@ class LammpsWrapper(object):
return elements
def _parse_groups(self, output):
lines = output.splitlines()[6:-3]
lines = output[6:-3]
groups = []
group_pattern = re.compile(r"(?P<name>.+) \((?P<type>.+)\)")
@ -449,7 +518,13 @@ class LammpsWrapper(object):
with OutputCapture() as capture:
self.lmp.command(' '.join(cmd_args))
output = capture.output
return output
lines = output.splitlines()
if len(lines) > 1:
return lines
elif len(lines) == 1:
return lines[0]
return None
return handler