Merge pull request #639 from rbberger/python_interface_improvements

Python interface improvements
This commit is contained in:
Steve Plimpton
2017-09-06 08:51:27 -06:00
committed by GitHub
2 changed files with 81 additions and 0 deletions

View File

@ -32,6 +32,13 @@ import select
import re
import sys
def get_ctypes_int(size):
if size == 4:
return c_int32
elif size == 8:
return c_int64
return c_int
class MPIAbortException(Exception):
def __init__(self, message):
self.message = message
@ -162,6 +169,14 @@ class lammps(object):
pythonapi.PyCObject_AsVoidPtr.argtypes = [py_object]
self.lmp = c_void_p(pythonapi.PyCObject_AsVoidPtr(ptr))
# optional numpy support (lazy loading)
self._numpy = None
# set default types
self.c_bigint = get_ctypes_int(self.extract_setting("bigint"))
self.c_tagint = get_ctypes_int(self.extract_setting("tagint"))
self.c_imageint = get_ctypes_int(self.extract_setting("imageint"))
def __del__(self):
if self.lmp and self.opened:
self.lib.lammps_close(self.lmp)
@ -236,6 +251,48 @@ class lammps(object):
ptr = self.lib.lammps_extract_atom(self.lmp,name)
return ptr
# extract lammps type byte sizes
def extract_setting(self, name):
if name: name = name.encode()
self.lib.lammps_extract_atom.restype = c_int
return int(self.lib.lammps_extract_setting(self.lmp,name))
@property
def numpy(self):
if not self._numpy:
import numpy as np
class LammpsNumpyWrapper:
def __init__(self, lmp):
self.lmp = lmp
def extract_atom_iarray(self, name, nelem, dim=1):
if dim == 1:
tmp = self.lmp.extract_atom(name, 0)
ptr = cast(tmp, POINTER(c_int * nelem))
else:
tmp = self.lmp.extract_atom(name, 1)
ptr = cast(tmp[0], POINTER(c_int * nelem * dim))
a = np.frombuffer(ptr.contents, dtype=np.intc)
a.shape = (nelem, dim)
return a
def extract_atom_darray(self, name, nelem, dim=1):
if dim == 1:
tmp = self.lmp.extract_atom(name, 2)
ptr = cast(tmp, POINTER(c_double * nelem))
else:
tmp = self.lmp.extract_atom(name, 3)
ptr = cast(tmp[0], POINTER(c_double * nelem * dim))
a = np.frombuffer(ptr.contents)
a.shape = (nelem, dim)
return a
self._numpy = LammpsNumpyWrapper(self)
return self._numpy
# extract compute info
def extract_compute(self,id,style,type):

View File

@ -37,6 +37,7 @@
#include "comm.h"
#include "memory.h"
#include "error.h"
#include "force.h"
using namespace LAMMPS_NS;
@ -370,6 +371,7 @@ void *lammps_extract_global(void *ptr, char *name)
if (strcmp(name,"nlocal") == 0) return (void *) &lmp->atom->nlocal;
if (strcmp(name,"nghost") == 0) return (void *) &lmp->atom->nghost;
if (strcmp(name,"nmax") == 0) return (void *) &lmp->atom->nmax;
if (strcmp(name,"ntypes") == 0) return (void *) &lmp->atom->ntypes;
if (strcmp(name,"ntimestep") == 0) return (void *) &lmp->update->ntimestep;
if (strcmp(name,"units") == 0) return (void *) lmp->update->unit_style;
@ -384,6 +386,28 @@ void *lammps_extract_global(void *ptr, char *name)
if (strcmp(name,"atime") == 0) return (void *) &lmp->update->atime;
if (strcmp(name,"atimestep") == 0) return (void *) &lmp->update->atimestep;
// global constants defined by units
if (strcmp(name,"boltz") == 0) return (void *) &lmp->force->boltz;
if (strcmp(name,"hplanck") == 0) return (void *) &lmp->force->hplanck;
if (strcmp(name,"mvv2e") == 0) return (void *) &lmp->force->mvv2e;
if (strcmp(name,"ftm2v") == 0) return (void *) &lmp->force->ftm2v;
if (strcmp(name,"mv2d") == 0) return (void *) &lmp->force->mv2d;
if (strcmp(name,"nktv2p") == 0) return (void *) &lmp->force->nktv2p;
if (strcmp(name,"qqr2e") == 0) return (void *) &lmp->force->qqr2e;
if (strcmp(name,"qe2f") == 0) return (void *) &lmp->force->qe2f;
if (strcmp(name,"vxmu2f") == 0) return (void *) &lmp->force->vxmu2f;
if (strcmp(name,"xxt2kmu") == 0) return (void *) &lmp->force->xxt2kmu;
if (strcmp(name,"dielectric") == 0) return (void *) &lmp->force->dielectric;
if (strcmp(name,"qqrd2e") == 0) return (void *) &lmp->force->qqrd2e;
if (strcmp(name,"e_mass") == 0) return (void *) &lmp->force->e_mass;
if (strcmp(name,"hhmrr2e") == 0) return (void *) &lmp->force->hhmrr2e;
if (strcmp(name,"mvh2r") == 0) return (void *) &lmp->force->mvh2r;
if (strcmp(name,"angstrom") == 0) return (void *) &lmp->force->angstrom;
if (strcmp(name,"femtosecond") == 0) return (void *) &lmp->force->femtosecond;
if (strcmp(name,"qelectron") == 0) return (void *) &lmp->force->qelectron;
return NULL;
}