address CodeQL issue
This commit is contained in:
@ -133,7 +133,7 @@ class TorchWrapper(torch.nn.Module):
|
||||
elems = torch.as_tensor(elems,dtype=torch.int32, device=self.device)
|
||||
elems=elems-1
|
||||
device = self.device
|
||||
if (use_gpu_data and device == None and str(beta.device).find('CUDA') == 1):
|
||||
if (use_gpu_data and (device is None) and (str(beta.device).find('CUDA') == 1)):
|
||||
device = 'cuda' #Override device as it wasn't defined in the model
|
||||
with torch.autograd.enable_grad():
|
||||
|
||||
|
||||
Reference in New Issue
Block a user