Copy MLIAPUnified LJ example
This commit is contained in:
@ -1,32 +1,11 @@
|
||||
import numpy as np
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from write_unified import MLIAPInterface
|
||||
import lammps
|
||||
import lammps.mliap
|
||||
|
||||
class MyModel():
|
||||
def __init__(self,blah):
|
||||
"""
|
||||
coeffs = np.genfromtxt(file,skip_header=6)
|
||||
self.bias = coeffs[0]
|
||||
self.weights = coeffs[1:]
|
||||
"""
|
||||
self.blah = blah
|
||||
self.n_params = 3 #len(coeffs)
|
||||
self.n_descriptors = 1 #len(self.weights)
|
||||
self.n_elements = 1
|
||||
|
||||
def __call__(self,rij):
|
||||
print(rij)
|
||||
#energy[:] = bispectrum @ self.weights + self.bias
|
||||
#beta[:] = self.weights
|
||||
return 5
|
||||
|
||||
model = MyModel(1)
|
||||
|
||||
#unified = MLIAPInterface(model, ["Ta"], model_device="cpu")
|
||||
#from lammps.mliap.mliap_unified_lj import MLIAPUnifiedLJ
|
||||
from mliap_unified_jax import MLIAPUnifiedJAX
|
||||
|
||||
def create_pickle():
|
||||
unified = MLIAPInterface(model, ["Ta"])
|
||||
unified.pickle('mliap_jax.pkl')
|
||||
unified = MLIAPUnifiedJAX(["Ar"])
|
||||
unified.pickle('mliap_unified_jax_Ar.pkl')
|
||||
|
||||
create_pickle()
|
||||
Reference in New Issue
Block a user