diff --git a/examples/mliap/jax/README.md b/examples/mliap/jax/README.md index 14a049221a..320d87e465 100644 --- a/examples/mliap/jax/README.md +++ b/examples/mliap/jax/README.md @@ -35,4 +35,8 @@ First define JAX model in `deploy_script.py`, which will wrap model with `write_ python deploy_script.py -Then load model in LAMMPS and run: \ No newline at end of file +This creates `.pkl` file to be loaded by LAMMPS ML-IAP Unified. + +Run LAMMPS with the model: + + mpirun -np P lmp -in in.run \ No newline at end of file diff --git a/examples/mliap/jax/deploy_script.py b/examples/mliap/jax/deploy_script.py index 58af797766..5e73995565 100644 --- a/examples/mliap/jax/deploy_script.py +++ b/examples/mliap/jax/deploy_script.py @@ -1,32 +1,11 @@ -import numpy as np -import pickle -from pathlib import Path -from write_unified import MLIAPInterface +import lammps +import lammps.mliap -class MyModel(): - def __init__(self,blah): - """ - coeffs = np.genfromtxt(file,skip_header=6) - self.bias = coeffs[0] - self.weights = coeffs[1:] - """ - self.blah = blah - self.n_params = 3 #len(coeffs) - self.n_descriptors = 1 #len(self.weights) - self.n_elements = 1 - - def __call__(self,rij): - print(rij) - #energy[:] = bispectrum @ self.weights + self.bias - #beta[:] = self.weights - return 5 - -model = MyModel(1) - -#unified = MLIAPInterface(model, ["Ta"], model_device="cpu") +#from lammps.mliap.mliap_unified_lj import MLIAPUnifiedLJ +from mliap_unified_jax import MLIAPUnifiedJAX def create_pickle(): - unified = MLIAPInterface(model, ["Ta"]) - unified.pickle('mliap_jax.pkl') + unified = MLIAPUnifiedJAX(["Ar"]) + unified.pickle('mliap_unified_jax_Ar.pkl') create_pickle() \ No newline at end of file diff --git a/examples/mliap/jax/in.run b/examples/mliap/jax/in.run index 1ffc2d3c0f..354dfd769f 100644 --- a/examples/mliap/jax/in.run +++ b/examples/mliap/jax/in.run @@ -1,48 +1,35 @@ -# Initialize simulation +# 3d Lennard-Jones melt -variable nsteps index 10000 -units metal +units lj +atom_style atomic -# generate the box and atom positions using a BCC lattice - -#boundary p p p -#read_data DATA - -variable nrep equal 2 #10 -variable a equal 3.316 -variable nx equal ${nrep} -variable ny equal ${nrep} -variable nz equal ${nrep} -lattice bcc $a -region box block 0 ${nx} 0 ${ny} 0 ${nz} +lattice fcc 0.8442 +region box block 0 10 0 10 0 10 create_box 1 box create_atoms 1 box -mass 1 180.88 +mass 1 1.0 -pair_style mliap unified mliap_jax.pkl 0 -pair_coeff * * Ta +velocity all create 3.0 87287 loop geom -compute eatom all pe/atom -compute energy all reduce sum c_eatom +pair_style mliap unified mliap_unified_jax_Ar.pkl 0 +pair_coeff * * Ar -compute satom all stress/atom NULL -compute str all reduce sum c_satom[1] c_satom[2] c_satom[3] -variable press equal (c_str[1]+c_str[2]+c_str[3])/(3*vol) +neighbor 0.3 bin +neigh_modify every 20 delay 0 check no -thermo_style custom step temp epair c_energy etotal press v_press -thermo 10 -thermo_modify norm yes +fix 1 all nve -# Set up NVE run +#dump id all atom 50 dump.melt -timestep 0.5e-3 -neighbor 1.0 bin -# is this neigh modify every 1 slow? -neigh_modify once no every 1 delay 0 check yes +#dump 2 all image 25 image.*.jpg type type & +# axes yes 0.8 0.02 view 60 -30 +#dump_modify 2 pad 3 -# Run MD +#dump 3 all movie 1 movie.mpg type type & +# axes yes 0.8 0.02 view 60 -30 +#dump_modify 3 pad 3 -velocity all create 3200.0 4928459 loop geom -dump 1 all xyz 10 dump.xyz -fix 1 all nve -run ${nsteps} \ No newline at end of file +#dump 4 all custom 1 forces.xyz fx fy fz + +thermo 50 +run 250 \ No newline at end of file diff --git a/examples/mliap/jax/mliap_unified_jax.py b/examples/mliap/jax/mliap_unified_jax.py index e69de29bb2..a2447d1cdb 100644 --- a/examples/mliap/jax/mliap_unified_jax.py +++ b/examples/mliap/jax/mliap_unified_jax.py @@ -0,0 +1,41 @@ +from lammps.mliap.mliap_unified_abc import MLIAPUnified +import numpy as np + + +class MLIAPUnifiedJAX(MLIAPUnified): + """Test implementation for MLIAPUnified.""" + + def __init__(self, element_types, epsilon=1.0, sigma=1.0, rcutfac=1.25): + # ARGS: interface, element_types, ndescriptors, nparams, rcutfac + super().__init__(None, element_types, 1, 3, rcutfac) + # Mimicking the LJ pair-style: + # pair_style lj/cut 2.5 + # pair_coeff * * 1 1 + self.epsilon = epsilon + self.sigma = sigma + + def compute_gradients(self, data): + """Test compute_gradients.""" + + def compute_descriptors(self, data): + """Test compute_descriptors.""" + + def compute_forces(self, data): + """Test compute_forces.""" + eij, fij = self.compute_pair_ef(data) + data.update_pair_energy(eij) + data.update_pair_forces(fij) + + def compute_pair_ef(self, data): + rij = data.rij + + r2inv = 1.0 / np.sum(rij ** 2, axis=1) + r6inv = r2inv * r2inv * r2inv + + lj1 = 4.0 * self.epsilon * self.sigma**12 + lj2 = 4.0 * self.epsilon * self.sigma**6 + + eij = r6inv * (lj1 * r6inv - lj2) + fij = r6inv * (3.0 * lj2 - 6.0 * lj2 * r6inv) * r2inv + fij = fij[:, np.newaxis] * rij + return eij, fij \ No newline at end of file diff --git a/examples/mliap/jax/mliap_unified_jax_Ar.pkl b/examples/mliap/jax/mliap_unified_jax_Ar.pkl new file mode 100644 index 0000000000..04b1730315 Binary files /dev/null and b/examples/mliap/jax/mliap_unified_jax_Ar.pkl differ diff --git a/examples/mliap/jax/write_unified.py b/examples/mliap/jax/write_unified.py index 55f46a0af4..af2627dfef 100644 --- a/examples/mliap/jax/write_unified.py +++ b/examples/mliap/jax/write_unified.py @@ -6,6 +6,7 @@ import pickle import numpy as np from lammps.mliap.mliap_unified_abc import MLIAPUnified +#from deploy_script import MyModel class MLIAPInterface(MLIAPUnified): """