Add input padded rij from LAMMPS Cython side

This commit is contained in:
rohskopf
2023-05-25 01:35:08 -06:00
parent 29ba0e3f18
commit 35a55068d7
2 changed files with 19 additions and 18 deletions

View File

@ -5,10 +5,9 @@ import jax.numpy as jnp
from jax import jit
from functools import partial
# /- ensure epsilon and sigma are treated as compile-time constants
@partial(jit, static_argnums=(0, 1))
def lj_potential(epsilon: float, sigma: float, rij):
# a pure function we can differentiate:
@jax.jit
def lj_potential(epsilon, sigma, rij):
# A pure function we can differentiate:
def _tot_e(rij):
r2inv = 1.0 / jnp.sum(rij ** 2, axis=1)
r6inv = r2inv * r2inv * r2inv
@ -17,8 +16,8 @@ def lj_potential(epsilon: float, sigma: float, rij):
lj2 = 4.0 * epsilon * sigma**6
eij = r6inv * (lj1 * r6inv - lj2)
return jnp.sum(eij)
# /- construct a function computing _tot_e and its derivative
return 0.5 * jnp.sum(eij)
# Construct a function computing _tot_e and its derivative
tot_e, fij = jax.value_and_grad(_tot_e)(rij)
return tot_e, fij
@ -37,8 +36,6 @@ class MLIAPUnifiedJAX(MLIAPUnified):
# pair_coeff * * 1 1
self.epsilon = epsilon
self.sigma = sigma
# TODO: Take this from the LAMMPS Cython side.
self.npair_max = 250000
def compute_gradients(self, data):
"""Test compute_gradients."""
@ -48,18 +45,11 @@ class MLIAPUnifiedJAX(MLIAPUnified):
def compute_forces(self, data):
"""Test compute_forces."""
rij = data.rij
# TODO: Take max npairs from the LAMMPS Cython side.
if (data.npairs > self.npair_max):
self.npair_max = data.npairs
# NOTE: Use data.rij_max with JAX.
rij = data.rij_max
npad = self.npair_max - data.npairs
# TODO: Take pre-padded rij from the LAMMPS Cython side.
# This might account for ~2-3x slowdown compared to original LJ.
rij = np.pad(rij, ((0,npad), (0,0)), 'constant')
e_tot, fij = lj_potential(rij)
e_tot, fij = lj_potential(self.epsilon, self.sigma, rij)
data.energy = e_tot.item()
data.update_pair_forces(np.array(fij, dtype=np.float64))