diff --git a/examples/mliap/jax/mliap_unified_jax.py b/examples/mliap/jax/mliap_unified_jax.py index b986b1247e..4625fbf7f7 100644 --- a/examples/mliap/jax/mliap_unified_jax.py +++ b/examples/mliap/jax/mliap_unified_jax.py @@ -17,6 +17,7 @@ class MLIAPUnifiedJAX(MLIAPUnified): # pair_coeff * * 1 1 self.epsilon = epsilon self.sigma = sigma + # TODO: Take this from the LAMMPS Cython side. self.npair_max = 250000 def compute_gradients(self, data): @@ -48,7 +49,7 @@ class MLIAPUnifiedJAX(MLIAPUnified): @partial(jax.jit, static_argnums=(0,)) def compute_pair_ef(self, rij): - r2inv = 1.0 / np.sum(rij ** 2, axis=1) + r2inv = 1.0 / jnp.sum(rij ** 2, axis=1) r6inv = r2inv * r2inv * r2inv lj1 = 4.0 * self.epsilon * self.sigma**12 @@ -56,5 +57,5 @@ class MLIAPUnifiedJAX(MLIAPUnified): eij = r6inv * (lj1 * r6inv - lj2) fij = r6inv * (3.0 * lj2 - 6.0 * lj2 * r6inv) * r2inv - fij = fij[:, np.newaxis] * rij - return eij, fij \ No newline at end of file + fij = fij[:, jnp.newaxis] * rij + return eij, fij