whitespace

This commit is contained in:
Axel Kohlmeyer
2022-09-24 15:34:41 -04:00
parent a02ab6eaa1
commit 59ca352e48
4 changed files with 12 additions and 12 deletions

View File

@ -80,10 +80,10 @@ class TorchWrapper(torch.nn.Module):
n_params : torch.nn.Module (None)
Number of NN model parameters
device : torch.nn.Module (None)
Accelerator device
dtype : torch.dtype (torch.float64)
Dtype to use on device
"""
@ -325,6 +325,6 @@ class ElemwiseModels(torch.nn.Module):
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=self.dtype)
given_elems, elem_indices = torch.unique(elems, return_inverse=True)
for i, elem in enumerate(given_elems):
self.subnets[elem].to(self.dtype)
self.subnets[elem].to(self.dtype)
per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten()
return per_atom_attributes