Update pair_python

This commit is contained in:
Richard Berger
2021-04-06 14:50:08 -04:00
parent 7e9fa25121
commit 0aa9aa96f6
2 changed files with 70 additions and 168 deletions

View File

@ -24,9 +24,10 @@
#include "memory.h" #include "memory.h"
#include "neigh_list.h" #include "neigh_list.h"
#include "python_compat.h" #include "python_compat.h"
#include "python_utils.h"
#include "update.h" #include "update.h"
#include <cstring> #include <string>
#include <Python.h> // IWYU pragma: export #include <Python.h> // IWYU pragma: export
using namespace LAMMPS_NS; using namespace LAMMPS_NS;
@ -50,7 +51,7 @@ PairPython::PairPython(LAMMPS *lmp) : Pair(lmp) {
// add current directory to PYTHONPATH // add current directory to PYTHONPATH
PyGILState_STATE gstate = PyGILState_Ensure(); PyUtils::GIL lock;
PyObject *py_path = PySys_GetObject((char *)"path"); PyObject *py_path = PySys_GetObject((char *)"path");
PyList_Append(py_path, PY_STRING_FROM_STRING(".")); PyList_Append(py_path, PY_STRING_FROM_STRING("."));
@ -61,14 +62,14 @@ PairPython::PairPython(LAMMPS *lmp) : Pair(lmp) {
if (potentials_path != nullptr) { if (potentials_path != nullptr) {
PyList_Append(py_path, PY_STRING_FROM_STRING(potentials_path)); PyList_Append(py_path, PY_STRING_FROM_STRING(potentials_path));
} }
PyGILState_Release(gstate);
} }
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
PairPython::~PairPython() PairPython::~PairPython()
{ {
if (py_potential) Py_DECREF((PyObject*) py_potential); PyUtils::GIL lock;
Py_CLEAR(py_potential);
delete[] skip_types; delete[] skip_types;
if (allocated) { if (allocated) {
@ -103,41 +104,31 @@ void PairPython::compute(int eflag, int vflag)
// prepare access to compute_force and compute_energy functions // prepare access to compute_force and compute_energy functions
PyGILState_STATE gstate = PyGILState_Ensure(); PyUtils::GIL lock;
PyObject *py_pair_instance = (PyObject *) py_potential; PyObject *py_pair_instance = (PyObject *) py_potential;
PyObject *py_compute_force = PyObject_GetAttrString(py_pair_instance,"compute_force"); PyObject *py_compute_force = PyObject_GetAttrString(py_pair_instance,"compute_force");
if (!py_compute_force) { if (!py_compute_force) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not find 'compute_force' method'"); error->all(FLERR,"Could not find 'compute_force' method'");
} }
if (!PyCallable_Check(py_compute_force)) { if (!PyCallable_Check(py_compute_force)) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Python 'compute_force' is not callable"); error->all(FLERR,"Python 'compute_force' is not callable");
} }
PyObject *py_compute_energy = PyObject_GetAttrString(py_pair_instance,"compute_energy"); PyObject *py_compute_energy = PyObject_GetAttrString(py_pair_instance,"compute_energy");
if (!py_compute_energy) { if (!py_compute_energy) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not find 'compute_energy' method'"); error->all(FLERR,"Could not find 'compute_energy' method'");
} }
if (!PyCallable_Check(py_compute_energy)) { if (!PyCallable_Check(py_compute_energy)) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Python 'compute_energy' is not callable"); error->all(FLERR,"Python 'compute_energy' is not callable");
} }
PyObject *py_compute_args = PyTuple_New(3); PyObject *py_compute_args = PyTuple_New(3);
if (!py_compute_args) { if (!py_compute_args) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not create tuple for 'compute' function arguments"); error->all(FLERR,"Could not create tuple for 'compute' function arguments");
} }
@ -179,13 +170,11 @@ void PairPython::compute(int eflag, int vflag)
PyTuple_SetItem(py_compute_args,0,py_rsq); PyTuple_SetItem(py_compute_args,0,py_rsq);
py_value = PyObject_CallObject(py_compute_force,py_compute_args); py_value = PyObject_CallObject(py_compute_force,py_compute_args);
if (!py_value) { if (!py_value) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Calling 'compute_force' function failed"); error->all(FLERR,"Calling 'compute_force' function failed");
} }
fpair = factor_lj*PyFloat_AsDouble(py_value); fpair = factor_lj*PyFloat_AsDouble(py_value);
Py_DECREF(py_value); Py_CLEAR(py_value);
f[i][0] += delx*fpair; f[i][0] += delx*fpair;
f[i][1] += dely*fpair; f[i][1] += dely*fpair;
@ -198,8 +187,12 @@ void PairPython::compute(int eflag, int vflag)
if (eflag) { if (eflag) {
py_value = PyObject_CallObject(py_compute_energy,py_compute_args); py_value = PyObject_CallObject(py_compute_energy,py_compute_args);
if (!py_value) {
PyUtils::Print_Errors();
error->all(FLERR,"Calling 'compute_energy' function failed");
}
evdwl = factor_lj*PyFloat_AsDouble(py_value); evdwl = factor_lj*PyFloat_AsDouble(py_value);
Py_DECREF(py_value); Py_CLEAR(py_value);
} else evdwl = 0.0; } else evdwl = 0.0;
if (evflag) ev_tally(i,j,nlocal,newton_pair, if (evflag) ev_tally(i,j,nlocal,newton_pair,
@ -207,8 +200,7 @@ void PairPython::compute(int eflag, int vflag)
} }
} }
} }
Py_DECREF(py_compute_args); Py_CLEAR(py_compute_args);
PyGILState_Release(gstate);
if (vflag_fdotr) virial_fdotr_compute(); if (vflag_fdotr) virial_fdotr_compute();
} }
@ -261,112 +253,49 @@ void PairPython::coeff(int narg, char **arg)
error->all(FLERR,"Incorrect args for pair coefficients"); error->all(FLERR,"Incorrect args for pair coefficients");
// check if python potential file exists and source it // check if python potential file exists and source it
char * full_cls_name = arg[2]; std::string full_cls_name = arg[2];
char * lastpos = strrchr(full_cls_name, '.'); size_t lastpos = full_cls_name.rfind(".");
if (lastpos == nullptr) { if (lastpos == std::string::npos) {
error->all(FLERR,"Python pair style requires fully qualified class name"); error->all(FLERR,"Python pair style requires fully qualified class name");
} }
size_t module_name_length = strlen(full_cls_name) - strlen(lastpos); std::string module_name = full_cls_name.substr(0, lastpos);
size_t cls_name_length = strlen(lastpos)-1; std::string cls_name = full_cls_name.substr(lastpos+1);
char * module_name = new char[module_name_length+1]; PyUtils::GIL lock;
char * cls_name = new char[cls_name_length+1];
strncpy(module_name, full_cls_name, module_name_length);
module_name[module_name_length] = 0;
strcpy(cls_name, lastpos+1); PyObject * pModule = PyImport_ImportModule(module_name.c_str());
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject * pModule = PyImport_ImportModule(module_name);
if (!pModule) { if (!pModule) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Loading python pair style module failure"); error->all(FLERR,"Loading python pair style module failure");
} }
// create LAMMPS atom type to potential file type mapping in python class // create LAMMPS atom type to potential file type mapping in python class
// by calling 'lammps_pair_style.map_coeff(name,type)' // by calling 'lammps_pair_style.map_coeff(name,type)'
PyObject *py_pair_type = PyObject_GetAttrString(pModule, cls_name); PyObject *py_pair_type = PyObject_GetAttrString(pModule, cls_name.c_str());
if (!py_pair_type) { if (!py_pair_type) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not find pair style class in module'"); error->all(FLERR,"Could not find pair style class in module'");
} }
delete [] module_name;
delete [] cls_name;
PyObject * py_pair_instance = PyObject_CallObject(py_pair_type, nullptr); PyObject * py_pair_instance = PyObject_CallObject(py_pair_type, nullptr);
if (!py_pair_instance) { if (!py_pair_instance) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not instantiate instance of pair style class'"); error->all(FLERR,"Could not instantiate instance of pair style class'");
} }
py_potential = (void *) py_pair_instance; py_potential = (void *) py_pair_instance;
PyObject *py_check_units = PyObject_GetAttrString(py_pair_instance,"check_units"); PyObject *py_value = PyObject_CallMethod(py_pair_instance, "check_units", "s", update->unit_style);
if (!py_check_units) {
PyErr_Print();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not find 'check_units' method'");
}
if (!PyCallable_Check(py_check_units)) {
PyErr_Print();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Python 'check_units' is not callable");
}
PyObject *py_units_args = PyTuple_New(1);
if (!py_units_args) {
PyErr_Print();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not create tuple for 'check_units' function arguments");
}
PyObject *py_name = PY_STRING_FROM_STRING(update->unit_style);
PyTuple_SetItem(py_units_args,0,py_name);
PyObject *py_value = PyObject_CallObject(py_check_units,py_units_args);
if (!py_value) { if (!py_value) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Calling 'check_units' function failed"); error->all(FLERR,"Calling 'check_units' function failed");
} }
Py_DECREF(py_units_args); Py_CLEAR(py_value);
PyObject *py_map_coeff = PyObject_GetAttrString(py_pair_instance,"map_coeff");
if (!py_map_coeff) {
PyErr_Print();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not find 'map_coeff' method'");
}
if (!PyCallable_Check(py_map_coeff)) {
PyErr_Print();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Python 'map_coeff' is not callable");
}
PyObject *py_map_args = PyTuple_New(2);
if (!py_map_args) {
PyErr_Print();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not create tuple for 'map_coeff' function arguments");
}
delete[] skip_types; delete[] skip_types;
skip_types = new int[ntypes+1]; skip_types = new int[ntypes+1];
skip_types[0] = 1; skip_types[0] = 1;
@ -375,25 +304,20 @@ void PairPython::coeff(int narg, char **arg)
skip_types[i] = 1; skip_types[i] = 1;
continue; continue;
} else skip_types[i] = 0; } else skip_types[i] = 0;
PyObject *py_type = PY_INT_FROM_LONG(i); const int type = i;
py_name = PY_STRING_FROM_STRING(arg[2+i]); const char * name = arg[2+i];
PyTuple_SetItem(py_map_args,0,py_name); py_value = PyObject_CallMethod(py_pair_instance, "map_coeff", "si", name, type);
PyTuple_SetItem(py_map_args,1,py_type);
py_value = PyObject_CallObject(py_map_coeff,py_map_args);
if (!py_value) { if (!py_value) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Calling 'map_coeff' function failed"); error->all(FLERR,"Calling 'map_coeff' function failed");
} }
Py_CLEAR(py_value);
for (int j = i; j <= ntypes ; j++) { for (int j = i; j <= ntypes ; j++) {
setflag[i][j] = 1; setflag[i][j] = 1;
cutsq[i][j] = cut_global*cut_global; cutsq[i][j] = cut_global*cut_global;
} }
} }
Py_DECREF(py_map_args);
PyGILState_Release(gstate);
} }
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
@ -417,76 +341,53 @@ double PairPython::single(int /* i */, int /* j */, int itype, int jtype,
// prepare access to compute_force and compute_energy functions // prepare access to compute_force and compute_energy functions
PyGILState_STATE gstate = PyGILState_Ensure(); PyUtils::GIL lock;
PyObject *py_pair_instance = (PyObject *) py_potential; PyObject *py_pair_instance = (PyObject *) py_potential;
PyObject *py_compute_force PyObject *py_compute_force = (PyObject *) get_member_function("compute_force");
= PyObject_GetAttrString(py_pair_instance,"compute_force"); PyObject *py_compute_energy = (PyObject *) get_member_function("compute_energy");
if (!py_compute_force) { PyObject *py_compute_args = Py_BuildValue("(dii)", rsq, itype, jtype);
PyErr_Print();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not find 'compute_force' method'");
}
if (!PyCallable_Check(py_compute_force)) {
PyErr_Print();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Python 'compute_force' is not callable");
}
PyObject *py_compute_energy
= PyObject_GetAttrString(py_pair_instance,"compute_energy");
if (!py_compute_energy) {
PyErr_Print();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not find 'compute_energy' method'");
}
if (!PyCallable_Check(py_compute_energy)) {
PyErr_Print();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Python 'compute_energy' is not callable");
}
PyObject *py_rsq, *py_itype, *py_jtype, *py_value;
PyObject *py_compute_args = PyTuple_New(3);
if (!py_compute_args) { if (!py_compute_args) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Could not create tuple for 'compute' function arguments"); error->all(FLERR,"Could not create tuple for 'compute' function arguments");
} }
py_itype = PY_INT_FROM_LONG(itype); PyObject * py_value = PyObject_CallObject(py_compute_force, py_compute_args);
PyTuple_SetItem(py_compute_args,1,py_itype);
py_jtype = PY_INT_FROM_LONG(jtype);
PyTuple_SetItem(py_compute_args,2,py_jtype);
py_rsq = PyFloat_FromDouble(rsq);
PyTuple_SetItem(py_compute_args,0,py_rsq);
py_value = PyObject_CallObject(py_compute_force,py_compute_args);
if (!py_value) { if (!py_value) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Calling 'compute_force' function failed"); error->all(FLERR,"Calling 'compute_force' function failed");
} }
fforce = factor_lj*PyFloat_AsDouble(py_value); fforce = factor_lj*PyFloat_AsDouble(py_value);
Py_DECREF(py_value); Py_CLEAR(py_value);
py_value = PyObject_CallObject(py_compute_energy,py_compute_args); py_value = PyObject_CallObject(py_compute_energy, py_compute_args);
if (!py_value) { if (!py_value) {
PyErr_Print(); PyUtils::Print_Errors();
PyErr_Clear();
PyGILState_Release(gstate);
error->all(FLERR,"Calling 'compute_energy' function failed"); error->all(FLERR,"Calling 'compute_energy' function failed");
} }
double evdwl = factor_lj*PyFloat_AsDouble(py_value); double evdwl = factor_lj*PyFloat_AsDouble(py_value);
Py_DECREF(py_value);
Py_DECREF(py_compute_args); Py_CLEAR(py_value);
PyGILState_Release(gstate); Py_CLEAR(py_compute_args);
return evdwl; return evdwl;
} }
/* ---------------------------------------------------------------------- */
void * PairPython::get_member_function(const char * name)
{
PyUtils::GIL lock;
PyObject *py_pair_instance = (PyObject *) py_potential;
PyObject * py_mfunc = PyObject_GetAttrString(py_pair_instance, name);
if (!py_mfunc) {
PyUtils::Print_Errors();
error->all(FLERR, fmt::format("Could not find '{}' method'", name));
}
if (!PyCallable_Check(py_mfunc)) {
PyUtils::Print_Errors();
error->all(FLERR, fmt::format("Python '{}' is not callable", name));
}
return py_mfunc;
}

View File

@ -50,6 +50,7 @@ class PairPython : public Pair {
int * skip_types; int * skip_types;
virtual void allocate(); virtual void allocate();
void * get_member_function(const char *);
}; };
} }