send eij to LAMMPS
This commit is contained in:
@ -15,11 +15,11 @@ def lj_potential(epsilon, sigma, rij):
|
|||||||
lj1 = 4.0 * epsilon * sigma**12
|
lj1 = 4.0 * epsilon * sigma**12
|
||||||
lj2 = 4.0 * epsilon * sigma**6
|
lj2 = 4.0 * epsilon * sigma**6
|
||||||
|
|
||||||
eij = r6inv * (lj1 * r6inv - lj2)
|
eij = 0.5 * r6inv * (lj1 * r6inv - lj2)
|
||||||
return 0.5 * jnp.sum(eij)
|
return jnp.sum(eij), eij
|
||||||
# Construct a function computing _tot_e and its derivative
|
# Construct a function computing _tot_e and its derivative
|
||||||
tot_e, fij = jax.value_and_grad(_tot_e)(rij)
|
(_, eij), fij = jax.value_and_grad(_tot_e, has_aux=True)(rij)
|
||||||
return tot_e, fij
|
return eij, fij
|
||||||
|
|
||||||
|
|
||||||
class MLIAPUnifiedJAX(MLIAPUnified):
|
class MLIAPUnifiedJAX(MLIAPUnified):
|
||||||
@ -49,7 +49,7 @@ class MLIAPUnifiedJAX(MLIAPUnified):
|
|||||||
# NOTE: Use data.rij_max with JAX.
|
# NOTE: Use data.rij_max with JAX.
|
||||||
rij = data.rij_max
|
rij = data.rij_max
|
||||||
|
|
||||||
e_tot, fij = lj_potential(self.epsilon, self.sigma, rij)
|
eij, fij = lj_potential(self.epsilon, self.sigma, rij)
|
||||||
|
|
||||||
data.energy = e_tot.item()
|
data.update_pair_energy(np.array(eij, dtype=np.float64))
|
||||||
data.update_pair_forces(np.array(fij, dtype=np.float64))
|
data.update_pair_forces(np.array(fij, dtype=np.float64))
|
||||||
|
|||||||
Reference in New Issue
Block a user