Merge pull request #3664 from hoba87/develop
simplify execution of mliap pytorch example
This commit is contained in:
@ -94,8 +94,12 @@ lmp.commands_string(before_loading)
|
|||||||
|
|
||||||
# Define the model however you like. In this example
|
# Define the model however you like. In this example
|
||||||
# we load it from disk:
|
# we load it from disk:
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
model = torch.load('Ta06A.mliap.pytorch.model.pt')
|
torch_model = 'Ta06A.mliap.pytorch.model.pt'
|
||||||
|
if not os.path.exists(torch_model):
|
||||||
|
raise FileNotFoundError(f"Generate {torch_model} with convert_mliap_Ta06A.py")
|
||||||
|
model = torch.load(torch_model)
|
||||||
|
|
||||||
# Connect the PyTorch model to the mliap pair style.
|
# Connect the PyTorch model to the mliap pair style.
|
||||||
lammps.mliap.load_model(model)
|
lammps.mliap.load_model(model)
|
||||||
|
|||||||
@ -94,8 +94,12 @@ lmp.commands_string(before_loading)
|
|||||||
|
|
||||||
# Define the model however you like. In this example
|
# Define the model however you like. In this example
|
||||||
# we load it from disk:
|
# we load it from disk:
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
model = torch.load('Ta06A.mliap.pytorch.model.pt')
|
torch_model = 'Ta06A.mliap.pytorch.model.pt'
|
||||||
|
if not os.path.exists(torch_model):
|
||||||
|
raise FileNotFoundError(f"Generate {torch_model} with convert_mliap_Ta06A.py")
|
||||||
|
model = torch.load(torch_model)
|
||||||
|
|
||||||
# Connect the PyTorch model to the mliap pair style.
|
# Connect the PyTorch model to the mliap pair style.
|
||||||
lammps.mliap.load_model_kokkos(model)
|
lammps.mliap.load_model_kokkos(model)
|
||||||
|
|||||||
@ -7,18 +7,19 @@ import ctypes
|
|||||||
import platform
|
import platform
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
py_ver = sysconfig.get_config_vars('VERSION')[0]
|
py_ver = sysconfig.get_config_var('VERSION')
|
||||||
|
abi_flags = sysconfig.get_config_var('abiflags')
|
||||||
OS_name = platform.system()
|
OS_name = platform.system()
|
||||||
|
|
||||||
if OS_name == "Darwin":
|
if OS_name == "Darwin":
|
||||||
SHLIB_SUFFIX = '.dylib'
|
SHLIB_SUFFIX = '.dylib'
|
||||||
library = 'libpython' + py_ver + SHLIB_SUFFIX
|
library = 'libpython' + py_ver + abi_flags + SHLIB_SUFFIX
|
||||||
elif OS_name == "Windows":
|
elif OS_name == "Windows":
|
||||||
SHLIB_SUFFIX = '.dll'
|
SHLIB_SUFFIX = '.dll'
|
||||||
library = 'python' + py_ver + SHLIB_SUFFIX
|
library = 'python' + py_ver + abi_flags + SHLIB_SUFFIX
|
||||||
else:
|
else:
|
||||||
SHLIB_SUFFIX = '.so'
|
SHLIB_SUFFIX = '.so'
|
||||||
library = 'libpython' + py_ver + SHLIB_SUFFIX
|
library = 'libpython' + py_ver + abi_flags + SHLIB_SUFFIX
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pylib = ctypes.CDLL(library)
|
pylib = ctypes.CDLL(library)
|
||||||
|
|||||||
Reference in New Issue
Block a user