Copy MLIAPUnified LJ example
This commit is contained in:
@ -35,4 +35,8 @@ First define JAX model in `deploy_script.py`, which will wrap model with `write_
|
||||
|
||||
python deploy_script.py
|
||||
|
||||
Then load model in LAMMPS and run:
|
||||
This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified.
|
||||
|
||||
Run LAMMPS with the model:
|
||||
|
||||
mpirun -np P lmp -in in.run
|
||||
@ -1,32 +1,11 @@
|
||||
import numpy as np
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from write_unified import MLIAPInterface
|
||||
import lammps
|
||||
import lammps.mliap
|
||||
|
||||
class MyModel():
|
||||
def __init__(self,blah):
|
||||
"""
|
||||
coeffs = np.genfromtxt(file,skip_header=6)
|
||||
self.bias = coeffs[0]
|
||||
self.weights = coeffs[1:]
|
||||
"""
|
||||
self.blah = blah
|
||||
self.n_params = 3 #len(coeffs)
|
||||
self.n_descriptors = 1 #len(self.weights)
|
||||
self.n_elements = 1
|
||||
|
||||
def __call__(self,rij):
|
||||
print(rij)
|
||||
#energy[:] = bispectrum @ self.weights + self.bias
|
||||
#beta[:] = self.weights
|
||||
return 5
|
||||
|
||||
model = MyModel(1)
|
||||
|
||||
#unified = MLIAPInterface(model, ["Ta"], model_device="cpu")
|
||||
#from lammps.mliap.mliap_unified_lj import MLIAPUnifiedLJ
|
||||
from mliap_unified_jax import MLIAPUnifiedJAX
|
||||
|
||||
def create_pickle():
|
||||
unified = MLIAPInterface(model, ["Ta"])
|
||||
unified.pickle('mliap_jax.pkl')
|
||||
unified = MLIAPUnifiedJAX(["Ar"])
|
||||
unified.pickle('mliap_unified_jax_Ar.pkl')
|
||||
|
||||
create_pickle()
|
||||
@ -1,48 +1,35 @@
|
||||
# Initialize simulation
|
||||
# 3d Lennard-Jones melt
|
||||
|
||||
variable nsteps index 10000
|
||||
units metal
|
||||
units lj
|
||||
atom_style atomic
|
||||
|
||||
# generate the box and atom positions using a BCC lattice
|
||||
|
||||
#boundary p p p
|
||||
#read_data DATA
|
||||
|
||||
variable nrep equal 2 #10
|
||||
variable a equal 3.316
|
||||
variable nx equal ${nrep}
|
||||
variable ny equal ${nrep}
|
||||
variable nz equal ${nrep}
|
||||
lattice bcc $a
|
||||
region box block 0 ${nx} 0 ${ny} 0 ${nz}
|
||||
lattice fcc 0.8442
|
||||
region box block 0 10 0 10 0 10
|
||||
create_box 1 box
|
||||
create_atoms 1 box
|
||||
mass 1 180.88
|
||||
mass 1 1.0
|
||||
|
||||
pair_style mliap unified mliap_jax.pkl 0
|
||||
pair_coeff * * Ta
|
||||
velocity all create 3.0 87287 loop geom
|
||||
|
||||
compute eatom all pe/atom
|
||||
compute energy all reduce sum c_eatom
|
||||
pair_style mliap unified mliap_unified_jax_Ar.pkl 0
|
||||
pair_coeff * * Ar
|
||||
|
||||
compute satom all stress/atom NULL
|
||||
compute str all reduce sum c_satom[1] c_satom[2] c_satom[3]
|
||||
variable press equal (c_str[1]+c_str[2]+c_str[3])/(3*vol)
|
||||
neighbor 0.3 bin
|
||||
neigh_modify every 20 delay 0 check no
|
||||
|
||||
thermo_style custom step temp epair c_energy etotal press v_press
|
||||
thermo 10
|
||||
thermo_modify norm yes
|
||||
fix 1 all nve
|
||||
|
||||
# Set up NVE run
|
||||
#dump id all atom 50 dump.melt
|
||||
|
||||
timestep 0.5e-3
|
||||
neighbor 1.0 bin
|
||||
# is this neigh modify every 1 slow?
|
||||
neigh_modify once no every 1 delay 0 check yes
|
||||
#dump 2 all image 25 image.*.jpg type type &
|
||||
# axes yes 0.8 0.02 view 60 -30
|
||||
#dump_modify 2 pad 3
|
||||
|
||||
# Run MD
|
||||
#dump 3 all movie 1 movie.mpg type type &
|
||||
# axes yes 0.8 0.02 view 60 -30
|
||||
#dump_modify 3 pad 3
|
||||
|
||||
velocity all create 3200.0 4928459 loop geom
|
||||
dump 1 all xyz 10 dump.xyz
|
||||
fix 1 all nve
|
||||
run ${nsteps}
|
||||
#dump 4 all custom 1 forces.xyz fx fy fz
|
||||
|
||||
thermo 50
|
||||
run 250
|
||||
@ -0,0 +1,41 @@
|
||||
from lammps.mliap.mliap_unified_abc import MLIAPUnified
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MLIAPUnifiedJAX(MLIAPUnified):
|
||||
"""Test implementation for MLIAPUnified."""
|
||||
|
||||
def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25):
|
||||
# ARGS: interface, element_types, ndescriptors, nparams, rcutfac
|
||||
super().__init__(None, element_types, 1, 3, rcutfac)
|
||||
# Mimicking the LJ pair-style:
|
||||
# pair_style lj/cut 2.5
|
||||
# pair_coeff * * 1 1
|
||||
self.epsilon = epsilon
|
||||
self.sigma = sigma
|
||||
|
||||
def compute_gradients(self, data):
|
||||
"""Test compute_gradients."""
|
||||
|
||||
def compute_descriptors(self, data):
|
||||
"""Test compute_descriptors."""
|
||||
|
||||
def compute_forces(self, data):
|
||||
"""Test compute_forces."""
|
||||
eij, fij = self.compute_pair_ef(data)
|
||||
data.update_pair_energy(eij)
|
||||
data.update_pair_forces(fij)
|
||||
|
||||
def compute_pair_ef(self, data):
|
||||
rij = data.rij
|
||||
|
||||
r2inv = 1.0 / np.sum(rij ** 2, axis=1)
|
||||
r6inv = r2inv * r2inv * r2inv
|
||||
|
||||
lj1 = 4.0 * self.epsilon * self.sigma**12
|
||||
lj2 = 4.0 * self.epsilon * self.sigma**6
|
||||
|
||||
eij = r6inv * (lj1 * r6inv - lj2)
|
||||
fij = r6inv * (3.0 * lj2 - 6.0 * lj2 * r6inv) * r2inv
|
||||
fij = fij[:, np.newaxis] * rij
|
||||
return eij, fij
|
||||
BIN
examples/mliap/jax/mliap_unified_jax_Ar.pkl
Normal file
BIN
examples/mliap/jax/mliap_unified_jax_Ar.pkl
Normal file
Binary file not shown.
@ -6,6 +6,7 @@ import pickle
|
||||
import numpy as np
|
||||
|
||||
from lammps.mliap.mliap_unified_abc import MLIAPUnified
|
||||
#from deploy_script import MyModel
|
||||
|
||||
class MLIAPInterface(MLIAPUnified):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user