Initial example

This commit is contained in:
rohskopf
2023-05-20 13:53:22 -06:00
parent 4aee151b0e
commit 6977f71eb0
6 changed files with 204 additions and 0 deletions

View File

@ -0,0 +1,38 @@
# Running JAX from LAMMPS
### Getting started
First make a Python environment with dependencies:
conda create --name jax python=3.10
conda activate jax
# Upgrade pip
python -m pip install --upgrade pip
# Install JAX:
python -m pip install --upgrade "jax[cpu]"
# Install other dependencies:
python -m pip install numpy scipy torch scikit-learn virtualenv psutil tabulate mpi4py Cython
Install LAMMPS:
cd /path/to/lammps
mkdir build-jax; cd build-jax
cmake ../cmake -DLAMMPS_EXCEPTIONS=yes \
-DBUILD_SHARED_LIBS=yes \
-DMLIAP_ENABLE_PYTHON=yes \
-DPKG_PYTHON=yes \
-DPKG_ML-SNAP=yes \
-DPKG_ML-IAP=yes \
-DPYTHON_EXECUTABLE:FILEPATH=`which python`
make -j4
make install-python
### Wrapping JAX code
Take inspiration from the `FitSNAP` ML-IAP wrapper: https://github.com/rohskopf/FitSNAP/blob/mliap-unified/fitsnap3lib/tools/write_unified.py
First define JAX model in `deploy_script.py`, which will wrap model with `write_unified`.
python deploy_script.py
Then load model in LAMMPS and run:

View File

@ -0,0 +1,32 @@
import numpy as np
import pickle
from pathlib import Path
from write_unified import MLIAPInterface
class MyModel():
def __init__(self,blah):
"""
coeffs = np.genfromtxt(file,skip_header=6)
self.bias = coeffs[0]
self.weights = coeffs[1:]
"""
self.blah = blah
self.n_params = 3 #len(coeffs)
self.n_descriptors = 1 #len(self.weights)
self.n_elements = 1
def __call__(self,rij):
print(rij)
#energy[:] = bispectrum @ self.weights + self.bias
#beta[:] = self.weights
return 5
model = MyModel(1)
#unified = MLIAPInterface(model, ["Ta"], model_device="cpu")
def create_pickle():
unified = MLIAPInterface(model, ["Ta"])
unified.pickle('mliap_jax.pkl')
create_pickle()

48
examples/mliap/jax/in.run Normal file
View File

@ -0,0 +1,48 @@
# Initialize simulation
variable nsteps index 10000
units metal
# generate the box and atom positions using a BCC lattice
#boundary p p p
#read_data DATA
variable nrep equal 2 #10
variable a equal 3.316
variable nx equal ${nrep}
variable ny equal ${nrep}
variable nz equal ${nrep}
lattice bcc $a
region box block 0 ${nx} 0 ${ny} 0 ${nz}
create_box 1 box
create_atoms 1 box
mass 1 180.88
pair_style mliap unified mliap_jax.pkl 0
pair_coeff * * Ta
compute eatom all pe/atom
compute energy all reduce sum c_eatom
compute satom all stress/atom NULL
compute str all reduce sum c_satom[1] c_satom[2] c_satom[3]
variable press equal (c_str[1]+c_str[2]+c_str[3])/(3*vol)
thermo_style custom step temp epair c_energy etotal press v_press
thermo 10
thermo_modify norm yes
# Set up NVE run
timestep 0.5e-3
neighbor 1.0 bin
# is this neigh modify every 1 slow?
neigh_modify once no every 1 delay 0 check yes
# Run MD
velocity all create 3200.0 4928459 loop geom
dump 1 all xyz 10 dump.xyz
fix 1 all nve
run ${nsteps}

Binary file not shown.

View File

View File

@ -0,0 +1,86 @@
"""
interface for creating LAMMPS MLIAP Unified models.
"""
import pickle
import numpy as np
from lammps.mliap.mliap_unified_abc import MLIAPUnified
class MLIAPInterface(MLIAPUnified):
"""
Class for creating ML-IAP Unified model based on hippynn graphs.
"""
def __init__(self, model, element_types, cutoff=4.5, ndescriptors=1):
"""
:param model: class defining the model
:param element_types: list of atomic symbols corresponding to element types
:param ndescriptors: the number of descriptors to report to LAMMPS
:param model_device: the device to send torch data to (cpu or cuda)
"""
super().__init__()
self.model = model
self.element_types = element_types
self.ndescriptors = ndescriptors
#self.model_device = model_device
# Build the calculator
# TODO: Make this cutoff depend on model cutoff, ideally from deployed model itself but could
# be part of deploy step.
#rc = 4.5
self.rcutfac = 0.5*cutoff # Actual cutoff will be 2*rc
#print(self.model.nparams)
self.nparams = 10
#self.rcutfac, self.species_set, self.graph = setup_LAMMPS()
#self.nparams = sum(p.nelement() for p in self.graph.parameters())
#self.graph.to(torch.float64)
def compute_descriptors(self, data):
pass
def compute_gradients(self, data):
pass
def compute_forces(self, data):
#print(">>>>> hey!")
#elems = self.as_tensor(data.elems).type(torch.int64).reshape(1, data.ntotal)
"""
elems = self.as_tensor(data.elems).type(torch.int64) + 1
#z_vals = self.species_set[elems+1]
pair_i = self.as_tensor(data.pair_i).type(torch.int64)
pair_j = self.as_tensor(data.pair_j).type(torch.int64)
rij = self.as_tensor(data.rij).type(torch.float64).requires_grad_(True)
nlocal = self.as_tensor(data.nlistatoms)
"""
rij = data.rij
#(total_energy, fij) = self.network(rij, None, None, None, nlocal, elems, pair_i, pair_j, "cpu", dtype=torch.float64, mode="lammps")
test = self.model(rij)
#data.update_pair_forces(fij)
#data.energy = total_energy.item()
pass
def setup_LAMMPS(energy):
"""
:param energy: energy node for lammps interface
:return: graph for computing from lammps MLIAP unified inputs.
"""
model = TheModelClass(*args, **kwargs)
save_state_dict = torch.load("Ta_Pytorch.pt")
model.load_state_dict(save_state_dict["model_state_dict"])
#model.load_state_dict(torch.load(PATH))
model.eval()
#model.eval()
return model