update lammps python wrapper to support passing in a communicator from mpi4py
This commit is contained in:
@ -18,8 +18,19 @@ from ctypes import *
|
|||||||
from os.path import dirname,abspath,join
|
from os.path import dirname,abspath,join
|
||||||
from inspect import getsourcefile
|
from inspect import getsourcefile
|
||||||
|
|
||||||
|
|
||||||
class lammps:
|
class lammps:
|
||||||
def __init__(self,name="",cmdargs=None,ptr=None):
|
# detect, if we use a version of mpi4py that can pass a communicator
|
||||||
|
has_mpi4py_v2 = False
|
||||||
|
try:
|
||||||
|
from mpi4py import MPI
|
||||||
|
from mpi4py import __version__ as mpi4py_version
|
||||||
|
if mpi4py_version.split('.')[0] == '2':
|
||||||
|
has_mpi4py_v2 = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self,name="",cmdargs=None,ptr=None,comm=None):
|
||||||
|
|
||||||
# determine module location
|
# determine module location
|
||||||
modpath = dirname(abspath(getsourcefile(lambda:0)))
|
modpath = dirname(abspath(getsourcefile(lambda:0)))
|
||||||
@ -37,6 +48,7 @@ class lammps:
|
|||||||
|
|
||||||
# if no ptr provided, create an instance of LAMMPS
|
# if no ptr provided, create an instance of LAMMPS
|
||||||
# don't know how to pass an MPI communicator from PyPar
|
# don't know how to pass an MPI communicator from PyPar
|
||||||
|
# but we can pass an MPI communicator from mpi4py v2.0.0 and later
|
||||||
# no_mpi call lets LAMMPS use MPI_COMM_WORLD
|
# no_mpi call lets LAMMPS use MPI_COMM_WORLD
|
||||||
# cargs = array of C strings from args
|
# cargs = array of C strings from args
|
||||||
# if ptr, then are embedding Python in LAMMPS input script
|
# if ptr, then are embedding Python in LAMMPS input script
|
||||||
@ -44,6 +56,35 @@ class lammps:
|
|||||||
# just convert it to ctypes ptr and store in self.lmp
|
# just convert it to ctypes ptr and store in self.lmp
|
||||||
|
|
||||||
if not ptr:
|
if not ptr:
|
||||||
|
# with mpi4py we can pass communicators into the LAMMPS object but
|
||||||
|
# we need to adjust type for the MPI communicator object depending
|
||||||
|
# on whether it is an int (like MPICH) or a void* (like OpenMPI)
|
||||||
|
if lammps.has_mpi4py_v2 and comm != None:
|
||||||
|
if lammps.MPI._sizeof(lammps.MPI.Comm) == sizeof(c_int):
|
||||||
|
MPI_Comm = c_int
|
||||||
|
else:
|
||||||
|
MPI_Comm = c_void_p
|
||||||
|
|
||||||
|
narg = 0
|
||||||
|
cargs = 0
|
||||||
|
if cmdargs:
|
||||||
|
cmdargs.insert(0,"lammps.py")
|
||||||
|
narg = len(cmdargs)
|
||||||
|
cargs = (c_char_p*narg)(*cmdargs)
|
||||||
|
self.lib.lammps_open.argtypes = [c_int, c_char_p*narg, \
|
||||||
|
MPI_Comm, c_void_p()]
|
||||||
|
else:
|
||||||
|
self.lib.lammps_open.argtypes = [c_int, c_int, \
|
||||||
|
MPI_Comm, c_void_p()]
|
||||||
|
|
||||||
|
self.lib.lammps_open.restype = None
|
||||||
|
self.opened = 1
|
||||||
|
self.lmp = c_void_p()
|
||||||
|
comm_ptr = lammps.MPI._addressof(comm)
|
||||||
|
comm_val = MPI_Comm.from_address(comm_ptr)
|
||||||
|
self.lib.lammps_open(narg,cargs,comm_val,byref(self.lmp))
|
||||||
|
|
||||||
|
else:
|
||||||
self.opened = 1
|
self.opened = 1
|
||||||
if cmdargs:
|
if cmdargs:
|
||||||
cmdargs.insert(0,"lammps.py")
|
cmdargs.insert(0,"lammps.py")
|
||||||
|
|||||||
Reference in New Issue
Block a user