Working MD with 2 atom types

This commit is contained in:
rohskopf
2022-09-06 08:21:01 -06:00
parent dcf6bca3ad
commit bd7a7d4f74

View File

@ -134,6 +134,9 @@ class TorchWrapper(torch.nn.Module):
descriptors = torch.from_numpy(descriptors).to(dtype=self.dtype, device=self.device).requires_grad_(True)
elems = torch.from_numpy(elems).to(dtype=torch.long, device=self.device) - 1
#print(self.model)
#print("ASDFASDF")
with torch.autograd.enable_grad():
energy_nn = self.model(descriptors, elems)
@ -300,7 +303,7 @@ class ElemwiseModels(torch.nn.Module):
self.subnets = subnets
self.n_types = n_types
def forward(self, descriptors, elems):
def forward(self, descriptors, elems, dtype=torch.float64):
"""
Feeds descriptors to network model after adding zeros into
descriptor columns relating to different atom types
@ -319,8 +322,24 @@ class ElemwiseModels(torch.nn.Module):
Per atom attribute computed by the network model
"""
per_atom_attributes = torch.zeros(elems.size[0])
#print("^^^^^^")
#print(elems)
#print(elems.size(dim=0))
#print(descriptors.dtype)
self.dtype=dtype
self.to(self.dtype)
self.subnets[0].to(torch.float64)
self.subnets[1].to(torch.float64)
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=torch.float64)
#print("^^^^^ -----")
given_elems, elem_indices = torch.unique(elems, return_inverse=True)
#print(per_atom_attributes.size())
#print(elem_indices.size())
for i, elem in enumerate(given_elems):
per_atom_attribute[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i])
#print(descriptors.size())
#print(self.subnets[elem](descriptors).flatten().size())
per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten()
return per_atom_attributes