diff --git a/python/lammps/mliap/pytorch.py b/python/lammps/mliap/pytorch.py index 93df96d2e0..442494cb0c 100644 --- a/python/lammps/mliap/pytorch.py +++ b/python/lammps/mliap/pytorch.py @@ -18,6 +18,7 @@ import numpy as np import torch + def calc_n_params(model): """ Returns the sum of two decimal numbers in binary digits. @@ -132,22 +133,26 @@ class TorchWrapper(torch.nn.Module): descriptors = torch.as_tensor(descriptors,dtype=self.dtype, device=self.device).requires_grad_(True) elems = torch.as_tensor(elems,dtype=torch.int32, device=self.device) elems=elems-1 + device = self.device + if (use_gpu_data and device == None and str(beta.device).find('CUDA') == 1): + device = 'cuda' #Override device as it wasn't defined in the model with torch.autograd.enable_grad(): if (use_gpu_data): - energy_nn = torch.as_tensor(energy,dtype=self.dtype, device=self.device) + energy_nn = torch.as_tensor(energy,dtype=self.dtype, device=device) energy_nn[:] = self.model(descriptors, elems).flatten() else: energy_nn = self.model(descriptors, elems).flatten() energy[:] = energy_nn.detach().cpu().numpy().astype(np.float64) - + if (use_gpu_data): - beta_nn = torch.as_tensor(beta,dtype=self.dtype, device=self.device) + beta_nn = torch.as_tensor(beta,dtype=self.dtype, device=device) beta_nn[:] = torch.autograd.grad(energy_nn.sum(), descriptors)[0] else: beta_nn = torch.autograd.grad(energy_nn.sum(), descriptors)[0] beta[:] = beta_nn.detach().cpu().numpy().astype(np.float64) + class IgnoreElems(torch.nn.Module): """ A class to represent a NN model agnostic of element typing.