diff --git a/python/lammps.py b/python/lammps.py index 01c20bc3b5..5d8b10873a 100644 --- a/python/lammps.py +++ b/python/lammps.py @@ -18,8 +18,19 @@ from ctypes import * from os.path import dirname,abspath,join from inspect import getsourcefile + class lammps: - def __init__(self,name="",cmdargs=None,ptr=None): + # detect, if we use a version of mpi4py that can pass a communicator + has_mpi4py_v2 = False + try: + from mpi4py import MPI + from mpi4py import __version__ as mpi4py_version + if mpi4py_version.split('.')[0] == '2': + has_mpi4py_v2 = True + except: + pass + + def __init__(self,name="",cmdargs=None,ptr=None,comm=None): # determine module location modpath = dirname(abspath(getsourcefile(lambda:0))) @@ -37,6 +48,7 @@ class lammps: # if no ptr provided, create an instance of LAMMPS # don't know how to pass an MPI communicator from PyPar + # but we can pass an MPI communicator from mpi4py v2.0.0 and later # no_mpi call lets LAMMPS use MPI_COMM_WORLD # cargs = array of C strings from args # if ptr, then are embedding Python in LAMMPS input script @@ -44,18 +56,47 @@ class lammps: # just convert it to ctypes ptr and store in self.lmp if not ptr: - self.opened = 1 - if cmdargs: - cmdargs.insert(0,"lammps.py") - narg = len(cmdargs) - cargs = (c_char_p*narg)(*cmdargs) + # with mpi4py we can pass communicators into the LAMMPS object but + # we need to adjust type for the MPI communicator object depending + # on whether it is an int (like MPICH) or a void* (like OpenMPI) + if lammps.has_mpi4py_v2 and comm != None: + if lammps.MPI._sizeof(lammps.MPI.Comm) == sizeof(c_int): + MPI_Comm = c_int + else: + MPI_Comm = c_void_p + + narg = 0 + cargs = 0 + if cmdargs: + cmdargs.insert(0,"lammps.py") + narg = len(cmdargs) + cargs = (c_char_p*narg)(*cmdargs) + self.lib.lammps_open.argtypes = [c_int, c_char_p*narg, \ + MPI_Comm, c_void_p()] + else: + self.lib.lammps_open.argtypes = [c_int, c_int, \ + MPI_Comm, c_void_p()] + + self.lib.lammps_open.restype = None + self.opened = 1 self.lmp = c_void_p() - self.lib.lammps_open_no_mpi(narg,cargs,byref(self.lmp)) + comm_ptr = lammps.MPI._addressof(comm) + comm_val = MPI_Comm.from_address(comm_ptr) + self.lib.lammps_open(narg,cargs,comm_val,byref(self.lmp)) + else: - self.lmp = c_void_p() - self.lib.lammps_open_no_mpi(0,None,byref(self.lmp)) - # could use just this if LAMMPS lib interface supported it - # self.lmp = self.lib.lammps_open_no_mpi(0,None) + self.opened = 1 + if cmdargs: + cmdargs.insert(0,"lammps.py") + narg = len(cmdargs) + cargs = (c_char_p*narg)(*cmdargs) + self.lmp = c_void_p() + self.lib.lammps_open_no_mpi(narg,cargs,byref(self.lmp)) + else: + self.lmp = c_void_p() + self.lib.lammps_open_no_mpi(0,None,byref(self.lmp)) + # could use just this if LAMMPS lib interface supported it + # self.lmp = self.lib.lammps_open_no_mpi(0,None) else: self.opened = 0 # magic to convert ptr to ctypes ptr