Copy MLIAPUnified LJ example

This commit is contained in:
rohskopf
2023-05-20 14:08:20 -06:00
parent 6977f71eb0
commit 28c9c274be
6 changed files with 76 additions and 64 deletions

View File

@ -35,4 +35,8 @@ First define JAX model in `deploy_script.py`, which will wrap model with `write_
python deploy_script.py
Then load model in LAMMPS and run:
This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified.
Run LAMMPS with the model:
mpirun -np P lmp -in in.run

View File

@ -1,32 +1,11 @@
import numpy as np
import pickle
from pathlib import Path
from write_unified import MLIAPInterface
import lammps
import lammps.mliap
class MyModel():
def __init__(self,blah):
"""
coeffs = np.genfromtxt(file,skip_header=6)
self.bias = coeffs[0]
self.weights = coeffs[1:]
"""
self.blah = blah
self.n_params = 3 #len(coeffs)
self.n_descriptors = 1 #len(self.weights)
self.n_elements = 1
def __call__(self,rij):
print(rij)
#energy[:] = bispectrum @ self.weights + self.bias
#beta[:] = self.weights
return 5
model = MyModel(1)
#unified = MLIAPInterface(model, ["Ta"], model_device="cpu")
#from lammps.mliap.mliap_unified_lj import MLIAPUnifiedLJ
from mliap_unified_jax import MLIAPUnifiedJAX
def create_pickle():
unified = MLIAPInterface(model, ["Ta"])
unified.pickle('mliap_jax.pkl')
unified = MLIAPUnifiedJAX(["Ar"])
unified.pickle('mliap_unified_jax_Ar.pkl')
create_pickle()

View File

@ -1,48 +1,35 @@
# Initialize simulation
# 3d Lennard-Jones melt
variable nsteps index 10000
units metal
units lj
atom_style atomic
# generate the box and atom positions using a BCC lattice
#boundary p p p
#read_data DATA
variable nrep equal 2 #10
variable a equal 3.316
variable nx equal ${nrep}
variable ny equal ${nrep}
variable nz equal ${nrep}
lattice bcc $a
region box block 0 ${nx} 0 ${ny} 0 ${nz}
lattice fcc 0.8442
region box block 0 10 0 10 0 10
create_box 1 box
create_atoms 1 box
mass 1 180.88
mass 1 1.0
pair_style mliap unified mliap_jax.pkl 0
pair_coeff * * Ta
velocity all create 3.0 87287 loop geom
compute eatom all pe/atom
compute energy all reduce sum c_eatom
pair_style mliap unified mliap_unified_jax_Ar.pkl 0
pair_coeff * * Ar
compute satom all stress/atom NULL
compute str all reduce sum c_satom[1] c_satom[2] c_satom[3]
variable press equal (c_str[1]+c_str[2]+c_str[3])/(3*vol)
neighbor 0.3 bin
neigh_modify every 20 delay 0 check no
thermo_style custom step temp epair c_energy etotal press v_press
thermo 10
thermo_modify norm yes
# Set up NVE run
timestep 0.5e-3
neighbor 1.0 bin
# is this neigh modify every 1 slow?
neigh_modify once no every 1 delay 0 check yes
# Run MD
velocity all create 3200.0 4928459 loop geom
dump 1 all xyz 10 dump.xyz
fix 1 all nve
run ${nsteps}
#dump id all atom 50 dump.melt
#dump 2 all image 25 image.*.jpg type type &
# axes yes 0.8 0.02 view 60 -30
#dump_modify 2 pad 3
#dump 3 all movie 1 movie.mpg type type &
# axes yes 0.8 0.02 view 60 -30
#dump_modify 3 pad 3
#dump 4 all custom 1 forces.xyz fx fy fz
thermo 50
run 250

View File

@ -0,0 +1,41 @@
from lammps.mliap.mliap_unified_abc import MLIAPUnified
import numpy as np
class MLIAPUnifiedJAX(MLIAPUnified):
"""Test implementation for MLIAPUnified."""
def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25):
# ARGS: interface, element_types, ndescriptors, nparams, rcutfac
super().__init__(None, element_types, 1, 3, rcutfac)
# Mimicking the LJ pair-style:
# pair_style lj/cut 2.5
# pair_coeff * * 1 1
self.epsilon = epsilon
self.sigma = sigma
def compute_gradients(self, data):
"""Test compute_gradients."""
def compute_descriptors(self, data):
"""Test compute_descriptors."""
def compute_forces(self, data):
"""Test compute_forces."""
eij, fij = self.compute_pair_ef(data)
data.update_pair_energy(eij)
data.update_pair_forces(fij)
def compute_pair_ef(self, data):
rij = data.rij
r2inv = 1.0 / np.sum(rij ** 2, axis=1)
r6inv = r2inv * r2inv * r2inv
lj1 = 4.0 * self.epsilon * self.sigma**12
lj2 = 4.0 * self.epsilon * self.sigma**6
eij = r6inv * (lj1 * r6inv - lj2)
fij = r6inv * (3.0 * lj2 - 6.0 * lj2 * r6inv) * r2inv
fij = fij[:, np.newaxis] * rij
return eij, fij

Binary file not shown.

View File

@ -6,6 +6,7 @@ import pickle
import numpy as np
from lammps.mliap.mliap_unified_abc import MLIAPUnified
#from deploy_script import MyModel
class MLIAPInterface(MLIAPUnified):
"""