Working MD with 2 atom types

This commit is contained in:
rohskopf
2022-09-06 08:21:01 -06:00
parent dcf6bca3ad
commit bd7a7d4f74

View File

@ -134,6 +134,9 @@ class TorchWrapper(torch.nn.Module):
descriptors = torch.from_numpy(descriptors).to(dtype=self.dtype, device=self.device).requires_grad_(True) descriptors = torch.from_numpy(descriptors).to(dtype=self.dtype, device=self.device).requires_grad_(True)
elems = torch.from_numpy(elems).to(dtype=torch.long, device=self.device) - 1 elems = torch.from_numpy(elems).to(dtype=torch.long, device=self.device) - 1
#print(self.model)
#print("ASDFASDF")
with torch.autograd.enable_grad(): with torch.autograd.enable_grad():
energy_nn = self.model(descriptors, elems) energy_nn = self.model(descriptors, elems)
@ -300,7 +303,7 @@ class ElemwiseModels(torch.nn.Module):
self.subnets = subnets self.subnets = subnets
self.n_types = n_types self.n_types = n_types
def forward(self, descriptors, elems): def forward(self, descriptors, elems, dtype=torch.float64):
""" """
Feeds descriptors to network model after adding zeros into Feeds descriptors to network model after adding zeros into
descriptor columns relating to different atom types descriptor columns relating to different atom types
@ -319,8 +322,24 @@ class ElemwiseModels(torch.nn.Module):
Per atom attribute computed by the network model Per atom attribute computed by the network model
""" """
per_atom_attributes = torch.zeros(elems.size[0]) #print("^^^^^^")
#print(elems)
#print(elems.size(dim=0))
#print(descriptors.dtype)
self.dtype=dtype
self.to(self.dtype)
self.subnets[0].to(torch.float64)
self.subnets[1].to(torch.float64)
per_atom_attributes = torch.zeros(elems.size(dim=0), dtype=torch.float64)
#print("^^^^^ -----")
given_elems, elem_indices = torch.unique(elems, return_inverse=True) given_elems, elem_indices = torch.unique(elems, return_inverse=True)
#print(per_atom_attributes.size())
#print(elem_indices.size())
for i, elem in enumerate(given_elems): for i, elem in enumerate(given_elems):
per_atom_attribute[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]) #print(descriptors.size())
#print(self.subnets[elem](descriptors).flatten().size())
per_atom_attributes[elem_indices == i] = self.subnets[elem](descriptors[elem_indices == i]).flatten()
return per_atom_attributes return per_atom_attributes