diff --git a/examples/mliap/jax/mliap_unified_jax.py b/examples/mliap/jax/mliap_unified_jax.py index 4625fbf7f7..d4e75ee5c0 100644 --- a/examples/mliap/jax/mliap_unified_jax.py +++ b/examples/mliap/jax/mliap_unified_jax.py @@ -5,10 +5,30 @@ import jax.numpy as jnp from jax import jit from functools import partial +# /- ensure epsilon and sigma are treated as compile-time constants +@partial(jit, static_argnums=(0, 1)) +def lj_potential(epsilon: float, sigma: float, rij): + # a pure function we can differentiate: + def _tot_e(rij): + r2inv = 1.0 / jnp.sum(rij ** 2, axis=1) + r6inv = r2inv * r2inv * r2inv + + lj1 = 4.0 * epsilon * sigma**12 + lj2 = 4.0 * epsilon * sigma**6 + + eij = r6inv * (lj1 * r6inv - lj2) + return jnp.sum(eij) + # /- construct a function computing _tot_e and its derivative + tot_e, fij = jax.value_and_grad(_tot_e)(rij) + return tot_e, fij + class MLIAPUnifiedJAX(MLIAPUnified): """Test implementation for MLIAPUnified.""" + epsilon: float + sigma: float + def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25): # ARGS: interface, element_types, ndescriptors, nparams, rcutfac super().__init__(None, element_types, 1, 3, rcutfac) @@ -39,23 +59,7 @@ class MLIAPUnifiedJAX(MLIAPUnified): # This might account for ~2-3x slowdown compared to original LJ. rij = np.pad(rij, ((0,npad), (0,0)), 'constant') - eij, fij = self.compute_pair_ef(rij) + e_tot, fij = lj_potential(rij) - data.update_pair_energy(np.array(np.double(eij))) - data.update_pair_forces(np.array(np.double(fij))) - - #@jax.jit # <-- This will error! See https://github.com/google/jax/issues/1251 - # @partial takes a function (e.g. jax.jit) as an arg. - @partial(jax.jit, static_argnums=(0,)) - def compute_pair_ef(self, rij): - - r2inv = 1.0 / jnp.sum(rij ** 2, axis=1) - r6inv = r2inv * r2inv * r2inv - - lj1 = 4.0 * self.epsilon * self.sigma**12 - lj2 = 4.0 * self.epsilon * self.sigma**6 - - eij = r6inv * (lj1 * r6inv - lj2) - fij = r6inv * (3.0 * lj2 - 6.0 * lj2 * r6inv) * r2inv - fij = fij[:, jnp.newaxis] * rij - return eij, fij + data.energy = e_tot.item() + data.update_pair_forces(np.array(fij, dtype=np.float64))