diff --git a/examples/mliap/jax/mliap_unified_jax.py b/examples/mliap/jax/mliap_unified_jax.py index 96f9a067cc..42134c19d2 100644 --- a/examples/mliap/jax/mliap_unified_jax.py +++ b/examples/mliap/jax/mliap_unified_jax.py @@ -15,8 +15,8 @@ def lj_potential(epsilon, sigma, rij): lj1 = 4.0 * epsilon * sigma**12 lj2 = 4.0 * epsilon * sigma**6 - eij = 0.5 * r6inv * (lj1 * r6inv - lj2) - return jnp.sum(eij), eij + eij = r6inv * (lj1 * r6inv - lj2) + return 0.5 * jnp.sum(eij), eij # Construct a function computing _tot_e and its derivative (_, eij), fij = jax.value_and_grad(_tot_e, has_aux=True)(rij) return eij, fij