Properly decorate energy/force compute
This commit is contained in:
@ -32,4 +32,4 @@ fix 1 all nve
|
|||||||
#dump 4 all custom 1 forces.xyz fx fy fz
|
#dump 4 all custom 1 forces.xyz fx fy fz
|
||||||
|
|
||||||
thermo 50
|
thermo 50
|
||||||
run 250
|
run 100
|
||||||
@ -1,5 +1,9 @@
|
|||||||
from lammps.mliap.mliap_unified_abc import MLIAPUnified
|
from lammps.mliap.mliap_unified_abc import MLIAPUnified
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import jit
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
class MLIAPUnifiedJAX(MLIAPUnified):
|
class MLIAPUnifiedJAX(MLIAPUnified):
|
||||||
@ -13,6 +17,7 @@ class MLIAPUnifiedJAX(MLIAPUnified):
|
|||||||
# pair_coeff * * 1 1
|
# pair_coeff * * 1 1
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
|
self.npair_max = 250000
|
||||||
|
|
||||||
def compute_gradients(self, data):
|
def compute_gradients(self, data):
|
||||||
"""Test compute_gradients."""
|
"""Test compute_gradients."""
|
||||||
@ -22,12 +27,26 @@ class MLIAPUnifiedJAX(MLIAPUnified):
|
|||||||
|
|
||||||
def compute_forces(self, data):
|
def compute_forces(self, data):
|
||||||
"""Test compute_forces."""
|
"""Test compute_forces."""
|
||||||
eij, fij = self.compute_pair_ef(data)
|
rij = data.rij
|
||||||
data.update_pair_energy(eij)
|
|
||||||
data.update_pair_forces(fij)
|
|
||||||
|
|
||||||
def compute_pair_ef(self, data):
|
# TODO: Take max npairs from the LAMMPS Cython side.
|
||||||
rij = data.rij
|
if (data.npairs > self.npair_max):
|
||||||
|
self.npair_max = data.npairs
|
||||||
|
|
||||||
|
npad = self.npair_max - data.npairs
|
||||||
|
# TODO: Take pre-padded rij from the LAMMPS Cython side.
|
||||||
|
# This might account for ~2-3x slowdown compared to original LJ.
|
||||||
|
rij = np.pad(rij, ((0,npad), (0,0)), 'constant')
|
||||||
|
|
||||||
|
eij, fij = self.compute_pair_ef(rij)
|
||||||
|
|
||||||
|
data.update_pair_energy(np.array(np.double(eij)))
|
||||||
|
data.update_pair_forces(np.array(np.double(fij)))
|
||||||
|
|
||||||
|
#@jax.jit # <-- This will error! See https://github.com/google/jax/issues/1251
|
||||||
|
# @partial takes a function (e.g. jax.jit) as an arg.
|
||||||
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
|
def compute_pair_ef(self, rij):
|
||||||
|
|
||||||
r2inv = 1.0 / np.sum(rij ** 2, axis=1)
|
r2inv = 1.0 / np.sum(rij ** 2, axis=1)
|
||||||
r6inv = r2inv * r2inv * r2inv
|
r6inv = r2inv * r2inv * r2inv
|
||||||
|
|||||||
Reference in New Issue
Block a user