diff --git a/examples/mliap/jax/mliap_unified_jax.py b/examples/mliap/jax/mliap_unified_jax.py index ee09392ee8..96f9a067cc 100644 --- a/examples/mliap/jax/mliap_unified_jax.py +++ b/examples/mliap/jax/mliap_unified_jax.py @@ -15,11 +15,11 @@ def lj_potential(epsilon, sigma, rij): lj1 = 4.0 * epsilon * sigma**12 lj2 = 4.0 * epsilon * sigma**6 - eij = r6inv * (lj1 * r6inv - lj2) - return 0.5 * jnp.sum(eij) + eij = 0.5 * r6inv * (lj1 * r6inv - lj2) + return jnp.sum(eij), eij # Construct a function computing _tot_e and its derivative - tot_e, fij = jax.value_and_grad(_tot_e)(rij) - return tot_e, fij + (_, eij), fij = jax.value_and_grad(_tot_e, has_aux=True)(rij) + return eij, fij class MLIAPUnifiedJAX(MLIAPUnified): @@ -49,7 +49,7 @@ class MLIAPUnifiedJAX(MLIAPUnified): # NOTE: Use data.rij_max with JAX. rij = data.rij_max - e_tot, fij = lj_potential(self.epsilon, self.sigma, rij) + eij, fij = lj_potential(self.epsilon, self.sigma, rij) - data.energy = e_tot.item() + data.update_pair_energy(np.array(eij, dtype=np.float64)) data.update_pair_forces(np.array(fij, dtype=np.float64))