From 6977f71eb0ad15e370e330739cf6cf937f4ceca0 Mon Sep 17 00:00:00 2001 From: rohskopf Date: Sat, 20 May 2023 13:53:22 -0600 Subject: [PATCH] Initial example --- examples/mliap/jax/README.md | 38 +++++++++++ examples/mliap/jax/deploy_script.py | 32 +++++++++ examples/mliap/jax/in.run | 48 +++++++++++++ examples/mliap/jax/mliap_jax.pkl | Bin 0 -> 234 bytes examples/mliap/jax/mliap_unified_jax.py | 0 examples/mliap/jax/write_unified.py | 86 ++++++++++++++++++++++++ 6 files changed, 204 insertions(+) create mode 100644 examples/mliap/jax/README.md create mode 100644 examples/mliap/jax/deploy_script.py create mode 100644 examples/mliap/jax/in.run create mode 100644 examples/mliap/jax/mliap_jax.pkl create mode 100644 examples/mliap/jax/mliap_unified_jax.py create mode 100644 examples/mliap/jax/write_unified.py diff --git a/examples/mliap/jax/README.md b/examples/mliap/jax/README.md new file mode 100644 index 0000000000..14a049221a --- /dev/null +++ b/examples/mliap/jax/README.md @@ -0,0 +1,38 @@ +# Running JAX from LAMMPS + +### Getting started + +First make a Python environment with dependencies: + + conda create --name jax python=3.10 + conda activate jax + # Upgrade pip + python -m pip install --upgrade pip + # Install JAX: + python -m pip install --upgrade "jax[cpu]" + # Install other dependencies: + python -m pip install numpy scipy torch scikit-learn virtualenv psutil tabulate mpi4py Cython + +Install LAMMPS: + + cd /path/to/lammps + mkdir build-jax; cd build-jax + cmake ../cmake -DLAMMPS_EXCEPTIONS=yes \ + -DBUILD_SHARED_LIBS=yes \ + -DMLIAP_ENABLE_PYTHON=yes \ + -DPKG_PYTHON=yes \ + -DPKG_ML-SNAP=yes \ + -DPKG_ML-IAP=yes \ + -DPYTHON_EXECUTABLE:FILEPATH=`which python` + make -j4 + make install-python + +### Wrapping JAX code + +Take inspiration from the `FitSNAP` ML-IAP wrapper: https://github.com/rohskopf/FitSNAP/blob/mliap-unified/fitsnap3lib/tools/write_unified.py + +First define JAX model in `deploy_script.py`, which will wrap model with `write_unified`. + + python deploy_script.py + +Then load model in LAMMPS and run: \ No newline at end of file diff --git a/examples/mliap/jax/deploy_script.py b/examples/mliap/jax/deploy_script.py new file mode 100644 index 0000000000..58af797766 --- /dev/null +++ b/examples/mliap/jax/deploy_script.py @@ -0,0 +1,32 @@ +import numpy as np +import pickle +from pathlib import Path +from write_unified import MLIAPInterface + +class MyModel(): + def __init__(self,blah): + """ + coeffs = np.genfromtxt(file,skip_header=6) + self.bias = coeffs[0] + self.weights = coeffs[1:] + """ + self.blah = blah + self.n_params = 3 #len(coeffs) + self.n_descriptors = 1 #len(self.weights) + self.n_elements = 1 + + def __call__(self,rij): + print(rij) + #energy[:] = bispectrum @ self.weights + self.bias + #beta[:] = self.weights + return 5 + +model = MyModel(1) + +#unified = MLIAPInterface(model, ["Ta"], model_device="cpu") + +def create_pickle(): + unified = MLIAPInterface(model, ["Ta"]) + unified.pickle('mliap_jax.pkl') + +create_pickle() \ No newline at end of file diff --git a/examples/mliap/jax/in.run b/examples/mliap/jax/in.run new file mode 100644 index 0000000000..1ffc2d3c0f --- /dev/null +++ b/examples/mliap/jax/in.run @@ -0,0 +1,48 @@ +# Initialize simulation + +variable nsteps index 10000 +units metal + +# generate the box and atom positions using a BCC lattice + +#boundary p p p +#read_data DATA + +variable nrep equal 2 #10 +variable a equal 3.316 +variable nx equal ${nrep} +variable ny equal ${nrep} +variable nz equal ${nrep} +lattice bcc $a +region box block 0 ${nx} 0 ${ny} 0 ${nz} +create_box 1 box +create_atoms 1 box +mass 1 180.88 + +pair_style mliap unified mliap_jax.pkl 0 +pair_coeff * * Ta + +compute eatom all pe/atom +compute energy all reduce sum c_eatom + +compute satom all stress/atom NULL +compute str all reduce sum c_satom[1] c_satom[2] c_satom[3] +variable press equal (c_str[1]+c_str[2]+c_str[3])/(3*vol) + +thermo_style custom step temp epair c_energy etotal press v_press +thermo 10 +thermo_modify norm yes + +# Set up NVE run + +timestep 0.5e-3 +neighbor 1.0 bin +# is this neigh modify every 1 slow? +neigh_modify once no every 1 delay 0 check yes + +# Run MD + +velocity all create 3200.0 4928459 loop geom +dump 1 all xyz 10 dump.xyz +fix 1 all nve +run ${nsteps} \ No newline at end of file diff --git a/examples/mliap/jax/mliap_jax.pkl b/examples/mliap/jax/mliap_jax.pkl new file mode 100644 index 0000000000000000000000000000000000000000..2cbbedb7c1bd8dd8af97d7312471cf2054e163d1 GIT binary patch literal 234 zcmZ9GyAFat5Ji1pP$GVZHhzOvhA76u)=U-{<0deWWr?9M(SAC|@AFs0XW^}GZgI{% zJGXo59rlk#TZD@AUSlFW3Rl5=6Ocu-S24;}CKox$m>>>> hey!") + #elems = self.as_tensor(data.elems).type(torch.int64).reshape(1, data.ntotal) + + """ + elems = self.as_tensor(data.elems).type(torch.int64) + 1 + #z_vals = self.species_set[elems+1] + pair_i = self.as_tensor(data.pair_i).type(torch.int64) + pair_j = self.as_tensor(data.pair_j).type(torch.int64) + rij = self.as_tensor(data.rij).type(torch.float64).requires_grad_(True) + nlocal = self.as_tensor(data.nlistatoms) + """ + + rij = data.rij + + #(total_energy, fij) = self.network(rij, None, None, None, nlocal, elems, pair_i, pair_j, "cpu", dtype=torch.float64, mode="lammps") + + test = self.model(rij) + + #data.update_pair_forces(fij) + #data.energy = total_energy.item() + + pass + +def setup_LAMMPS(energy): + """ + + :param energy: energy node for lammps interface + :return: graph for computing from lammps MLIAP unified inputs. + """ + + model = TheModelClass(*args, **kwargs) + + save_state_dict = torch.load("Ta_Pytorch.pt") + model.load_state_dict(save_state_dict["model_state_dict"]) + + + #model.load_state_dict(torch.load(PATH)) + model.eval() + + #model.eval() + return model \ No newline at end of file