diff --git a/python/lammps.py b/python/lammps.py index c895037c00..62ed272d4b 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -46,17 +46,15 @@ class MPIAbortException(Exception): def __str__(self): return repr(self.message) - class lammps(object): # detect if Python is using version of mpi4py that can pass a communicator - has_mpi4py_v2 = False + has_mpi4py = False try: from mpi4py import MPI from mpi4py import __version__ as mpi4py_version - if mpi4py_version.split('.')[0] == '2': - has_mpi4py_v2 = True + if mpi4py_version.split('.')[0] in ['2','3']: has_mpi4py = True except: pass @@ -111,7 +109,9 @@ class lammps(object): # need to adjust for type of MPI communicator object # allow for int (like MPICH) or void* (like OpenMPI) - if lammps.has_mpi4py_v2 and comm != None: + if comm: + if not lammps.has_mpi4py: + raise Exception('Python mpi4py version is not 2 or 3') if lammps.MPI._sizeof(lammps.MPI.Comm) == sizeof(c_int): MPI_Comm = c_int else: