From bd7a7d4f74714a79c822c5fe1f4181599aa06df1 Mon Sep 17 00:00:00 2001 From: rohskopf Date: Tue, 6 Sep 2022 08:21:01 -0600 Subject: [PATCH] Working MD with 2 atom types --- python/lammps/mliap/pytorch.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/python/lammps/mliap/pytorch.py b/python/lammps/mliap/pytorch.py index 9aa2da80f4..04602926ff 100644 --- a/python/lammps/mliap/pytorch.py +++ b/python/lammps/mliap/pytorch.py @@ -134,6 +134,9 @@ class TorchWrapper(torch.nn.Module): descriptors = torch.from_numpy(descriptors).to(dtype=self.dtype, device=self.device).requires_grad_(True) elems = torch.from_numpy(elems).to(dtype=torch.long, device=self.device) - 1 + #print(self.model) + #print("ASDFASDF") + with torch.autograd.enable_grad(): energy_nn = self.model(descriptors, elems) @@ -300,7 +303,7 @@ class ElemwiseModels(torch.nn.Module): self.subnets = subnets self.n_types = n_types - def forward(self, descriptors, elems): + def forward(self, descriptors, elems, dtype=torch.float64): """ Feeds descriptors to network model after adding zeros into descriptor columns relating to different atom types @@ -319,8 +322,24 @@ class ElemwiseModels(torch.nn.Module): Per atom attribute computed by the network model """ - per_atom_attributes = torch.zeros(elems.size[0]) + #print("^^^^^^") + #print(elems) + #print(elems.size(dim=0)) + #print(descriptors.dtype) + + self.dtype=dtype + self.to(self.dtype) + self.subnets[0].to(torch.float64) + self.subnets[1].to(torch.float64) + + per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=torch.float64) + #print("^^^^^ -----") + given_elems, elem_indices = torch.unique(elems, return_inverse=True) + #print(per_atom_attributes.size()) + #print(elem_indices.size()) for i, elem in enumerate(given_elems): - per_atom_attribute[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]) + #print(descriptors.size()) + #print(self.subnets[elem](descriptors).flatten().size()) + per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten() return per_atom_attributes