Initial example
This commit is contained in:
38
examples/mliap/jax/README.md
Normal file
38
examples/mliap/jax/README.md
Normal file
@ -0,0 +1,38 @@
|
||||
# Running JAX from LAMMPS
|
||||
|
||||
### Getting started
|
||||
|
||||
First make a Python environment with dependencies:
|
||||
|
||||
conda create --name jax python=3.10
|
||||
conda activate jax
|
||||
# Upgrade pip
|
||||
python -m pip install --upgrade pip
|
||||
# Install JAX:
|
||||
python -m pip install --upgrade "jax[cpu]"
|
||||
# Install other dependencies:
|
||||
python -m pip install numpy scipy torch scikit-learn virtualenv psutil tabulate mpi4py Cython
|
||||
|
||||
Install LAMMPS:
|
||||
|
||||
cd /path/to/lammps
|
||||
mkdir build-jax; cd build-jax
|
||||
cmake ../cmake -DLAMMPS_EXCEPTIONS=yes \
|
||||
-DBUILD_SHARED_LIBS=yes \
|
||||
-DMLIAP_ENABLE_PYTHON=yes \
|
||||
-DPKG_PYTHON=yes \
|
||||
-DPKG_ML-SNAP=yes \
|
||||
-DPKG_ML-IAP=yes \
|
||||
-DPYTHON_EXECUTABLE:FILEPATH=`which python`
|
||||
make -j4
|
||||
make install-python
|
||||
|
||||
### Wrapping JAX code
|
||||
|
||||
Take inspiration from the `FitSNAP` ML-IAP wrapper: https://github.com/rohskopf/FitSNAP/blob/mliap-unified/fitsnap3lib/tools/write_unified.py
|
||||
|
||||
First define JAX model in `deploy_script.py`, which will wrap model with `write_unified`.
|
||||
|
||||
python deploy_script.py
|
||||
|
||||
Then load model in LAMMPS and run:
|
||||
32
examples/mliap/jax/deploy_script.py
Normal file
32
examples/mliap/jax/deploy_script.py
Normal file
@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from write_unified import MLIAPInterface
|
||||
|
||||
class MyModel():
|
||||
def __init__(self,blah):
|
||||
"""
|
||||
coeffs = np.genfromtxt(file,skip_header=6)
|
||||
self.bias = coeffs[0]
|
||||
self.weights = coeffs[1:]
|
||||
"""
|
||||
self.blah = blah
|
||||
self.n_params = 3 #len(coeffs)
|
||||
self.n_descriptors = 1 #len(self.weights)
|
||||
self.n_elements = 1
|
||||
|
||||
def __call__(self,rij):
|
||||
print(rij)
|
||||
#energy[:] = bispectrum @ self.weights + self.bias
|
||||
#beta[:] = self.weights
|
||||
return 5
|
||||
|
||||
model = MyModel(1)
|
||||
|
||||
#unified = MLIAPInterface(model, ["Ta"], model_device="cpu")
|
||||
|
||||
def create_pickle():
|
||||
unified = MLIAPInterface(model, ["Ta"])
|
||||
unified.pickle('mliap_jax.pkl')
|
||||
|
||||
create_pickle()
|
||||
48
examples/mliap/jax/in.run
Normal file
48
examples/mliap/jax/in.run
Normal file
@ -0,0 +1,48 @@
|
||||
# Initialize simulation
|
||||
|
||||
variable nsteps index 10000
|
||||
units metal
|
||||
|
||||
# generate the box and atom positions using a BCC lattice
|
||||
|
||||
#boundary p p p
|
||||
#read_data DATA
|
||||
|
||||
variable nrep equal 2 #10
|
||||
variable a equal 3.316
|
||||
variable nx equal ${nrep}
|
||||
variable ny equal ${nrep}
|
||||
variable nz equal ${nrep}
|
||||
lattice bcc $a
|
||||
region box block 0 ${nx} 0 ${ny} 0 ${nz}
|
||||
create_box 1 box
|
||||
create_atoms 1 box
|
||||
mass 1 180.88
|
||||
|
||||
pair_style mliap unified mliap_jax.pkl 0
|
||||
pair_coeff * * Ta
|
||||
|
||||
compute eatom all pe/atom
|
||||
compute energy all reduce sum c_eatom
|
||||
|
||||
compute satom all stress/atom NULL
|
||||
compute str all reduce sum c_satom[1] c_satom[2] c_satom[3]
|
||||
variable press equal (c_str[1]+c_str[2]+c_str[3])/(3*vol)
|
||||
|
||||
thermo_style custom step temp epair c_energy etotal press v_press
|
||||
thermo 10
|
||||
thermo_modify norm yes
|
||||
|
||||
# Set up NVE run
|
||||
|
||||
timestep 0.5e-3
|
||||
neighbor 1.0 bin
|
||||
# is this neigh modify every 1 slow?
|
||||
neigh_modify once no every 1 delay 0 check yes
|
||||
|
||||
# Run MD
|
||||
|
||||
velocity all create 3200.0 4928459 loop geom
|
||||
dump 1 all xyz 10 dump.xyz
|
||||
fix 1 all nve
|
||||
run ${nsteps}
|
||||
BIN
examples/mliap/jax/mliap_jax.pkl
Normal file
BIN
examples/mliap/jax/mliap_jax.pkl
Normal file
Binary file not shown.
0
examples/mliap/jax/mliap_unified_jax.py
Normal file
0
examples/mliap/jax/mliap_unified_jax.py
Normal file
86
examples/mliap/jax/write_unified.py
Normal file
86
examples/mliap/jax/write_unified.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""
|
||||
interface for creating LAMMPS MLIAP Unified models.
|
||||
"""
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lammps.mliap.mliap_unified_abc import MLIAPUnified
|
||||
|
||||
class MLIAPInterface(MLIAPUnified):
|
||||
"""
|
||||
Class for creating ML-IAP Unified model based on hippynn graphs.
|
||||
"""
|
||||
def __init__(self, model, element_types, cutoff=4.5, ndescriptors=1):
|
||||
"""
|
||||
:param model: class defining the model
|
||||
:param element_types: list of atomic symbols corresponding to element types
|
||||
:param ndescriptors: the number of descriptors to report to LAMMPS
|
||||
:param model_device: the device to send torch data to (cpu or cuda)
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.element_types = element_types
|
||||
self.ndescriptors = ndescriptors
|
||||
#self.model_device = model_device
|
||||
|
||||
|
||||
# Build the calculator
|
||||
# TODO: Make this cutoff depend on model cutoff, ideally from deployed model itself but could
|
||||
# be part of deploy step.
|
||||
#rc = 4.5
|
||||
self.rcutfac = 0.5*cutoff # Actual cutoff will be 2*rc
|
||||
#print(self.model.nparams)
|
||||
self.nparams = 10
|
||||
#self.rcutfac, self.species_set, self.graph = setup_LAMMPS()
|
||||
#self.nparams = sum(p.nelement() for p in self.graph.parameters())
|
||||
#self.graph.to(torch.float64)
|
||||
|
||||
def compute_descriptors(self, data):
|
||||
pass
|
||||
|
||||
def compute_gradients(self, data):
|
||||
pass
|
||||
|
||||
def compute_forces(self, data):
|
||||
#print(">>>>> hey!")
|
||||
#elems = self.as_tensor(data.elems).type(torch.int64).reshape(1, data.ntotal)
|
||||
|
||||
"""
|
||||
elems = self.as_tensor(data.elems).type(torch.int64) + 1
|
||||
#z_vals = self.species_set[elems+1]
|
||||
pair_i = self.as_tensor(data.pair_i).type(torch.int64)
|
||||
pair_j = self.as_tensor(data.pair_j).type(torch.int64)
|
||||
rij = self.as_tensor(data.rij).type(torch.float64).requires_grad_(True)
|
||||
nlocal = self.as_tensor(data.nlistatoms)
|
||||
"""
|
||||
|
||||
rij = data.rij
|
||||
|
||||
#(total_energy, fij) = self.network(rij, None, None, None, nlocal, elems, pair_i, pair_j, "cpu", dtype=torch.float64, mode="lammps")
|
||||
|
||||
test = self.model(rij)
|
||||
|
||||
#data.update_pair_forces(fij)
|
||||
#data.energy = total_energy.item()
|
||||
|
||||
pass
|
||||
|
||||
def setup_LAMMPS(energy):
|
||||
"""
|
||||
|
||||
:param energy: energy node for lammps interface
|
||||
:return: graph for computing from lammps MLIAP unified inputs.
|
||||
"""
|
||||
|
||||
model = TheModelClass(*args, **kwargs)
|
||||
|
||||
save_state_dict = torch.load("Ta_Pytorch.pt")
|
||||
model.load_state_dict(save_state_dict["model_state_dict"])
|
||||
|
||||
|
||||
#model.load_state_dict(torch.load(PATH))
|
||||
model.eval()
|
||||
|
||||
#model.eval()
|
||||
return model
|
||||
Reference in New Issue
Block a user