Files
lammps/examples/mliap/jax/mliap_unified_jax.py
2023-05-25 01:35:08 -06:00

56 lines
1.6 KiB
Python

from lammps.mliap.mliap_unified_abc import MLIAPUnified
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
@jax.jit
def lj_potential(epsilon, sigma, rij):
# A pure function we can differentiate:
def _tot_e(rij):
r2inv = 1.0 / jnp.sum(rij ** 2, axis=1)
r6inv = r2inv * r2inv * r2inv
lj1 = 4.0 * epsilon * sigma**12
lj2 = 4.0 * epsilon * sigma**6
eij = r6inv * (lj1 * r6inv - lj2)
return 0.5 * jnp.sum(eij)
# Construct a function computing _tot_e and its derivative
tot_e, fij = jax.value_and_grad(_tot_e)(rij)
return tot_e, fij
class MLIAPUnifiedJAX(MLIAPUnified):
"""Test implementation for MLIAPUnified."""
epsilon: float
sigma: float
def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25):
# ARGS: interface, element_types, ndescriptors, nparams, rcutfac
super().__init__(None, element_types, 1, 3, rcutfac)
# Mimicking the LJ pair-style:
# pair_style lj/cut 2.5
# pair_coeff * * 1 1
self.epsilon = epsilon
self.sigma = sigma
def compute_gradients(self, data):
"""Test compute_gradients."""
def compute_descriptors(self, data):
"""Test compute_descriptors."""
def compute_forces(self, data):
"""Test compute_forces."""
# NOTE: Use data.rij_max with JAX.
rij = data.rij_max
e_tot, fij = lj_potential(self.epsilon, self.sigma, rij)
data.energy = e_tot.item()
data.update_pair_forces(np.array(fij, dtype=np.float64))