Add Kokkos example
This commit is contained in:
@ -54,11 +54,9 @@ Run example:
|
||||
|
||||
mpirun -np 1 lmp -k on g 1 -sf kk -pk kokkos newton on -in in.run
|
||||
|
||||
### Wrapping JAX code
|
||||
### Deploying JAX models on CPU
|
||||
|
||||
Take inspiration from the `FitSNAP` ML-IAP wrapper: https://github.com/rohskopf/FitSNAP/blob/mliap-unified/fitsnap3lib/tools/write_unified.py
|
||||
|
||||
First define JAX model in `deploy_script.py`, which will wrap model with `write_unified`.
|
||||
Use `deploy_script.py`, which will wrap model with `write_unified_jax`.
|
||||
|
||||
python deploy_script.py
|
||||
|
||||
@ -66,4 +64,16 @@ This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified.
|
||||
|
||||
Run LAMMPS with the model:
|
||||
|
||||
mpirun -np P lmp -in in.run
|
||||
mpirun -np P lmp -in in.run
|
||||
|
||||
### Deploying JAX models in Kokkos
|
||||
|
||||
Use `deploy_script_kokkos.py`, which will wrap model with `write_unified_jax_kokkos`.
|
||||
|
||||
python deploy_script_kokkos.py
|
||||
|
||||
This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified.
|
||||
|
||||
Run LAMMPS with the model:
|
||||
|
||||
mpirun -np 1 lmp -k on g 1 -sf kk -pk kokkos newton on -in in.run
|
||||
@ -31,5 +31,7 @@ fix 1 all nve
|
||||
|
||||
#dump 4 all custom 1 forces.xyz fx fy fz
|
||||
|
||||
thermo 50
|
||||
run 250
|
||||
dump 1 all xyz 10 dump.xyz
|
||||
|
||||
thermo 1
|
||||
run 250
|
||||
|
||||
62
examples/mliap/jax/mliap_unified_jax_kokkos.py
Normal file
62
examples/mliap/jax/mliap_unified_jax_kokkos.py
Normal file
@ -0,0 +1,62 @@
|
||||
from lammps.mliap.mliap_unified_abc import MLIAPUnified
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.dlpack
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from functools import partial
|
||||
import cupy
|
||||
|
||||
@jax.jit
|
||||
def lj_potential(epsilon, sigma, rij):
|
||||
# A pure function we can differentiate:
|
||||
def _tot_e(rij):
|
||||
r2inv = 1.0 / jnp.sum(rij ** 2, axis=1)
|
||||
r6inv = r2inv * r2inv * r2inv
|
||||
|
||||
lj1 = 4.0 * epsilon * sigma**12
|
||||
lj2 = 4.0 * epsilon * sigma**6
|
||||
|
||||
eij = r6inv * (lj1 * r6inv - lj2)
|
||||
return 0.5 * jnp.sum(eij), eij
|
||||
# Construct a function computing _tot_e and its derivative
|
||||
(_, eij), fij = jax.value_and_grad(_tot_e, has_aux=True)(rij)
|
||||
return eij, fij
|
||||
|
||||
|
||||
class MLIAPUnifiedJAXKokkos(MLIAPUnified):
|
||||
"""JAX wrapper for MLIAPUnified."""
|
||||
|
||||
epsilon: float
|
||||
sigma: float
|
||||
|
||||
def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25):
|
||||
# ARGS: interface, element_types, ndescriptors, nparams, rcutfac
|
||||
super().__init__(None, element_types, 1, 3, rcutfac)
|
||||
# Mimicking the LJ pair-style:
|
||||
# pair_style lj/cut 2.5
|
||||
# pair_coeff * * 1 1
|
||||
self.epsilon = epsilon
|
||||
self.sigma = sigma
|
||||
|
||||
def compute_gradients(self, data):
|
||||
"""Test compute_gradients."""
|
||||
|
||||
def compute_descriptors(self, data):
|
||||
"""Test compute_descriptors."""
|
||||
|
||||
def compute_forces(self, data):
|
||||
"""Test compute_forces."""
|
||||
|
||||
# NOTE: Use data.rij_max with JAX.
|
||||
# dlpack requires cudnn:
|
||||
rij = jax.dlpack.from_dlpack(data.rij_max.toDlpack())
|
||||
eij, fij = lj_potential(self.epsilon, self.sigma, rij)
|
||||
|
||||
# Convert back to cupy.
|
||||
eij = cupy.from_dlpack(jax.dlpack.to_dlpack(eij)).astype(np.float64)
|
||||
fij = cupy.from_dlpack(jax.dlpack.to_dlpack(fij)).astype(np.float64)
|
||||
|
||||
# Send to LAMMPS.
|
||||
data.update_pair_energy(eij)
|
||||
data.update_pair_forces(fij)
|
||||
Reference in New Issue
Block a user