Update MLIAP JAX example to use jax.grad
This commit is contained in:
@ -5,10 +5,30 @@ import jax.numpy as jnp
|
|||||||
from jax import jit
|
from jax import jit
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
# /- ensure epsilon and sigma are treated as compile-time constants
|
||||||
|
@partial(jit, static_argnums=(0, 1))
|
||||||
|
def lj_potential(epsilon: float, sigma: float, rij):
|
||||||
|
# a pure function we can differentiate:
|
||||||
|
def _tot_e(rij):
|
||||||
|
r2inv = 1.0 / jnp.sum(rij ** 2, axis=1)
|
||||||
|
r6inv = r2inv * r2inv * r2inv
|
||||||
|
|
||||||
|
lj1 = 4.0 * epsilon * sigma**12
|
||||||
|
lj2 = 4.0 * epsilon * sigma**6
|
||||||
|
|
||||||
|
eij = r6inv * (lj1 * r6inv - lj2)
|
||||||
|
return jnp.sum(eij)
|
||||||
|
# /- construct a function computing _tot_e and its derivative
|
||||||
|
tot_e, fij = jax.value_and_grad(_tot_e)(rij)
|
||||||
|
return tot_e, fij
|
||||||
|
|
||||||
|
|
||||||
class MLIAPUnifiedJAX(MLIAPUnified):
|
class MLIAPUnifiedJAX(MLIAPUnified):
|
||||||
"""Test implementation for MLIAPUnified."""
|
"""Test implementation for MLIAPUnified."""
|
||||||
|
|
||||||
|
epsilon: float
|
||||||
|
sigma: float
|
||||||
|
|
||||||
def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25):
|
def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25):
|
||||||
# ARGS: interface, element_types, ndescriptors, nparams, rcutfac
|
# ARGS: interface, element_types, ndescriptors, nparams, rcutfac
|
||||||
super().__init__(None, element_types, 1, 3, rcutfac)
|
super().__init__(None, element_types, 1, 3, rcutfac)
|
||||||
@ -39,23 +59,7 @@ class MLIAPUnifiedJAX(MLIAPUnified):
|
|||||||
# This might account for ~2-3x slowdown compared to original LJ.
|
# This might account for ~2-3x slowdown compared to original LJ.
|
||||||
rij = np.pad(rij, ((0,npad), (0,0)), 'constant')
|
rij = np.pad(rij, ((0,npad), (0,0)), 'constant')
|
||||||
|
|
||||||
eij, fij = self.compute_pair_ef(rij)
|
e_tot, fij = lj_potential(rij)
|
||||||
|
|
||||||
data.update_pair_energy(np.array(np.double(eij)))
|
data.energy = e_tot.item()
|
||||||
data.update_pair_forces(np.array(np.double(fij)))
|
data.update_pair_forces(np.array(fij, dtype=np.float64))
|
||||||
|
|
||||||
#@jax.jit # <-- This will error! See https://github.com/google/jax/issues/1251
|
|
||||||
# @partial takes a function (e.g. jax.jit) as an arg.
|
|
||||||
@partial(jax.jit, static_argnums=(0,))
|
|
||||||
def compute_pair_ef(self, rij):
|
|
||||||
|
|
||||||
r2inv = 1.0 / jnp.sum(rij ** 2, axis=1)
|
|
||||||
r6inv = r2inv * r2inv * r2inv
|
|
||||||
|
|
||||||
lj1 = 4.0 * self.epsilon * self.sigma**12
|
|
||||||
lj2 = 4.0 * self.epsilon * self.sigma**6
|
|
||||||
|
|
||||||
eij = r6inv * (lj1 * r6inv - lj2)
|
|
||||||
fij = r6inv * (3.0 * lj2 - 6.0 * lj2 * r6inv) * r2inv
|
|
||||||
fij = fij[:, jnp.newaxis] * rij
|
|
||||||
return eij, fij
|
|
||||||
|
|||||||
Reference in New Issue
Block a user