Use jax functions
This commit is contained in:
@ -17,6 +17,7 @@ class MLIAPUnifiedJAX(MLIAPUnified):
|
|||||||
# pair_coeff * * 1 1
|
# pair_coeff * * 1 1
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
|
# TODO: Take this from the LAMMPS Cython side.
|
||||||
self.npair_max = 250000
|
self.npair_max = 250000
|
||||||
|
|
||||||
def compute_gradients(self, data):
|
def compute_gradients(self, data):
|
||||||
@ -48,7 +49,7 @@ class MLIAPUnifiedJAX(MLIAPUnified):
|
|||||||
@partial(jax.jit, static_argnums=(0,))
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
def compute_pair_ef(self, rij):
|
def compute_pair_ef(self, rij):
|
||||||
|
|
||||||
r2inv = 1.0 / np.sum(rij ** 2, axis=1)
|
r2inv = 1.0 / jnp.sum(rij ** 2, axis=1)
|
||||||
r6inv = r2inv * r2inv * r2inv
|
r6inv = r2inv * r2inv * r2inv
|
||||||
|
|
||||||
lj1 = 4.0 * self.epsilon * self.sigma**12
|
lj1 = 4.0 * self.epsilon * self.sigma**12
|
||||||
@ -56,5 +57,5 @@ class MLIAPUnifiedJAX(MLIAPUnified):
|
|||||||
|
|
||||||
eij = r6inv * (lj1 * r6inv - lj2)
|
eij = r6inv * (lj1 * r6inv - lj2)
|
||||||
fij = r6inv * (3.0 * lj2 - 6.0 * lj2 * r6inv) * r2inv
|
fij = r6inv * (3.0 * lj2 - 6.0 * lj2 * r6inv) * r2inv
|
||||||
fij = fij[:, np.newaxis] * rij
|
fij = fij[:, jnp.newaxis] * rij
|
||||||
return eij, fij
|
return eij, fij
|
||||||
|
|||||||
Reference in New Issue
Block a user