Merge pull request #3814 from rohskopf/jax
JAX ML-IAP Unified connection & examples
This commit is contained in:
87
examples/mliap/jax/README.md
Normal file
87
examples/mliap/jax/README.md
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# Running JAX from LAMMPS
|
||||||
|
|
||||||
|
### Getting started
|
||||||
|
|
||||||
|
First make a Python environment with dependencies:
|
||||||
|
|
||||||
|
conda create --name jax python=3.10
|
||||||
|
conda activate jax
|
||||||
|
# Upgrade pip
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
# Install JAX:
|
||||||
|
python -m pip install --upgrade "jax[cpu]"
|
||||||
|
# Install other dependencies:
|
||||||
|
python -m pip install numpy scipy torch scikit-learn virtualenv psutil tabulate mpi4py Cython
|
||||||
|
|
||||||
|
Install LAMMPS:
|
||||||
|
|
||||||
|
cd /path/to/lammps
|
||||||
|
mkdir build-jax; cd build-jax
|
||||||
|
cmake ../cmake -DLAMMPS_EXCEPTIONS=yes \
|
||||||
|
-DBUILD_SHARED_LIBS=yes \
|
||||||
|
-DMLIAP_ENABLE_PYTHON=yes \
|
||||||
|
-DPKG_PYTHON=yes \
|
||||||
|
-DPKG_ML-SNAP=yes \
|
||||||
|
-DPKG_ML-IAP=yes \
|
||||||
|
-DPYTHON_EXECUTABLE:FILEPATH=`which python`
|
||||||
|
make -j4
|
||||||
|
make install-python
|
||||||
|
|
||||||
|
### Kokkos install
|
||||||
|
|
||||||
|
Use same Python dependencies as above, with some extra changes:
|
||||||
|
|
||||||
|
1. Make sure you install cupy properly! E.g.
|
||||||
|
|
||||||
|
python -m pip install cupy-cuda12x
|
||||||
|
|
||||||
|
2. Install JAX for GPU/CUDA:
|
||||||
|
|
||||||
|
python -m pip install --trusted-host storage.googleapis.com --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||||
|
|
||||||
|
3. Install cudNN: https://developer.nvidia.com/cudnn
|
||||||
|
|
||||||
|
Install LAMMPS. Take care to change `Kokkos_ARCH_*` flag:
|
||||||
|
|
||||||
|
cmake ../cmake -DLAMMPS_EXCEPTIONS=yes \
|
||||||
|
-DBUILD_SHARED_LIBS=yes \
|
||||||
|
-DPKG_PYTHON=yes \
|
||||||
|
-DPKG_ML-SNAP=yes \
|
||||||
|
-DPKG_ML-IAP=yes \
|
||||||
|
-DMLIAP_ENABLE_PYTHON=yes \
|
||||||
|
-DPKG_KOKKOS=yes \
|
||||||
|
-DKokkos_ARCH_TURING75=yes \
|
||||||
|
-DKokkos_ENABLE_CUDA=yes \
|
||||||
|
-DKokkos_ENABLE_OPENMP=yes \
|
||||||
|
-DCMAKE_CXX_COMPILER=${HOME}/lammps/lib/kokkos/bin/nvcc_wrapper \
|
||||||
|
-DPYTHON_EXECUTABLE:FILEPATH=`which python`
|
||||||
|
make -j
|
||||||
|
make install-python
|
||||||
|
|
||||||
|
Run example:
|
||||||
|
|
||||||
|
mpirun -np 1 lmp -k on g 1 -sf kk -pk kokkos newton on -in in.run
|
||||||
|
|
||||||
|
### Deploying JAX models on CPU
|
||||||
|
|
||||||
|
Use `deploy_script.py`, which will wrap model with `write_unified_jax`.
|
||||||
|
|
||||||
|
python deploy_script.py
|
||||||
|
|
||||||
|
This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified.
|
||||||
|
|
||||||
|
Run LAMMPS with the model:
|
||||||
|
|
||||||
|
mpirun -np P lmp -in in.run
|
||||||
|
|
||||||
|
### Deploying JAX models in Kokkos
|
||||||
|
|
||||||
|
Use `deploy_script_kokkos.py`, which will wrap model with `write_unified_jax_kokkos`.
|
||||||
|
|
||||||
|
python deploy_script_kokkos.py
|
||||||
|
|
||||||
|
This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified.
|
||||||
|
|
||||||
|
Run LAMMPS with the model:
|
||||||
|
|
||||||
|
mpirun -np 1 lmp -k on g 1 -sf kk -pk kokkos newton on -in in.run
|
||||||
11
examples/mliap/jax/deploy_script.py
Normal file
11
examples/mliap/jax/deploy_script.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import lammps
|
||||||
|
import lammps.mliap
|
||||||
|
|
||||||
|
#from lammps.mliap.mliap_unified_lj import MLIAPUnifiedLJ
|
||||||
|
from mliap_unified_jax import MLIAPUnifiedJAX
|
||||||
|
|
||||||
|
def create_pickle():
|
||||||
|
unified = MLIAPUnifiedJAX(["Ar"])
|
||||||
|
unified.pickle('mliap_unified_jax_Ar.pkl')
|
||||||
|
|
||||||
|
create_pickle()
|
||||||
37
examples/mliap/jax/in.run
Normal file
37
examples/mliap/jax/in.run
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# 3d Lennard-Jones melt
|
||||||
|
|
||||||
|
units lj
|
||||||
|
atom_style atomic
|
||||||
|
|
||||||
|
lattice fcc 0.8442
|
||||||
|
region box block 0 10 0 10 0 10
|
||||||
|
create_box 1 box
|
||||||
|
create_atoms 1 box
|
||||||
|
mass 1 1.0
|
||||||
|
|
||||||
|
velocity all create 3.0 87287 loop geom
|
||||||
|
|
||||||
|
pair_style mliap unified mliap_unified_jax_Ar.pkl 0
|
||||||
|
pair_coeff * * Ar
|
||||||
|
|
||||||
|
neighbor 0.3 bin
|
||||||
|
neigh_modify every 20 delay 0 check no
|
||||||
|
|
||||||
|
fix 1 all nve
|
||||||
|
|
||||||
|
#dump id all atom 50 dump.melt
|
||||||
|
|
||||||
|
#dump 2 all image 25 image.*.jpg type type &
|
||||||
|
# axes yes 0.8 0.02 view 60 -30
|
||||||
|
#dump_modify 2 pad 3
|
||||||
|
|
||||||
|
#dump 3 all movie 1 movie.mpg type type &
|
||||||
|
# axes yes 0.8 0.02 view 60 -30
|
||||||
|
#dump_modify 3 pad 3
|
||||||
|
|
||||||
|
#dump 4 all custom 1 forces.xyz fx fy fz
|
||||||
|
|
||||||
|
dump 1 all xyz 10 dump.xyz
|
||||||
|
|
||||||
|
thermo 1
|
||||||
|
run 250
|
||||||
BIN
examples/mliap/jax/mliap_jax.pkl
Normal file
BIN
examples/mliap/jax/mliap_jax.pkl
Normal file
Binary file not shown.
61
examples/mliap/jax/mliap_unified_jax.py
Normal file
61
examples/mliap/jax/mliap_unified_jax.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from lammps.mliap.mliap_unified_abc import MLIAPUnified
|
||||||
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import jit
|
||||||
|
from functools import partial
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Required else get `jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory`
|
||||||
|
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
|
||||||
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
|
||||||
|
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def lj_potential(epsilon, sigma, rij):
|
||||||
|
def _tot_e(rij):
|
||||||
|
"""A differentiable fn for total energy."""
|
||||||
|
r2inv = 1.0 / jnp.sum(rij ** 2, axis=1)
|
||||||
|
r6inv = r2inv * r2inv * r2inv
|
||||||
|
|
||||||
|
lj1 = 4.0 * epsilon * sigma**12
|
||||||
|
lj2 = 4.0 * epsilon * sigma**6
|
||||||
|
|
||||||
|
eij = r6inv * (lj1 * r6inv - lj2)
|
||||||
|
return 0.5 * jnp.sum(eij), eij
|
||||||
|
# Compute _tot_e and its derivative.
|
||||||
|
(_, eij), fij = jax.value_and_grad(_tot_e, has_aux=True)(rij)
|
||||||
|
return eij, fij
|
||||||
|
|
||||||
|
|
||||||
|
class MLIAPUnifiedJAX(MLIAPUnified):
|
||||||
|
"""Test implementation for MLIAPUnified."""
|
||||||
|
|
||||||
|
epsilon: float
|
||||||
|
sigma: float
|
||||||
|
|
||||||
|
def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25):
|
||||||
|
# ARGS: interface, element_types, ndescriptors, nparams, rcutfac
|
||||||
|
super().__init__(None, element_types, 1, 3, rcutfac)
|
||||||
|
# Mimicking the LJ pair-style:
|
||||||
|
# pair_style lj/cut 2.5
|
||||||
|
# pair_coeff * * 1 1
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
def compute_gradients(self, data):
|
||||||
|
"""Test compute_gradients."""
|
||||||
|
|
||||||
|
def compute_descriptors(self, data):
|
||||||
|
"""Test compute_descriptors."""
|
||||||
|
|
||||||
|
def compute_forces(self, data):
|
||||||
|
"""Test compute_forces."""
|
||||||
|
|
||||||
|
# NOTE: Use data.rij_max with JAX.
|
||||||
|
rij = data.rij_max
|
||||||
|
|
||||||
|
eij, fij = lj_potential(self.epsilon, self.sigma, rij)
|
||||||
|
|
||||||
|
data.update_pair_energy(np.array(eij, dtype=np.float64))
|
||||||
|
data.update_pair_forces(np.array(fij, dtype=np.float64))
|
||||||
BIN
examples/mliap/jax/mliap_unified_jax_Ar.pkl
Normal file
BIN
examples/mliap/jax/mliap_unified_jax_Ar.pkl
Normal file
Binary file not shown.
69
examples/mliap/jax/mliap_unified_jax_kokkos.py
Normal file
69
examples/mliap/jax/mliap_unified_jax_kokkos.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from lammps.mliap.mliap_unified_abc import MLIAPUnified
|
||||||
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
import jax.dlpack
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import jit
|
||||||
|
from functools import partial
|
||||||
|
import cupy
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Required else get `jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory`
|
||||||
|
# Does not fix GPU problem with larger num. atoms.
|
||||||
|
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
|
||||||
|
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
|
||||||
|
#os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def lj_potential(epsilon, sigma, rij):
|
||||||
|
# A pure function we can differentiate:
|
||||||
|
def _tot_e(rij):
|
||||||
|
r2inv = 1.0 / jnp.sum(rij ** 2, axis=1)
|
||||||
|
r6inv = r2inv * r2inv * r2inv
|
||||||
|
|
||||||
|
lj1 = 4.0 * epsilon * sigma**12
|
||||||
|
lj2 = 4.0 * epsilon * sigma**6
|
||||||
|
|
||||||
|
eij = r6inv * (lj1 * r6inv - lj2)
|
||||||
|
return 0.5 * jnp.sum(eij), eij
|
||||||
|
# Construct a function computing _tot_e and its derivative
|
||||||
|
(_, eij), fij = jax.value_and_grad(_tot_e, has_aux=True)(rij)
|
||||||
|
return eij, fij
|
||||||
|
|
||||||
|
|
||||||
|
class MLIAPUnifiedJAXKokkos(MLIAPUnified):
|
||||||
|
"""JAX wrapper for MLIAPUnified."""
|
||||||
|
|
||||||
|
epsilon: float
|
||||||
|
sigma: float
|
||||||
|
|
||||||
|
def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25):
|
||||||
|
# ARGS: interface, element_types, ndescriptors, nparams, rcutfac
|
||||||
|
super().__init__(None, element_types, 1, 3, rcutfac)
|
||||||
|
# Mimicking the LJ pair-style:
|
||||||
|
# pair_style lj/cut 2.5
|
||||||
|
# pair_coeff * * 1 1
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
def compute_gradients(self, data):
|
||||||
|
"""Test compute_gradients."""
|
||||||
|
|
||||||
|
def compute_descriptors(self, data):
|
||||||
|
"""Test compute_descriptors."""
|
||||||
|
|
||||||
|
def compute_forces(self, data):
|
||||||
|
"""Test compute_forces."""
|
||||||
|
|
||||||
|
# NOTE: Use data.rij_max with JAX.
|
||||||
|
# dlpack requires cudnn:
|
||||||
|
rij = jax.dlpack.from_dlpack(data.rij_max.toDlpack())
|
||||||
|
eij, fij = lj_potential(self.epsilon, self.sigma, rij)
|
||||||
|
|
||||||
|
# Convert back to cupy.
|
||||||
|
eij = cupy.from_dlpack(jax.dlpack.to_dlpack(eij)).astype(np.float64)
|
||||||
|
fij = cupy.from_dlpack(jax.dlpack.to_dlpack(fij)).astype(np.float64)
|
||||||
|
|
||||||
|
# Send to LAMMPS.
|
||||||
|
data.update_pair_energy(eij)
|
||||||
|
data.update_pair_forces(fij)
|
||||||
87
examples/mliap/jax/write_unified.py
Normal file
87
examples/mliap/jax/write_unified.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
interface for creating LAMMPS MLIAP Unified models.
|
||||||
|
"""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lammps.mliap.mliap_unified_abc import MLIAPUnified
|
||||||
|
#from deploy_script import MyModel
|
||||||
|
|
||||||
|
class MLIAPInterface(MLIAPUnified):
|
||||||
|
"""
|
||||||
|
Class for creating ML-IAP Unified model based on hippynn graphs.
|
||||||
|
"""
|
||||||
|
def __init__(self, model, element_types, cutoff=4.5, ndescriptors=1):
|
||||||
|
"""
|
||||||
|
:param model: class defining the model
|
||||||
|
:param element_types: list of atomic symbols corresponding to element types
|
||||||
|
:param ndescriptors: the number of descriptors to report to LAMMPS
|
||||||
|
:param model_device: the device to send torch data to (cpu or cuda)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.element_types = element_types
|
||||||
|
self.ndescriptors = ndescriptors
|
||||||
|
#self.model_device = model_device
|
||||||
|
|
||||||
|
|
||||||
|
# Build the calculator
|
||||||
|
# TODO: Make this cutoff depend on model cutoff, ideally from deployed model itself but could
|
||||||
|
# be part of deploy step.
|
||||||
|
#rc = 4.5
|
||||||
|
self.rcutfac = 0.5*cutoff # Actual cutoff will be 2*rc
|
||||||
|
#print(self.model.nparams)
|
||||||
|
self.nparams = 10
|
||||||
|
#self.rcutfac, self.species_set, self.graph = setup_LAMMPS()
|
||||||
|
#self.nparams = sum(p.nelement() for p in self.graph.parameters())
|
||||||
|
#self.graph.to(torch.float64)
|
||||||
|
|
||||||
|
def compute_descriptors(self, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_gradients(self, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_forces(self, data):
|
||||||
|
#print(">>>>> hey!")
|
||||||
|
#elems = self.as_tensor(data.elems).type(torch.int64).reshape(1, data.ntotal)
|
||||||
|
|
||||||
|
"""
|
||||||
|
elems = self.as_tensor(data.elems).type(torch.int64) + 1
|
||||||
|
#z_vals = self.species_set[elems+1]
|
||||||
|
pair_i = self.as_tensor(data.pair_i).type(torch.int64)
|
||||||
|
pair_j = self.as_tensor(data.pair_j).type(torch.int64)
|
||||||
|
rij = self.as_tensor(data.rij).type(torch.float64).requires_grad_(True)
|
||||||
|
nlocal = self.as_tensor(data.nlistatoms)
|
||||||
|
"""
|
||||||
|
|
||||||
|
rij = data.rij
|
||||||
|
|
||||||
|
#(total_energy, fij) = self.network(rij, None, None, None, nlocal, elems, pair_i, pair_j, "cpu", dtype=torch.float64, mode="lammps")
|
||||||
|
|
||||||
|
test = self.model(rij)
|
||||||
|
|
||||||
|
#data.update_pair_forces(fij)
|
||||||
|
#data.energy = total_energy.item()
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_LAMMPS(energy):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param energy: energy node for lammps interface
|
||||||
|
:return: graph for computing from lammps MLIAP unified inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = TheModelClass(*args, **kwargs)
|
||||||
|
|
||||||
|
save_state_dict = torch.load("Ta_Pytorch.pt")
|
||||||
|
model.load_state_dict(save_state_dict["model_state_dict"])
|
||||||
|
|
||||||
|
|
||||||
|
#model.load_state_dict(torch.load(PATH))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
#model.eval()
|
||||||
|
return model
|
||||||
@ -346,6 +346,16 @@ cdef class MLIAPDataPy:
|
|||||||
return None
|
return None
|
||||||
return create_array(self.data.dev, self.data.rij, [self.npairs,3],False)
|
return create_array(self.data.dev, self.data.rij, [self.npairs,3],False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rij_max(self):
|
||||||
|
if self.data.rij is NULL:
|
||||||
|
return None
|
||||||
|
return create_array(self.data.dev, self.data.rij, [self.nneigh_max,3], False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nneigh_max(self):
|
||||||
|
return self.data.nneigh_max
|
||||||
|
|
||||||
@write_only_property
|
@write_only_property
|
||||||
def graddesc(self, value):
|
def graddesc(self, value):
|
||||||
if self.data.graddesc is NULL:
|
if self.data.graddesc is NULL:
|
||||||
|
|||||||
@ -371,6 +371,25 @@ void LAMMPS_NS::update_pair_forces(MLIAPDataKokkosDevice *data, double *fij)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ----------------------------------------------------------------------
|
||||||
|
set energy for i indexed atoms
|
||||||
|
---------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
void LAMMPS_NS::update_atom_energy(MLIAPDataKokkosDevice *data, double *ei)
|
||||||
|
{
|
||||||
|
auto d_eatoms = data->eatoms;
|
||||||
|
const auto nlistatoms = data->nlistatoms;
|
||||||
|
|
||||||
|
Kokkos::parallel_reduce(nlistatoms, KOKKOS_LAMBDA(int i, double &local_sum){
|
||||||
|
double e = ei[i];
|
||||||
|
// must not count any contribution where i is not a local atom
|
||||||
|
if (i < nlistatoms) {
|
||||||
|
d_eatoms[i] = e;
|
||||||
|
local_sum += e;
|
||||||
|
}
|
||||||
|
},*data->energy);
|
||||||
|
}
|
||||||
|
|
||||||
namespace LAMMPS_NS {
|
namespace LAMMPS_NS {
|
||||||
template class MLIAPDummyModelKokkos<LMPDeviceType>;
|
template class MLIAPDummyModelKokkos<LMPDeviceType>;
|
||||||
template class MLIAPDummyDescriptorKokkos<LMPDeviceType>;
|
template class MLIAPDummyDescriptorKokkos<LMPDeviceType>;
|
||||||
|
|||||||
@ -60,6 +60,7 @@ template <class DeviceType>
|
|||||||
MLIAPBuildUnifiedKokkos_t<DeviceType> build_unified(char *, MLIAPDataKokkos<DeviceType> *, LAMMPS *, char * = NULL);
|
MLIAPBuildUnifiedKokkos_t<DeviceType> build_unified(char *, MLIAPDataKokkos<DeviceType> *, LAMMPS *, char * = NULL);
|
||||||
void update_pair_energy(MLIAPDataKokkosDevice *, double *);
|
void update_pair_energy(MLIAPDataKokkosDevice *, double *);
|
||||||
void update_pair_forces(MLIAPDataKokkosDevice *, double *);
|
void update_pair_forces(MLIAPDataKokkosDevice *, double *);
|
||||||
|
void update_atom_energy(MLIAPDataKokkosDevice *, double *);
|
||||||
|
|
||||||
} // namespace LAMMPS_NS
|
} // namespace LAMMPS_NS
|
||||||
|
|
||||||
|
|||||||
@ -281,6 +281,16 @@ cdef class MLIAPDataPy:
|
|||||||
return None
|
return None
|
||||||
return np.asarray(<double[:self.npairs, :3]> &self.data.rij[0][0])
|
return np.asarray(<double[:self.npairs, :3]> &self.data.rij[0][0])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rij_max(self):
|
||||||
|
if self.data.rij is NULL:
|
||||||
|
return None
|
||||||
|
return np.asarray(<double[:self.nneigh_max, :3]> &self.data.rij[0][0])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nneigh_max(self):
|
||||||
|
return self.data.nneigh_max
|
||||||
|
|
||||||
@write_only_property
|
@write_only_property
|
||||||
def graddesc(self, value):
|
def graddesc(self, value):
|
||||||
if self.data.graddesc is NULL:
|
if self.data.graddesc is NULL:
|
||||||
@ -357,6 +367,7 @@ cdef public object mliap_unified_connect(char *fname, MLIAPDummyModel * model,
|
|||||||
unified_int.descriptor = descriptor
|
unified_int.descriptor = descriptor
|
||||||
|
|
||||||
unified.interface = unified_int
|
unified.interface = unified_int
|
||||||
|
#print(unified_int)
|
||||||
|
|
||||||
if unified.ndescriptors is None:
|
if unified.ndescriptors is None:
|
||||||
raise ValueError("no descriptors set")
|
raise ValueError("no descriptors set")
|
||||||
|
|||||||
Reference in New Issue
Block a user