whitespace

This commit is contained in:
Axel Kohlmeyer
2023-08-18 06:40:51 -04:00
parent cdbbe33933
commit 2af8842877

View File

@ -18,7 +18,6 @@
import numpy as np import numpy as np
import torch import torch
def calc_n_params(model): def calc_n_params(model):
""" """
Returns the sum of two decimal numbers in binary digits. Returns the sum of two decimal numbers in binary digits.
@ -144,7 +143,7 @@ class TorchWrapper(torch.nn.Module):
else: else:
energy_nn = self.model(descriptors, elems).flatten() energy_nn = self.model(descriptors, elems).flatten()
energy[:] = energy_nn.detach().cpu().numpy().astype(np.float64) energy[:] = energy_nn.detach().cpu().numpy().astype(np.float64)
if (use_gpu_data): if (use_gpu_data):
beta_nn = torch.as_tensor(beta,dtype=self.dtype, device=device) beta_nn = torch.as_tensor(beta,dtype=self.dtype, device=device)
beta_nn[:] = torch.autograd.grad(energy_nn.sum(), descriptors)[0] beta_nn[:] = torch.autograd.grad(energy_nn.sum(), descriptors)[0]