diff --git a/examples/mliap/jax/README.md b/examples/mliap/jax/README.md new file mode 100644 index 0000000000..5cc0c49f16 --- /dev/null +++ b/examples/mliap/jax/README.md @@ -0,0 +1,87 @@ +# Running JAX from LAMMPS + +### Getting started + +First make a Python environment with dependencies: + + conda create --name jax python=3.10 + conda activate jax + # Upgrade pip + python -m pip install --upgrade pip + # Install JAX: + python -m pip install --upgrade "jax[cpu]" + # Install other dependencies: + python -m pip install numpy scipy torch scikit-learn virtualenv psutil tabulate mpi4py Cython + +Install LAMMPS: + + cd /path/to/lammps + mkdir build-jax; cd build-jax + cmake ../cmake -DLAMMPS_EXCEPTIONS=yes \ + -DBUILD_SHARED_LIBS=yes \ + -DMLIAP_ENABLE_PYTHON=yes \ + -DPKG_PYTHON=yes \ + -DPKG_ML-SNAP=yes \ + -DPKG_ML-IAP=yes \ + -DPYTHON_EXECUTABLE:FILEPATH=`which python` + make -j4 + make install-python + +### Kokkos install + +Use same Python dependencies as above, with some extra changes: + +1. Make sure you install cupy properly! E.g. + + python -m pip install cupy-cuda12x + +2. Install JAX for GPU/CUDA: + + python -m pip install --trusted-host storage.googleapis.com --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + +3. Install cudNN: https://developer.nvidia.com/cudnn + +Install LAMMPS. Take care to change `Kokkos_ARCH_*` flag: + + cmake ../cmake -DLAMMPS_EXCEPTIONS=yes \ + -DBUILD_SHARED_LIBS=yes \ + -DPKG_PYTHON=yes \ + -DPKG_ML-SNAP=yes \ + -DPKG_ML-IAP=yes \ + -DMLIAP_ENABLE_PYTHON=yes \ + -DPKG_KOKKOS=yes \ + -DKokkos_ARCH_TURING75=yes \ + -DKokkos_ENABLE_CUDA=yes \ + -DKokkos_ENABLE_OPENMP=yes \ + -DCMAKE_CXX_COMPILER=${HOME}/lammps/lib/kokkos/bin/nvcc_wrapper \ + -DPYTHON_EXECUTABLE:FILEPATH=`which python` + make -j + make install-python + +Run example: + + mpirun -np 1 lmp -k on g 1 -sf kk -pk kokkos newton on -in in.run + +### Deploying JAX models on CPU + +Use `deploy_script.py`, which will wrap model with `write_unified_jax`. + + python deploy_script.py + +This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified. + +Run LAMMPS with the model: + + mpirun -np P lmp -in in.run + +### Deploying JAX models in Kokkos + +Use `deploy_script_kokkos.py`, which will wrap model with `write_unified_jax_kokkos`. + + python deploy_script_kokkos.py + +This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified. + +Run LAMMPS with the model: + + mpirun -np 1 lmp -k on g 1 -sf kk -pk kokkos newton on -in in.run \ No newline at end of file diff --git a/examples/mliap/jax/deploy_script.py b/examples/mliap/jax/deploy_script.py new file mode 100644 index 0000000000..5e73995565 --- /dev/null +++ b/examples/mliap/jax/deploy_script.py @@ -0,0 +1,11 @@ +import lammps +import lammps.mliap + +#from lammps.mliap.mliap_unified_lj import MLIAPUnifiedLJ +from mliap_unified_jax import MLIAPUnifiedJAX + +def create_pickle(): + unified = MLIAPUnifiedJAX(["Ar"]) + unified.pickle('mliap_unified_jax_Ar.pkl') + +create_pickle() \ No newline at end of file diff --git a/examples/mliap/jax/in.run b/examples/mliap/jax/in.run new file mode 100644 index 0000000000..d0560e6620 --- /dev/null +++ b/examples/mliap/jax/in.run @@ -0,0 +1,37 @@ +# 3d Lennard-Jones melt + +units lj +atom_style atomic + +lattice fcc 0.8442 +region box block 0 10 0 10 0 10 +create_box 1 box +create_atoms 1 box +mass 1 1.0 + +velocity all create 3.0 87287 loop geom + +pair_style mliap unified mliap_unified_jax_Ar.pkl 0 +pair_coeff * * Ar + +neighbor 0.3 bin +neigh_modify every 20 delay 0 check no + +fix 1 all nve + +#dump id all atom 50 dump.melt + +#dump 2 all image 25 image.*.jpg type type & +# axes yes 0.8 0.02 view 60 -30 +#dump_modify 2 pad 3 + +#dump 3 all movie 1 movie.mpg type type & +# axes yes 0.8 0.02 view 60 -30 +#dump_modify 3 pad 3 + +#dump 4 all custom 1 forces.xyz fx fy fz + +dump 1 all xyz 10 dump.xyz + +thermo 1 +run 250 diff --git a/examples/mliap/jax/mliap_jax.pkl b/examples/mliap/jax/mliap_jax.pkl new file mode 100644 index 0000000000..2cbbedb7c1 Binary files /dev/null and b/examples/mliap/jax/mliap_jax.pkl differ diff --git a/examples/mliap/jax/mliap_unified_jax.py b/examples/mliap/jax/mliap_unified_jax.py new file mode 100644 index 0000000000..69cbee6221 --- /dev/null +++ b/examples/mliap/jax/mliap_unified_jax.py @@ -0,0 +1,61 @@ +from lammps.mliap.mliap_unified_abc import MLIAPUnified +import numpy as np +import jax +import jax.numpy as jnp +from jax import jit +from functools import partial +import os + +# Required else get `jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory` +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX" +os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" + +@jax.jit +def lj_potential(epsilon, sigma, rij): + def _tot_e(rij): + """A differentiable fn for total energy.""" + r2inv = 1.0 / jnp.sum(rij ** 2, axis=1) + r6inv = r2inv * r2inv * r2inv + + lj1 = 4.0 * epsilon * sigma**12 + lj2 = 4.0 * epsilon * sigma**6 + + eij = r6inv * (lj1 * r6inv - lj2) + return 0.5 * jnp.sum(eij), eij + # Compute _tot_e and its derivative. + (_, eij), fij = jax.value_and_grad(_tot_e, has_aux=True)(rij) + return eij, fij + + +class MLIAPUnifiedJAX(MLIAPUnified): + """Test implementation for MLIAPUnified.""" + + epsilon: float + sigma: float + + def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25): + # ARGS: interface, element_types, ndescriptors, nparams, rcutfac + super().__init__(None, element_types, 1, 3, rcutfac) + # Mimicking the LJ pair-style: + # pair_style lj/cut 2.5 + # pair_coeff * * 1 1 + self.epsilon = epsilon + self.sigma = sigma + + def compute_gradients(self, data): + """Test compute_gradients.""" + + def compute_descriptors(self, data): + """Test compute_descriptors.""" + + def compute_forces(self, data): + """Test compute_forces.""" + + # NOTE: Use data.rij_max with JAX. + rij = data.rij_max + + eij, fij = lj_potential(self.epsilon, self.sigma, rij) + + data.update_pair_energy(np.array(eij, dtype=np.float64)) + data.update_pair_forces(np.array(fij, dtype=np.float64)) diff --git a/examples/mliap/jax/mliap_unified_jax_Ar.pkl b/examples/mliap/jax/mliap_unified_jax_Ar.pkl new file mode 100644 index 0000000000..04b1730315 Binary files /dev/null and b/examples/mliap/jax/mliap_unified_jax_Ar.pkl differ diff --git a/examples/mliap/jax/mliap_unified_jax_kokkos.py b/examples/mliap/jax/mliap_unified_jax_kokkos.py new file mode 100644 index 0000000000..fd7f106f47 --- /dev/null +++ b/examples/mliap/jax/mliap_unified_jax_kokkos.py @@ -0,0 +1,69 @@ +from lammps.mliap.mliap_unified_abc import MLIAPUnified +import numpy as np +import jax +import jax.dlpack +import jax.numpy as jnp +from jax import jit +from functools import partial +import cupy +import os + +# Required else get `jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory` +# Does not fix GPU problem with larger num. atoms. +#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false" +#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX" +#os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" + +@jax.jit +def lj_potential(epsilon, sigma, rij): + # A pure function we can differentiate: + def _tot_e(rij): + r2inv = 1.0 / jnp.sum(rij ** 2, axis=1) + r6inv = r2inv * r2inv * r2inv + + lj1 = 4.0 * epsilon * sigma**12 + lj2 = 4.0 * epsilon * sigma**6 + + eij = r6inv * (lj1 * r6inv - lj2) + return 0.5 * jnp.sum(eij), eij + # Construct a function computing _tot_e and its derivative + (_, eij), fij = jax.value_and_grad(_tot_e, has_aux=True)(rij) + return eij, fij + + +class MLIAPUnifiedJAXKokkos(MLIAPUnified): + """JAX wrapper for MLIAPUnified.""" + + epsilon: float + sigma: float + + def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25): + # ARGS: interface, element_types, ndescriptors, nparams, rcutfac + super().__init__(None, element_types, 1, 3, rcutfac) + # Mimicking the LJ pair-style: + # pair_style lj/cut 2.5 + # pair_coeff * * 1 1 + self.epsilon = epsilon + self.sigma = sigma + + def compute_gradients(self, data): + """Test compute_gradients.""" + + def compute_descriptors(self, data): + """Test compute_descriptors.""" + + def compute_forces(self, data): + """Test compute_forces.""" + + # NOTE: Use data.rij_max with JAX. + # dlpack requires cudnn: + rij = jax.dlpack.from_dlpack(data.rij_max.toDlpack()) + eij, fij = lj_potential(self.epsilon, self.sigma, rij) + + # Convert back to cupy. + eij = cupy.from_dlpack(jax.dlpack.to_dlpack(eij)).astype(np.float64) + fij = cupy.from_dlpack(jax.dlpack.to_dlpack(fij)).astype(np.float64) + + # Send to LAMMPS. + data.update_pair_energy(eij) + data.update_pair_forces(fij) diff --git a/examples/mliap/jax/write_unified.py b/examples/mliap/jax/write_unified.py new file mode 100644 index 0000000000..af2627dfef --- /dev/null +++ b/examples/mliap/jax/write_unified.py @@ -0,0 +1,87 @@ +""" +interface for creating LAMMPS MLIAP Unified models. +""" +import pickle + +import numpy as np + +from lammps.mliap.mliap_unified_abc import MLIAPUnified +#from deploy_script import MyModel + +class MLIAPInterface(MLIAPUnified): + """ + Class for creating ML-IAP Unified model based on hippynn graphs. + """ + def __init__(self, model, element_types, cutoff=4.5, ndescriptors=1): + """ + :param model: class defining the model + :param element_types: list of atomic symbols corresponding to element types + :param ndescriptors: the number of descriptors to report to LAMMPS + :param model_device: the device to send torch data to (cpu or cuda) + """ + super().__init__() + self.model = model + self.element_types = element_types + self.ndescriptors = ndescriptors + #self.model_device = model_device + + + # Build the calculator + # TODO: Make this cutoff depend on model cutoff, ideally from deployed model itself but could + # be part of deploy step. + #rc = 4.5 + self.rcutfac = 0.5*cutoff # Actual cutoff will be 2*rc + #print(self.model.nparams) + self.nparams = 10 + #self.rcutfac, self.species_set, self.graph = setup_LAMMPS() + #self.nparams = sum(p.nelement() for p in self.graph.parameters()) + #self.graph.to(torch.float64) + + def compute_descriptors(self, data): + pass + + def compute_gradients(self, data): + pass + + def compute_forces(self, data): + #print(">>>>> hey!") + #elems = self.as_tensor(data.elems).type(torch.int64).reshape(1, data.ntotal) + + """ + elems = self.as_tensor(data.elems).type(torch.int64) + 1 + #z_vals = self.species_set[elems+1] + pair_i = self.as_tensor(data.pair_i).type(torch.int64) + pair_j = self.as_tensor(data.pair_j).type(torch.int64) + rij = self.as_tensor(data.rij).type(torch.float64).requires_grad_(True) + nlocal = self.as_tensor(data.nlistatoms) + """ + + rij = data.rij + + #(total_energy, fij) = self.network(rij, None, None, None, nlocal, elems, pair_i, pair_j, "cpu", dtype=torch.float64, mode="lammps") + + test = self.model(rij) + + #data.update_pair_forces(fij) + #data.energy = total_energy.item() + + pass + +def setup_LAMMPS(energy): + """ + + :param energy: energy node for lammps interface + :return: graph for computing from lammps MLIAP unified inputs. + """ + + model = TheModelClass(*args, **kwargs) + + save_state_dict = torch.load("Ta_Pytorch.pt") + model.load_state_dict(save_state_dict["model_state_dict"]) + + + #model.load_state_dict(torch.load(PATH)) + model.eval() + + #model.eval() + return model \ No newline at end of file diff --git a/src/KOKKOS/mliap_unified_couple_kokkos.pyx b/src/KOKKOS/mliap_unified_couple_kokkos.pyx index 5e274d7f36..bd94b79eb4 100644 --- a/src/KOKKOS/mliap_unified_couple_kokkos.pyx +++ b/src/KOKKOS/mliap_unified_couple_kokkos.pyx @@ -346,6 +346,16 @@ cdef class MLIAPDataPy: return None return create_array(self.data.dev, self.data.rij, [self.npairs,3],False) + @property + def rij_max(self): + if self.data.rij is NULL: + return None + return create_array(self.data.dev, self.data.rij, [self.nneigh_max,3], False) + + @property + def nneigh_max(self): + return self.data.nneigh_max + @write_only_property def graddesc(self, value): if self.data.graddesc is NULL: diff --git a/src/KOKKOS/mliap_unified_kokkos.cpp b/src/KOKKOS/mliap_unified_kokkos.cpp index 1fdf039473..deb9cbc346 100644 --- a/src/KOKKOS/mliap_unified_kokkos.cpp +++ b/src/KOKKOS/mliap_unified_kokkos.cpp @@ -371,6 +371,25 @@ void LAMMPS_NS::update_pair_forces(MLIAPDataKokkosDevice *data, double *fij) } } +/* ---------------------------------------------------------------------- + set energy for i indexed atoms + ---------------------------------------------------------------------- */ + +void LAMMPS_NS::update_atom_energy(MLIAPDataKokkosDevice *data, double *ei) +{ + auto d_eatoms = data->eatoms; + const auto nlistatoms = data->nlistatoms; + + Kokkos::parallel_reduce(nlistatoms, KOKKOS_LAMBDA(int i, double &local_sum){ + double e = ei[i]; + // must not count any contribution where i is not a local atom + if (i < nlistatoms) { + d_eatoms[i] = e; + local_sum += e; + } + },*data->energy); +} + namespace LAMMPS_NS { template class MLIAPDummyModelKokkos; template class MLIAPDummyDescriptorKokkos; diff --git a/src/KOKKOS/mliap_unified_kokkos.h b/src/KOKKOS/mliap_unified_kokkos.h index aad25891b0..1af7715dad 100644 --- a/src/KOKKOS/mliap_unified_kokkos.h +++ b/src/KOKKOS/mliap_unified_kokkos.h @@ -60,6 +60,7 @@ template MLIAPBuildUnifiedKokkos_t build_unified(char *, MLIAPDataKokkos *, LAMMPS *, char * = NULL); void update_pair_energy(MLIAPDataKokkosDevice *, double *); void update_pair_forces(MLIAPDataKokkosDevice *, double *); +void update_atom_energy(MLIAPDataKokkosDevice *, double *); } // namespace LAMMPS_NS diff --git a/src/ML-IAP/mliap_unified_couple.pyx b/src/ML-IAP/mliap_unified_couple.pyx index 3fde99a25e..25852a1c5f 100644 --- a/src/ML-IAP/mliap_unified_couple.pyx +++ b/src/ML-IAP/mliap_unified_couple.pyx @@ -281,6 +281,16 @@ cdef class MLIAPDataPy: return None return np.asarray( &self.data.rij[0][0]) + @property + def rij_max(self): + if self.data.rij is NULL: + return None + return np.asarray( &self.data.rij[0][0]) + + @property + def nneigh_max(self): + return self.data.nneigh_max + @write_only_property def graddesc(self, value): if self.data.graddesc is NULL: @@ -357,6 +367,7 @@ cdef public object mliap_unified_connect(char *fname, MLIAPDummyModel * model, unified_int.descriptor = descriptor unified.interface = unified_int + #print(unified_int) if unified.ndescriptors is None: raise ValueError("no descriptors set")