Ensure all subnets are proper dtype

This commit is contained in:
rohskopf
2022-09-06 08:40:43 -06:00
parent bd7a7d4f74
commit 341f5cf40d

View File

@ -322,24 +322,12 @@ class ElemwiseModels(torch.nn.Module):
Per atom attribute computed by the network model
"""
#print("^^^^^^")
#print(elems)
#print(elems.size(dim=0))
#print(descriptors.dtype)
self.dtype=dtype
self.to(self.dtype)
self.subnets[0].to(torch.float64)
self.subnets[1].to(torch.float64)
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=torch.float64)
#print("^^^^^ -----")
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=self.dtype)
given_elems, elem_indices = torch.unique(elems, return_inverse=True)
#print(per_atom_attributes.size())
#print(elem_indices.size())
for i, elem in enumerate(given_elems):
#print(descriptors.size())
#print(self.subnets[elem](descriptors).flatten().size())
self.subnets[elem].to(self.dtype)
per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten()
return per_atom_attributes