address CodeQL issue

This commit is contained in:
Axel Kohlmeyer
2023-08-23 03:03:19 -04:00
parent 351a9dd11f
commit 9999f775cc

View File

@ -133,7 +133,7 @@ class TorchWrapper(torch.nn.Module):
elems = torch.as_tensor(elems,dtype=torch.int32, device=self.device)
elems=elems-1
device = self.device
if (use_gpu_data and device == None and str(beta.device).find('CUDA') == 1):
if (use_gpu_data and (device is None) and (str(beta.device).find('CUDA') == 1)):
device = 'cuda' #Override device as it wasn't defined in the model
with torch.autograd.enable_grad():