diff --git a/python/lammps/mliap/pytorch.py b/python/lammps/mliap/pytorch.py index 04602926ff..9753c4bcd8 100644 --- a/python/lammps/mliap/pytorch.py +++ b/python/lammps/mliap/pytorch.py @@ -322,24 +322,12 @@ class ElemwiseModels(torch.nn.Module): Per atom attribute computed by the network model """ - #print("^^^^^^") - #print(elems) - #print(elems.size(dim=0)) - #print(descriptors.dtype) - self.dtype=dtype self.to(self.dtype) - self.subnets[0].to(torch.float64) - self.subnets[1].to(torch.float64) - - per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=torch.float64) - #print("^^^^^ -----") + per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=self.dtype) given_elems, elem_indices = torch.unique(elems, return_inverse=True) - #print(per_atom_attributes.size()) - #print(elem_indices.size()) for i, elem in enumerate(given_elems): - #print(descriptors.size()) - #print(self.subnets[elem](descriptors).flatten().size()) + self.subnets[elem].to(self.dtype) per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten() return per_atom_attributes