Files
lammps/examples/mliap/jax
2023-05-24 13:08:10 -06:00
..
2023-05-20 14:08:20 -06:00
2023-05-24 11:39:05 -06:00
2023-05-20 13:53:22 -06:00
2023-05-24 13:08:10 -06:00
2023-05-20 14:08:20 -06:00
2023-05-20 14:08:20 -06:00

Running JAX from LAMMPS

Getting started

First make a Python environment with dependencies:

conda create --name jax python=3.10
conda activate jax
# Upgrade pip
python -m pip install --upgrade pip
# Install JAX:
python -m pip install --upgrade "jax[cpu]"
# Install other dependencies:
python -m pip install numpy scipy torch scikit-learn virtualenv psutil tabulate mpi4py Cython

Install LAMMPS:

cd /path/to/lammps
mkdir build-jax; cd build-jax
cmake ../cmake -DLAMMPS_EXCEPTIONS=yes \
               -DBUILD_SHARED_LIBS=yes \
               -DMLIAP_ENABLE_PYTHON=yes \
               -DPKG_PYTHON=yes \
               -DPKG_ML-SNAP=yes \
               -DPKG_ML-IAP=yes \
               -DPYTHON_EXECUTABLE:FILEPATH=`which python`
make -j4
make install-python

Wrapping JAX code

Take inspiration from the FitSNAP ML-IAP wrapper: https://github.com/rohskopf/FitSNAP/blob/mliap-unified/fitsnap3lib/tools/write_unified.py

First define JAX model in deploy_script.py, which will wrap model with write_unified.

python deploy_script.py

This creates .pkl file to be loaded by LAMMPS ML-IAP Unified.

Run LAMMPS with the model:

mpirun -np P lmp -in in.run