Ensure all subnets are proper dtype
This commit is contained in:
@ -322,24 +322,12 @@ class ElemwiseModels(torch.nn.Module):
|
||||
Per atom attribute computed by the network model
|
||||
"""
|
||||
|
||||
#print("^^^^^^")
|
||||
#print(elems)
|
||||
#print(elems.size(dim=0))
|
||||
#print(descriptors.dtype)
|
||||
|
||||
self.dtype=dtype
|
||||
self.to(self.dtype)
|
||||
self.subnets[0].to(torch.float64)
|
||||
self.subnets[1].to(torch.float64)
|
||||
|
||||
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=torch.float64)
|
||||
#print("^^^^^ -----")
|
||||
|
||||
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=self.dtype)
|
||||
given_elems, elem_indices = torch.unique(elems, return_inverse=True)
|
||||
#print(per_atom_attributes.size())
|
||||
#print(elem_indices.size())
|
||||
for i, elem in enumerate(given_elems):
|
||||
#print(descriptors.size())
|
||||
#print(self.subnets[elem](descriptors).flatten().size())
|
||||
self.subnets[elem].to(self.dtype)
|
||||
per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten()
|
||||
return per_atom_attributes
|
||||
|
||||
Reference in New Issue
Block a user