diff --git a/examples/mliap/jax/in.run b/examples/mliap/jax/in.run index 354dfd769f..0685d58f08 100644 --- a/examples/mliap/jax/in.run +++ b/examples/mliap/jax/in.run @@ -32,4 +32,4 @@ fix 1 all nve #dump 4 all custom 1 forces.xyz fx fy fz thermo 50 -run 250 \ No newline at end of file +run 100 \ No newline at end of file diff --git a/examples/mliap/jax/mliap_unified_jax.py b/examples/mliap/jax/mliap_unified_jax.py index a2447d1cdb..b986b1247e 100644 --- a/examples/mliap/jax/mliap_unified_jax.py +++ b/examples/mliap/jax/mliap_unified_jax.py @@ -1,5 +1,9 @@ from lammps.mliap.mliap_unified_abc import MLIAPUnified import numpy as np +import jax +import jax.numpy as jnp +from jax import jit +from functools import partial class MLIAPUnifiedJAX(MLIAPUnified): @@ -13,6 +17,7 @@ class MLIAPUnifiedJAX(MLIAPUnified): # pair_coeff * * 1 1 self.epsilon = epsilon self.sigma = sigma + self.npair_max = 250000 def compute_gradients(self, data): """Test compute_gradients.""" @@ -22,12 +27,26 @@ class MLIAPUnifiedJAX(MLIAPUnified): def compute_forces(self, data): """Test compute_forces.""" - eij, fij = self.compute_pair_ef(data) - data.update_pair_energy(eij) - data.update_pair_forces(fij) + rij = data.rij - def compute_pair_ef(self, data): - rij = data.rij + # TODO: Take max npairs from the LAMMPS Cython side. + if (data.npairs > self.npair_max): + self.npair_max = data.npairs + + npad = self.npair_max - data.npairs + # TODO: Take pre-padded rij from the LAMMPS Cython side. + # This might account for ~2-3x slowdown compared to original LJ. + rij = np.pad(rij, ((0,npad), (0,0)), 'constant') + + eij, fij = self.compute_pair_ef(rij) + + data.update_pair_energy(np.array(np.double(eij))) + data.update_pair_forces(np.array(np.double(fij))) + + #@jax.jit # <-- This will error! See https://github.com/google/jax/issues/1251 + # @partial takes a function (e.g. jax.jit) as an arg. + @partial(jax.jit, static_argnums=(0,)) + def compute_pair_ef(self, rij): r2inv = 1.0 / np.sum(rij ** 2, axis=1) r6inv = r2inv * r2inv * r2inv