update lammps python wrapper to support passing in a communicator from mpi4py

This commit is contained in:
Axel Kohlmeyer
2015-12-07 20:59:20 -05:00
parent d21bbb3efb
commit cacf29bc56

View File

@ -18,8 +18,19 @@ from ctypes import *
from os.path import dirname,abspath,join from os.path import dirname,abspath,join
from inspect import getsourcefile from inspect import getsourcefile
class lammps: class lammps:
def __init__(self,name="",cmdargs=None,ptr=None): # detect, if we use a version of mpi4py that can pass a communicator
has_mpi4py_v2 = False
try:
from mpi4py import MPI
from mpi4py import __version__ as mpi4py_version
if mpi4py_version.split('.')[0] == '2':
has_mpi4py_v2 = True
except:
pass
def __init__(self,name="",cmdargs=None,ptr=None,comm=None):
# determine module location # determine module location
modpath = dirname(abspath(getsourcefile(lambda:0))) modpath = dirname(abspath(getsourcefile(lambda:0)))
@ -37,6 +48,7 @@ class lammps:
# if no ptr provided, create an instance of LAMMPS # if no ptr provided, create an instance of LAMMPS
# don't know how to pass an MPI communicator from PyPar # don't know how to pass an MPI communicator from PyPar
# but we can pass an MPI communicator from mpi4py v2.0.0 and later
# no_mpi call lets LAMMPS use MPI_COMM_WORLD # no_mpi call lets LAMMPS use MPI_COMM_WORLD
# cargs = array of C strings from args # cargs = array of C strings from args
# if ptr, then are embedding Python in LAMMPS input script # if ptr, then are embedding Python in LAMMPS input script
@ -44,6 +56,35 @@ class lammps:
# just convert it to ctypes ptr and store in self.lmp # just convert it to ctypes ptr and store in self.lmp
if not ptr: if not ptr:
# with mpi4py we can pass communicators into the LAMMPS object but
# we need to adjust type for the MPI communicator object depending
# on whether it is an int (like MPICH) or a void* (like OpenMPI)
if lammps.has_mpi4py_v2 and comm != None:
if lammps.MPI._sizeof(lammps.MPI.Comm) == sizeof(c_int):
MPI_Comm = c_int
else:
MPI_Comm = c_void_p
narg = 0
cargs = 0
if cmdargs:
cmdargs.insert(0,"lammps.py")
narg = len(cmdargs)
cargs = (c_char_p*narg)(*cmdargs)
self.lib.lammps_open.argtypes = [c_int, c_char_p*narg, \
MPI_Comm, c_void_p()]
else:
self.lib.lammps_open.argtypes = [c_int, c_int, \
MPI_Comm, c_void_p()]
self.lib.lammps_open.restype = None
self.opened = 1
self.lmp = c_void_p()
comm_ptr = lammps.MPI._addressof(comm)
comm_val = MPI_Comm.from_address(comm_ptr)
self.lib.lammps_open(narg,cargs,comm_val,byref(self.lmp))
else:
self.opened = 1 self.opened = 1
if cmdargs: if cmdargs:
cmdargs.insert(0,"lammps.py") cmdargs.insert(0,"lammps.py")