Working MD with 2 atom types
This commit is contained in:
@ -134,6 +134,9 @@ class TorchWrapper(torch.nn.Module):
|
|||||||
descriptors = torch.from_numpy(descriptors).to(dtype=self.dtype, device=self.device).requires_grad_(True)
|
descriptors = torch.from_numpy(descriptors).to(dtype=self.dtype, device=self.device).requires_grad_(True)
|
||||||
elems = torch.from_numpy(elems).to(dtype=torch.long, device=self.device) - 1
|
elems = torch.from_numpy(elems).to(dtype=torch.long, device=self.device) - 1
|
||||||
|
|
||||||
|
#print(self.model)
|
||||||
|
#print("ASDFASDF")
|
||||||
|
|
||||||
with torch.autograd.enable_grad():
|
with torch.autograd.enable_grad():
|
||||||
|
|
||||||
energy_nn = self.model(descriptors, elems)
|
energy_nn = self.model(descriptors, elems)
|
||||||
@ -300,7 +303,7 @@ class ElemwiseModels(torch.nn.Module):
|
|||||||
self.subnets = subnets
|
self.subnets = subnets
|
||||||
self.n_types = n_types
|
self.n_types = n_types
|
||||||
|
|
||||||
def forward(self, descriptors, elems):
|
def forward(self, descriptors, elems, dtype=torch.float64):
|
||||||
"""
|
"""
|
||||||
Feeds descriptors to network model after adding zeros into
|
Feeds descriptors to network model after adding zeros into
|
||||||
descriptor columns relating to different atom types
|
descriptor columns relating to different atom types
|
||||||
@ -319,8 +322,24 @@ class ElemwiseModels(torch.nn.Module):
|
|||||||
Per atom attribute computed by the network model
|
Per atom attribute computed by the network model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
per_atom_attributes = torch.zeros(elems.size[0])
|
#print("^^^^^^")
|
||||||
|
#print(elems)
|
||||||
|
#print(elems.size(dim=0))
|
||||||
|
#print(descriptors.dtype)
|
||||||
|
|
||||||
|
self.dtype=dtype
|
||||||
|
self.to(self.dtype)
|
||||||
|
self.subnets[0].to(torch.float64)
|
||||||
|
self.subnets[1].to(torch.float64)
|
||||||
|
|
||||||
|
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=torch.float64)
|
||||||
|
#print("^^^^^ -----")
|
||||||
|
|
||||||
given_elems, elem_indices = torch.unique(elems, return_inverse=True)
|
given_elems, elem_indices = torch.unique(elems, return_inverse=True)
|
||||||
|
#print(per_atom_attributes.size())
|
||||||
|
#print(elem_indices.size())
|
||||||
for i, elem in enumerate(given_elems):
|
for i, elem in enumerate(given_elems):
|
||||||
per_atom_attribute[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i])
|
#print(descriptors.size())
|
||||||
|
#print(self.subnets[elem](descriptors).flatten().size())
|
||||||
|
per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten()
|
||||||
return per_atom_attributes
|
return per_atom_attributes
|
||||||
|
|||||||
Reference in New Issue
Block a user