diff --git a/python/lammps/core.py b/python/lammps/core.py index c4d5bff591..8610e85811 100644 --- a/python/lammps/core.py +++ b/python/lammps/core.py @@ -416,9 +416,16 @@ class lammps(object): # shut-down LAMMPS instance def __del__(self): - if self.lmp and self.opened: - self.lib.lammps_close(self.lmp) - self.opened = 0 + self.close() + + # ------------------------------------------------------------------------- + # context manager implementation + + def __enter__(self): + return self + + def __exit__(self, ex_type, ex_value, ex_traceback): + self.close() # ------------------------------------------------------------------------- @@ -445,7 +452,8 @@ class lammps(object): This is a wrapper around the :cpp:func:`lammps_close` function of the C-library interface. """ - if self.opened: self.lib.lammps_close(self.lmp) + if self.lmp and self.opened: + self.lib.lammps_close(self.lmp) self.lmp = None self.opened = 0 @@ -454,9 +462,7 @@ class lammps(object): def finalize(self): """Shut down the MPI communication through the library interface by calling :cpp:func:`lammps_finalize`. """ - if self.opened: self.lib.lammps_close(self.lmp) - self.lmp = None - self.opened = 0 + self.close() self.lib.lammps_finalize() # ------------------------------------------------------------------------- diff --git a/unittest/python/python-open.py b/unittest/python/python-open.py index ad4cc24a24..328745ded0 100644 --- a/unittest/python/python-open.py +++ b/unittest/python/python-open.py @@ -50,6 +50,16 @@ class PythonOpen(unittest.TestCase): self.assertIsNot(lmp.lmp,None) self.assertEqual(lmp.opened,1) + def testContextManager(self): + """Automatically clean up LAMMPS instance""" + with lammps(name=self.machine) as lmp: + self.assertIsNot(lmp.lmp,None) + self.assertEqual(lmp.opened,1) + self.assertEqual(has_mpi and has_mpi4py,lmp.has_mpi4py) + self.assertEqual(has_mpi,lmp.has_mpi_support) + self.assertIsNone(lmp.lmp,None) + self.assertEqual(lmp.opened,0) + @unittest.skipIf(not (has_mpi and has_mpi4py),"Skipping MPI test since LAMMPS is not parallel or mpi4py is not found") def testWithMPI(self): from mpi4py import MPI