Made check incase PyTorch didn't define the device
This commit is contained in:
@ -18,6 +18,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def calc_n_params(model):
|
def calc_n_params(model):
|
||||||
"""
|
"""
|
||||||
Returns the sum of two decimal numbers in binary digits.
|
Returns the sum of two decimal numbers in binary digits.
|
||||||
@ -132,22 +133,26 @@ class TorchWrapper(torch.nn.Module):
|
|||||||
descriptors = torch.as_tensor(descriptors,dtype=self.dtype, device=self.device).requires_grad_(True)
|
descriptors = torch.as_tensor(descriptors,dtype=self.dtype, device=self.device).requires_grad_(True)
|
||||||
elems = torch.as_tensor(elems,dtype=torch.int32, device=self.device)
|
elems = torch.as_tensor(elems,dtype=torch.int32, device=self.device)
|
||||||
elems=elems-1
|
elems=elems-1
|
||||||
|
device = self.device
|
||||||
|
if (use_gpu_data and device == None and str(beta.device).find('CUDA') == 1):
|
||||||
|
device = 'cuda' #Override device as it wasn't defined in the model
|
||||||
with torch.autograd.enable_grad():
|
with torch.autograd.enable_grad():
|
||||||
|
|
||||||
if (use_gpu_data):
|
if (use_gpu_data):
|
||||||
energy_nn = torch.as_tensor(energy,dtype=self.dtype, device=self.device)
|
energy_nn = torch.as_tensor(energy,dtype=self.dtype, device=device)
|
||||||
energy_nn[:] = self.model(descriptors, elems).flatten()
|
energy_nn[:] = self.model(descriptors, elems).flatten()
|
||||||
else:
|
else:
|
||||||
energy_nn = self.model(descriptors, elems).flatten()
|
energy_nn = self.model(descriptors, elems).flatten()
|
||||||
energy[:] = energy_nn.detach().cpu().numpy().astype(np.float64)
|
energy[:] = energy_nn.detach().cpu().numpy().astype(np.float64)
|
||||||
|
|
||||||
if (use_gpu_data):
|
if (use_gpu_data):
|
||||||
beta_nn = torch.as_tensor(beta,dtype=self.dtype, device=self.device)
|
beta_nn = torch.as_tensor(beta,dtype=self.dtype, device=device)
|
||||||
beta_nn[:] = torch.autograd.grad(energy_nn.sum(), descriptors)[0]
|
beta_nn[:] = torch.autograd.grad(energy_nn.sum(), descriptors)[0]
|
||||||
else:
|
else:
|
||||||
beta_nn = torch.autograd.grad(energy_nn.sum(), descriptors)[0]
|
beta_nn = torch.autograd.grad(energy_nn.sum(), descriptors)[0]
|
||||||
beta[:] = beta_nn.detach().cpu().numpy().astype(np.float64)
|
beta[:] = beta_nn.detach().cpu().numpy().astype(np.float64)
|
||||||
|
|
||||||
|
|
||||||
class IgnoreElems(torch.nn.Module):
|
class IgnoreElems(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
A class to represent a NN model agnostic of element typing.
|
A class to represent a NN model agnostic of element typing.
|
||||||
|
|||||||
Reference in New Issue
Block a user