diff --git a/python/lammps/mliap/pytorch.py b/python/lammps/mliap/pytorch.py index f7c3d05c76..71ce83d640 100644 --- a/python/lammps/mliap/pytorch.py +++ b/python/lammps/mliap/pytorch.py @@ -133,7 +133,7 @@ class TorchWrapper(torch.nn.Module): elems = torch.as_tensor(elems,dtype=torch.int32, device=self.device) elems=elems-1 device = self.device - if (use_gpu_data and device == None and str(beta.device).find('CUDA') == 1): + if (use_gpu_data and (device is None) and (str(beta.device).find('CUDA') == 1)): device = 'cuda' #Override device as it wasn't defined in the model with torch.autograd.enable_grad():