whitespace
This commit is contained in:
@ -80,10 +80,10 @@ class TorchWrapper(torch.nn.Module):
|
||||
|
||||
n_params : torch.nn.Module (None)
|
||||
Number of NN model parameters
|
||||
|
||||
|
||||
device : torch.nn.Module (None)
|
||||
Accelerator device
|
||||
|
||||
|
||||
dtype : torch.dtype (torch.float64)
|
||||
Dtype to use on device
|
||||
"""
|
||||
@ -325,6 +325,6 @@ class ElemwiseModels(torch.nn.Module):
|
||||
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=self.dtype)
|
||||
given_elems, elem_indices = torch.unique(elems, return_inverse=True)
|
||||
for i, elem in enumerate(given_elems):
|
||||
self.subnets[elem].to(self.dtype)
|
||||
self.subnets[elem].to(self.dtype)
|
||||
per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten()
|
||||
return per_atom_attributes
|
||||
|
||||
Reference in New Issue
Block a user