Copy MLIAPUnified LJ example

This commit is contained in:
rohskopf
2023-05-20 14:08:20 -06:00
parent 6977f71eb0
commit 28c9c274be
6 changed files with 76 additions and 64 deletions

View File

@ -1,32 +1,11 @@
import numpy as np
import pickle
from pathlib import Path
from write_unified import MLIAPInterface
import lammps
import lammps.mliap
class MyModel():
def __init__(self,blah):
"""
coeffs = np.genfromtxt(file,skip_header=6)
self.bias = coeffs[0]
self.weights = coeffs[1:]
"""
self.blah = blah
self.n_params = 3 #len(coeffs)
self.n_descriptors = 1 #len(self.weights)
self.n_elements = 1
def __call__(self,rij):
print(rij)
#energy[:] = bispectrum @ self.weights + self.bias
#beta[:] = self.weights
return 5
model = MyModel(1)
#unified = MLIAPInterface(model, ["Ta"], model_device="cpu")
#from lammps.mliap.mliap_unified_lj import MLIAPUnifiedLJ
from mliap_unified_jax import MLIAPUnifiedJAX
def create_pickle():
unified = MLIAPInterface(model, ["Ta"])
unified.pickle('mliap_jax.pkl')
unified = MLIAPUnifiedJAX(["Ar"])
unified.pickle('mliap_unified_jax_Ar.pkl')
create_pickle()