Have PyTorch interface for MLIAP working in Kokkos. This uses cuPy and a simple example is provided

This commit is contained in:
Matt Bettencourt
2022-11-14 17:49:00 +01:00
parent 07fe2fa29d
commit d47acfc0c4
14 changed files with 742 additions and 38 deletions

View File

@ -31,5 +31,8 @@ if not pylib.Py_IsInitialized():
"in undefined behavior.")
else:
from .loader import load_model, load_unified, activate_mliappy
try:
from .loader import load_model_kokkos, activate_mliappy_kokkos
except:
pass
del sysconfig, ctypes, library, pylib

View File

@ -56,6 +56,7 @@ class DynamicLoader(importlib.abc.Loader):
def activate_mliappy(lmp):
try:
print("activate_mliappy")
library = lmp.lib
module_names = ["mliap_model_python_couple", "mliap_unified_couple"]
api_version = library.lammps_python_api_version()
@ -72,8 +73,28 @@ def activate_mliappy(lmp):
except Exception as ee:
raise ImportError("Could not load ML-IAP python coupling module.") from ee
def activate_mliappy_kokkos(lmp):
try:
print("activate_mliappy_kokkos")
library = lmp.lib
module_names = ["mliap_model_python_couple_kokkos"]
api_version = library.lammps_python_api_version()
for module_name in module_names:
# Make Machinery
loader = DynamicLoader(module_name,library,api_version)
spec = importlib.util.spec_from_loader(module_name,loader)
# Do the import
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
except Exception as ee:
raise ImportError("Could not load ML-IAP python coupling module.") from ee
def load_model(model):
try:
print("load_model")
import mliap_model_python_couple
except ImportError as ie:
raise ImportError("ML-IAP python module must be activated before loading\n"
@ -81,6 +102,17 @@ def load_model(model):
) from ie
mliap_model_python_couple.load_from_python(model)
def load_model_kokkos(model):
try:
print("load_model_kokkos")
import mliap_model_python_couple_kokkos
except ImportError as ie:
raise ImportError("ML-IAP python module must be activated before loading\n"
"the pair style. Call lammps.mliap.activate_mliappy(lmp)."
) from ie
mliap_model_python_couple_kokkos.load_from_python(model)
def load_unified(model):
try:
import mliap_unified_couple
@ -89,3 +121,4 @@ def load_unified(model):
"the pair style. Call lammps.mliap.activate_mliappy(lmp)."
) from ie
mliap_unified_couple.load_from_python(model)

View File

@ -89,7 +89,6 @@ class TorchWrapper(torch.nn.Module):
"""
super().__init__()
self.model = model
self.device = device
self.dtype = dtype
@ -105,7 +104,7 @@ class TorchWrapper(torch.nn.Module):
self.n_descriptors = n_descriptors
self.n_elements = n_elements
def forward(self, elems, descriptors, beta, energy):
def forward(self, elems, descriptors, beta, energy,use_gpu_data=False):
"""
Takes element types and descriptors calculated via lammps and
calculates the per atom energies and forces.
@ -130,20 +129,28 @@ class TorchWrapper(torch.nn.Module):
-------
None
"""
descriptors = torch.from_numpy(descriptors).to(dtype=self.dtype, device=self.device).requires_grad_(True)
elems = torch.from_numpy(elems).to(dtype=torch.long, device=self.device) - 1
descriptors = torch.as_tensor(descriptors,dtype=self.dtype, device=self.device).requires_grad_(True)
elems = torch.as_tensor(elems,dtype=torch.int32, device=self.device)
elems=elems-1
with torch.autograd.enable_grad():
energy_nn = self.model(descriptors, elems)
if energy_nn.ndim > 1:
energy_nn = energy_nn.flatten()
if (use_gpu_data):
energy_nn = torch.as_tensor(energy,dtype=self.dtype, device=self.device)
energy_nn[:] = self.model(descriptors, elems).flatten()
else:
energy_nn = self.model(descriptors, elems).flatten()
energy[:] = energy_nn.detach().cpu().numpy().astype(np.float64)
#if energy_nn.ndim > 1:
# energy_nn = energy_nn.flatten()
if (use_gpu_data):
beta_nn = torch.as_tensor(beta,dtype=self.dtype, device=self.device)
beta_nn[:] = torch.autograd.grad(energy_nn.sum(), descriptors)[0]
else:
beta_nn = torch.autograd.grad(energy_nn.sum(), descriptors)[0]
beta[:] = beta_nn.detach().cpu().numpy().astype(np.float64)
energy[:] = energy_nn.detach().cpu().numpy().astype(np.float64)
beta[:] = beta_nn.detach().cpu().numpy().astype(np.float64)
elems=elems+1
class IgnoreElems(torch.nn.Module):