Update fix_python_invoke

This commit is contained in:
Richard Berger
2021-04-06 14:47:20 -04:00
parent da5bd578ad
commit 5ee24c5b89
2 changed files with 30 additions and 23 deletions

View File

@ -21,6 +21,7 @@
#include "error.h" #include "error.h"
#include "lmppython.h" #include "lmppython.h"
#include "python_compat.h" #include "python_compat.h"
#include "python_utils.h"
#include "update.h" #include "update.h"
#include <cstring> #include <cstring>
@ -51,12 +52,12 @@ FixPythonInvoke::FixPythonInvoke(LAMMPS *lmp, int narg, char **arg) :
} }
// get Python function // get Python function
PyGILState_STATE gstate = PyGILState_Ensure(); PyUtils::GIL lock;
PyObject *pyMain = PyImport_AddModule("__main__"); PyObject *pyMain = PyImport_AddModule("__main__");
if (!pyMain) { if (!pyMain) {
PyGILState_Release(gstate); PyUtils::Print_Errors();
error->all(FLERR,"Could not initialize embedded Python"); error->all(FLERR,"Could not initialize embedded Python");
} }
@ -64,11 +65,19 @@ FixPythonInvoke::FixPythonInvoke(LAMMPS *lmp, int narg, char **arg) :
pFunc = PyObject_GetAttrString(pyMain, fname); pFunc = PyObject_GetAttrString(pyMain, fname);
if (!pFunc) { if (!pFunc) {
PyGILState_Release(gstate); PyUtils::Print_Errors();
error->all(FLERR,"Could not find Python function"); error->all(FLERR,"Could not find Python function");
} }
PyGILState_Release(gstate); lmpPtr = PY_VOID_POINTER(lmp);
}
/* ---------------------------------------------------------------------- */
FixPythonInvoke::~FixPythonInvoke()
{
PyUtils::GIL lock;
Py_CLEAR(lmpPtr);
} }
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
@ -82,36 +91,32 @@ int FixPythonInvoke::setmask()
void FixPythonInvoke::end_of_step() void FixPythonInvoke::end_of_step()
{ {
PyGILState_STATE gstate = PyGILState_Ensure(); PyUtils::GIL lock;
PyObject *ptr = PY_VOID_POINTER(lmp); PyObject * result = PyObject_CallFunction((PyObject*)pFunc, "O", (PyObject*)lmpPtr);
PyObject *arglist = Py_BuildValue("(O)", ptr);
PyObject *result = PyEval_CallObject((PyObject*)pFunc, arglist); if (!result) {
Py_DECREF(arglist); PyUtils::Print_Errors();
if (!result && (comm->me == 0)) PyErr_Print();
PyGILState_Release(gstate);
if (!result)
error->all(FLERR,"Fix python/invoke end_of_step() method failed"); error->all(FLERR,"Fix python/invoke end_of_step() method failed");
} }
Py_CLEAR(result);
}
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
void FixPythonInvoke::post_force(int vflag) void FixPythonInvoke::post_force(int vflag)
{ {
if (update->ntimestep % nevery != 0) return; if (update->ntimestep % nevery != 0) return;
PyGILState_STATE gstate = PyGILState_Ensure(); PyUtils::GIL lock;
PyObject *ptr = PY_VOID_POINTER(lmp); PyObject * result = PyObject_CallFunction((PyObject*)pFunc, "Oi", (PyObject*)lmpPtr, vflag);
PyObject *arglist = Py_BuildValue("(Oi)", ptr, vflag);
PyObject *result = PyEval_CallObject((PyObject*)pFunc, arglist); if (!result) {
Py_DECREF(arglist); PyUtils::Print_Errors();
if (!result && (comm->me == 0)) PyErr_Print();
PyGILState_Release(gstate);
if (!result)
error->all(FLERR,"Fix python/invoke post_force() method failed"); error->all(FLERR,"Fix python/invoke post_force() method failed");
} }
Py_CLEAR(result);
}

View File

@ -21,6 +21,7 @@ FixStyle(python,FixPythonInvoke)
#ifndef LMP_FIX_PYTHON_INVOKE_H #ifndef LMP_FIX_PYTHON_INVOKE_H
#define LMP_FIX_PYTHON_INVOKE_H #define LMP_FIX_PYTHON_INVOKE_H
#include "fix.h" #include "fix.h"
namespace LAMMPS_NS { namespace LAMMPS_NS {
@ -28,12 +29,13 @@ namespace LAMMPS_NS {
class FixPythonInvoke : public Fix { class FixPythonInvoke : public Fix {
public: public:
FixPythonInvoke(class LAMMPS *, int, char **); FixPythonInvoke(class LAMMPS *, int, char **);
virtual ~FixPythonInvoke() {} virtual ~FixPythonInvoke();
int setmask(); int setmask();
virtual void end_of_step(); virtual void end_of_step();
virtual void post_force(int); virtual void post_force(int);
private: private:
void * lmpPtr;
void * pFunc; void * pFunc;
int selected_callback; int selected_callback;
}; };