whitespace
This commit is contained in:
@ -1622,8 +1622,8 @@ class lammps(object):
|
|||||||
"""Return a string with detailed information about any devices that are
|
"""Return a string with detailed information about any devices that are
|
||||||
usable by the GPU package.
|
usable by the GPU package.
|
||||||
|
|
||||||
This is a wrapper around the :cpp:func:`lammps_get_gpu_device_info`
|
This is a wrapper around the :cpp:func:`lammps_get_gpu_device_info`
|
||||||
function of the C-library interface.
|
function of the C-library interface.
|
||||||
|
|
||||||
:return: GPU device info string
|
:return: GPU device info string
|
||||||
:rtype: string
|
:rtype: string
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from ctypes import pythonapi, c_int, c_void_p, py_object
|
|||||||
class DynamicLoader(importlib.abc.Loader):
|
class DynamicLoader(importlib.abc.Loader):
|
||||||
def __init__(self,module_name,library,api_version=1013):
|
def __init__(self,module_name,library,api_version=1013):
|
||||||
self.api_version = api_version
|
self.api_version = api_version
|
||||||
|
|
||||||
attr = "PyInit_"+module_name
|
attr = "PyInit_"+module_name
|
||||||
initfunc = getattr(library,attr)
|
initfunc = getattr(library,attr)
|
||||||
# c_void_p is standin for PyModuleDef *
|
# c_void_p is standin for PyModuleDef *
|
||||||
@ -44,7 +44,7 @@ class DynamicLoader(importlib.abc.Loader):
|
|||||||
createfunc.restype = py_object
|
createfunc.restype = py_object
|
||||||
module = createfunc(self.module_def, spec, self.api_version)
|
module = createfunc(self.module_def, spec, self.api_version)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def exec_module(self, module):
|
def exec_module(self, module):
|
||||||
execfunc = pythonapi.PyModule_ExecDef
|
execfunc = pythonapi.PyModule_ExecDef
|
||||||
# c_void_p is standin for PyModuleDef *
|
# c_void_p is standin for PyModuleDef *
|
||||||
@ -59,12 +59,12 @@ def activate_mliappy(lmp):
|
|||||||
library = lmp.lib
|
library = lmp.lib
|
||||||
module_names = ["mliap_model_python_couple", "mliap_unified_couple"]
|
module_names = ["mliap_model_python_couple", "mliap_unified_couple"]
|
||||||
api_version = library.lammps_python_api_version()
|
api_version = library.lammps_python_api_version()
|
||||||
|
|
||||||
for module_name in module_names:
|
for module_name in module_names:
|
||||||
# Make Machinery
|
# Make Machinery
|
||||||
loader = DynamicLoader(module_name,library,api_version)
|
loader = DynamicLoader(module_name,library,api_version)
|
||||||
spec = importlib.util.spec_from_loader(module_name,loader)
|
spec = importlib.util.spec_from_loader(module_name,loader)
|
||||||
|
|
||||||
# Do the import
|
# Do the import
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
sys.modules[module_name] = module
|
sys.modules[module_name] = module
|
||||||
|
|||||||
@ -19,16 +19,16 @@ class MLIAPUnifiedLJ(MLIAPUnified):
|
|||||||
|
|
||||||
def compute_gradients(self, data):
|
def compute_gradients(self, data):
|
||||||
"""Test compute_gradients."""
|
"""Test compute_gradients."""
|
||||||
|
|
||||||
def compute_descriptors(self, data):
|
def compute_descriptors(self, data):
|
||||||
"""Test compute_descriptors."""
|
"""Test compute_descriptors."""
|
||||||
|
|
||||||
def compute_forces(self, data):
|
def compute_forces(self, data):
|
||||||
"""Test compute_forces."""
|
"""Test compute_forces."""
|
||||||
eij, fij = self.compute_pair_ef(data)
|
eij, fij = self.compute_pair_ef(data)
|
||||||
data.update_pair_energy(eij)
|
data.update_pair_energy(eij)
|
||||||
data.update_pair_forces(fij)
|
data.update_pair_forces(fij)
|
||||||
|
|
||||||
def compute_pair_ef(self, data):
|
def compute_pair_ef(self, data):
|
||||||
rij = data.rij
|
rij = data.rij
|
||||||
|
|
||||||
|
|||||||
@ -80,10 +80,10 @@ class TorchWrapper(torch.nn.Module):
|
|||||||
|
|
||||||
n_params : torch.nn.Module (None)
|
n_params : torch.nn.Module (None)
|
||||||
Number of NN model parameters
|
Number of NN model parameters
|
||||||
|
|
||||||
device : torch.nn.Module (None)
|
device : torch.nn.Module (None)
|
||||||
Accelerator device
|
Accelerator device
|
||||||
|
|
||||||
dtype : torch.dtype (torch.float64)
|
dtype : torch.dtype (torch.float64)
|
||||||
Dtype to use on device
|
Dtype to use on device
|
||||||
"""
|
"""
|
||||||
@ -325,6 +325,6 @@ class ElemwiseModels(torch.nn.Module):
|
|||||||
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=self.dtype)
|
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=self.dtype)
|
||||||
given_elems, elem_indices = torch.unique(elems, return_inverse=True)
|
given_elems, elem_indices = torch.unique(elems, return_inverse=True)
|
||||||
for i, elem in enumerate(given_elems):
|
for i, elem in enumerate(given_elems):
|
||||||
self.subnets[elem].to(self.dtype)
|
self.subnets[elem].to(self.dtype)
|
||||||
per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten()
|
per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten()
|
||||||
return per_atom_attributes
|
return per_atom_attributes
|
||||||
|
|||||||
Reference in New Issue
Block a user