Add Kokkos example

This commit is contained in:
rohskopf
2023-05-26 10:48:58 -06:00
parent b7146c900f
commit 49b2c299a7
3 changed files with 81 additions and 7 deletions

View File

@ -54,11 +54,9 @@ Run example:
mpirun -np 1 lmp -k on g 1 -sf kk -pk kokkos newton on -in in.run
### Wrapping JAX code
### Deploying JAX models on CPU
Take inspiration from the `FitSNAP` ML-IAP wrapper: https://github.com/rohskopf/FitSNAP/blob/mliap-unified/fitsnap3lib/tools/write_unified.py
First define JAX model in `deploy_script.py`, which will wrap model with `write_unified`.
Use `deploy_script.py`, which will wrap model with `write_unified_jax`.
python deploy_script.py
@ -66,4 +64,16 @@ This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified.
Run LAMMPS with the model:
mpirun -np P lmp -in in.run
mpirun -np P lmp -in in.run
### Deploying JAX models in Kokkos
Use `deploy_script_kokkos.py`, which will wrap model with `write_unified_jax_kokkos`.
python deploy_script_kokkos.py
This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified.
Run LAMMPS with the model:
mpirun -np 1 lmp -k on g 1 -sf kk -pk kokkos newton on -in in.run

View File

@ -31,5 +31,7 @@ fix 1 all nve
#dump 4 all custom 1 forces.xyz fx fy fz
thermo 50
run 250
dump 1 all xyz 10 dump.xyz
thermo 1
run 250

View File

@ -0,0 +1,62 @@
from lammps.mliap.mliap_unified_abc import MLIAPUnified
import numpy as np
import jax
import jax.dlpack
import jax.numpy as jnp
from jax import jit
from functools import partial
import cupy
@jax.jit
def lj_potential(epsilon, sigma, rij):
# A pure function we can differentiate:
def _tot_e(rij):
r2inv = 1.0 / jnp.sum(rij ** 2, axis=1)
r6inv = r2inv * r2inv * r2inv
lj1 = 4.0 * epsilon * sigma**12
lj2 = 4.0 * epsilon * sigma**6
eij = r6inv * (lj1 * r6inv - lj2)
return 0.5 * jnp.sum(eij), eij
# Construct a function computing _tot_e and its derivative
(_, eij), fij = jax.value_and_grad(_tot_e, has_aux=True)(rij)
return eij, fij
class MLIAPUnifiedJAXKokkos(MLIAPUnified):
"""JAX wrapper for MLIAPUnified."""
epsilon: float
sigma: float
def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25):
# ARGS: interface, element_types, ndescriptors, nparams, rcutfac
super().__init__(None, element_types, 1, 3, rcutfac)
# Mimicking the LJ pair-style:
# pair_style lj/cut 2.5
# pair_coeff * * 1 1
self.epsilon = epsilon
self.sigma = sigma
def compute_gradients(self, data):
"""Test compute_gradients."""
def compute_descriptors(self, data):
"""Test compute_descriptors."""
def compute_forces(self, data):
"""Test compute_forces."""
# NOTE: Use data.rij_max with JAX.
# dlpack requires cudnn:
rij = jax.dlpack.from_dlpack(data.rij_max.toDlpack())
eij, fij = lj_potential(self.epsilon, self.sigma, rij)
# Convert back to cupy.
eij = cupy.from_dlpack(jax.dlpack.to_dlpack(eij)).astype(np.float64)
fij = cupy.from_dlpack(jax.dlpack.to_dlpack(fij)).astype(np.float64)
# Send to LAMMPS.
data.update_pair_energy(eij)
data.update_pair_forces(fij)