Made check incase PyTorch didn't define the device

This commit is contained in:
Matt Bettencourt
2023-08-14 09:36:56 +02:00
parent b6f7a27b09
commit cdbbe33933

View File

@ -18,6 +18,7 @@
import numpy as np
import torch
def calc_n_params(model):
"""
Returns the sum of two decimal numbers in binary digits.
@ -132,22 +133,26 @@ class TorchWrapper(torch.nn.Module):
descriptors = torch.as_tensor(descriptors,dtype=self.dtype, device=self.device).requires_grad_(True)
elems = torch.as_tensor(elems,dtype=torch.int32, device=self.device)
elems=elems-1
device = self.device
if (use_gpu_data and device == None and str(beta.device).find('CUDA') == 1):
device = 'cuda' #Override device as it wasn't defined in the model
with torch.autograd.enable_grad():
if (use_gpu_data):
energy_nn = torch.as_tensor(energy,dtype=self.dtype, device=self.device)
energy_nn = torch.as_tensor(energy,dtype=self.dtype, device=device)
energy_nn[:] = self.model(descriptors, elems).flatten()
else:
energy_nn = self.model(descriptors, elems).flatten()
energy[:] = energy_nn.detach().cpu().numpy().astype(np.float64)
if (use_gpu_data):
beta_nn = torch.as_tensor(beta,dtype=self.dtype, device=self.device)
beta_nn = torch.as_tensor(beta,dtype=self.dtype, device=device)
beta_nn[:] = torch.autograd.grad(energy_nn.sum(), descriptors)[0]
else:
beta_nn = torch.autograd.grad(energy_nn.sum(), descriptors)[0]
beta[:] = beta_nn.detach().cpu().numpy().astype(np.float64)
class IgnoreElems(torch.nn.Module):
"""
A class to represent a NN model agnostic of element typing.