PyLammps: alternative OutputCapture using tempfiles

This commit is contained in:
Richard Berger
2022-02-03 19:49:52 -05:00
parent f550460ecd
commit 050ce421e9

View File

@ -20,47 +20,47 @@
from __future__ import print_function from __future__ import print_function
import io
import os import os
import re import re
import select import sys
import tempfile
from collections import namedtuple from collections import namedtuple
from .core import lammps from .core import lammps
# -------------------------------------------------------------------------
class OutputCapture(object): class OutputCapture(object):
""" Utility class to capture LAMMPS library output """ """ Utility class to capture LAMMPS library output """
def __init__(self): def __init__(self):
self.stdout_pipe_read, self.stdout_pipe_write = os.pipe() self.stdout_fd = sys.stdout.fileno()
self.stdout_fd = 1 self.captured_output = ""
def __enter__(self): def __enter__(self):
self.stdout = os.dup(self.stdout_fd) self.tmpfile = tempfile.TemporaryFile(mode='w+b')
os.dup2(self.stdout_pipe_write, self.stdout_fd)
sys.stdout.flush()
# make copy of original stdout
self.stdout_orig = os.dup(self.stdout_fd)
# replace stdout and redirect to temp file
os.dup2(self.tmpfile.fileno(), self.stdout_fd)
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
os.dup2(self.stdout, self.stdout_fd) os.dup2(self.stdout_orig, self.stdout_fd)
os.close(self.stdout) os.close(self.stdout_orig)
os.close(self.stdout_pipe_read) self.tmpfile.close()
os.close(self.stdout_pipe_write)
# check if we have more to read from the pipe
def more_data(self, pipe):
r, _, _ = select.select([pipe], [], [], 0)
return bool(r)
# read the whole pipe
def read_pipe(self, pipe):
out = ""
while self.more_data(pipe):
out += os.read(pipe, 1024).decode()
return out
@property @property
def output(self): def output(self):
return self.read_pipe(self.stdout_pipe_read) sys.stdout.flush()
self.tmpfile.flush()
self.tmpfile.seek(0, io.SEEK_SET)
self.captured_output = self.tmpfile.read().decode('utf-8')
return self.captured_output
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -109,9 +109,9 @@ class AtomList(object):
""" """
if index not in self._loaded: if index not in self._loaded:
if self.dimensions == 2: if self.dimensions == 2:
atom = Atom2D(self._pylmp, index + 1) atom = Atom2D(self._pylmp, index)
else: else:
atom = Atom(self._pylmp, index + 1) atom = Atom(self._pylmp, index)
self._loaded[index] = atom self._loaded[index] = atom
return self._loaded[index] return self._loaded[index]
@ -612,7 +612,8 @@ class PyLammps(object):
:getter: Returns an object with properties storing the current system state :getter: Returns an object with properties storing the current system state
:type: namedtuple :type: namedtuple
""" """
output = self.info("system") output = self.lmp_info("system")
output = output[output.index("System information:")+1:]
d = self._parse_info_system(output) d = self._parse_info_system(output)
return namedtuple('System', d.keys())(*d.values()) return namedtuple('System', d.keys())(*d.values())
@ -624,7 +625,8 @@ class PyLammps(object):
:getter: Returns an object with properties storing the current communication state :getter: Returns an object with properties storing the current communication state
:type: namedtuple :type: namedtuple
""" """
output = self.info("communication") output = self.lmp_info("communication")
output = output[output.index("Communication information:")+1:]
d = self._parse_info_communication(output) d = self._parse_info_communication(output)
return namedtuple('Communication', d.keys())(*d.values()) return namedtuple('Communication', d.keys())(*d.values())
@ -636,7 +638,8 @@ class PyLammps(object):
:getter: Returns a list of computes that are currently active in this LAMMPS instance :getter: Returns a list of computes that are currently active in this LAMMPS instance
:type: list :type: list
""" """
output = self.info("computes") output = self.lmp_info("computes")
output = output[output.index("Compute information:")+1:]
return self._parse_element_list(output) return self._parse_element_list(output)
@property @property
@ -647,7 +650,8 @@ class PyLammps(object):
:getter: Returns a list of dumps that are currently active in this LAMMPS instance :getter: Returns a list of dumps that are currently active in this LAMMPS instance
:type: list :type: list
""" """
output = self.info("dumps") output = self.lmp_info("dumps")
output = output[output.index("Dump information:")+1:]
return self._parse_element_list(output) return self._parse_element_list(output)
@property @property
@ -658,7 +662,8 @@ class PyLammps(object):
:getter: Returns a list of fixes that are currently active in this LAMMPS instance :getter: Returns a list of fixes that are currently active in this LAMMPS instance
:type: list :type: list
""" """
output = self.info("fixes") output = self.lmp_info("fixes")
output = output[output.index("Fix information:")+1:]
return self._parse_element_list(output) return self._parse_element_list(output)
@property @property
@ -669,7 +674,8 @@ class PyLammps(object):
:getter: Returns a list of atom groups that are currently active in this LAMMPS instance :getter: Returns a list of atom groups that are currently active in this LAMMPS instance
:type: list :type: list
""" """
output = self.info("groups") output = self.lmp_info("groups")
output = output[output.index("Group information:")+1:]
return self._parse_groups(output) return self._parse_groups(output)
@property @property
@ -680,11 +686,12 @@ class PyLammps(object):
:getter: Returns a dictionary of all variables that are defined in this LAMMPS instance :getter: Returns a dictionary of all variables that are defined in this LAMMPS instance
:type: dict :type: dict
""" """
output = self.info("variables") output = self.lmp_info("variables")
vars = {} output = output[output.index("Variable information:")+1:]
variables = {}
for v in self._parse_element_list(output): for v in self._parse_element_list(output):
vars[v['name']] = Variable(self, v['name'], v['style'], v['def']) variables[v['name']] = Variable(self, v['name'], v['style'], v['def'])
return vars return variables
def eval(self, expr): def eval(self, expr):
""" """
@ -709,10 +716,9 @@ class PyLammps(object):
return [x.strip() for x in value.split('=')] return [x.strip() for x in value.split('=')]
def _parse_info_system(self, output): def _parse_info_system(self, output):
lines = output[5:-2]
system = {} system = {}
for line in lines: for line in output:
if line.startswith("Units"): if line.startswith("Units"):
system['units'] = self._get_pair(line)[1] system['units'] = self._get_pair(line)[1]
elif line.startswith("Atom style"): elif line.startswith("Atom style"):
@ -770,10 +776,9 @@ class PyLammps(object):
return system return system
def _parse_info_communication(self, output): def _parse_info_communication(self, output):
lines = output[5:-3]
comm = {} comm = {}
for line in lines: for line in output:
if line.startswith("MPI library"): if line.startswith("MPI library"):
comm['mpi_version'] = line.split(':')[1].strip() comm['mpi_version'] = line.split(':')[1].strip()
elif line.startswith("Comm style"): elif line.startswith("Comm style"):
@ -791,10 +796,10 @@ class PyLammps(object):
return comm return comm
def _parse_element_list(self, output): def _parse_element_list(self, output):
lines = output[5:-3]
elements = [] elements = []
for line in lines: for line in output:
if not line or (":" not in line): continue
element_info = self._split_values(line.split(':')[1].strip()) element_info = self._split_values(line.split(':')[1].strip())
element = {'name': element_info[0]} element = {'name': element_info[0]}
for key, value in [self._get_pair(x) for x in element_info[1:]]: for key, value in [self._get_pair(x) for x in element_info[1:]]:
@ -803,11 +808,10 @@ class PyLammps(object):
return elements return elements
def _parse_groups(self, output): def _parse_groups(self, output):
lines = output[5:-3]
groups = [] groups = []
group_pattern = re.compile(r"(?P<name>.+) \((?P<type>.+)\)") group_pattern = re.compile(r"(?P<name>.+) \((?P<type>.+)\)")
for line in lines: for line in output:
m = group_pattern.match(line.split(':')[1].strip()) m = group_pattern.match(line.split(':')[1].strip())
group = {'name': m.group('name'), 'type': m.group('type')} group = {'name': m.group('name'), 'type': m.group('type')}
groups.append(group) groups.append(group)
@ -830,6 +834,15 @@ class PyLammps(object):
'thermo_modify', 'thermo_style', 'timestep', 'undump', 'unfix', 'units', 'thermo_modify', 'thermo_style', 'timestep', 'undump', 'unfix', 'units',
'variable', 'velocity', 'write_restart'] + self.lmp.available_styles("command"))) 'variable', 'velocity', 'write_restart'] + self.lmp.available_styles("command")))
def lmp_info(self, s):
# skip anything before and after Info-Info-Info
# also skip timestamp line
output = self.__getattr__("info")(s)
indices = [index for index, line in enumerate(output) if line.startswith("Info-Info-Info-Info")]
start = indices[0]
end = indices[1]
return [line for line in output[start+2:end] if line]
def __getattr__(self, name): def __getattr__(self, name):
""" """
This method is where the Python 'magic' happens. If a method is not This method is where the Python 'magic' happens. If a method is not