Make Pytorch optional

This commit is contained in:
Forrest Glines
2025-03-25 13:00:14 -06:00
parent 32e4a0d36b
commit d850d93dad

View File

@ -1,7 +1,6 @@
# cython: language_level=3
# distutils: language = c++
import torch
import pickle
import numpy as np
import lammps.mliap
@ -191,8 +190,6 @@ cdef class MLIAPDataPy:
self.update_pair_forces_gpu(fij)
def forward_exchange(self, copy_from, copy_to, vec_len):
#(fglines): I copied this pattern from above -- why this duck typing?
#The except clause is for Pytorch tensors, what is the try clause for?
cdef uintptr_t copy_from_ptr;
try:
@ -213,16 +210,24 @@ cdef class MLIAPDataPy:
if copy_from_dtype != copy_to_dtype:
raise TypeError(f"Types of ({copy_from_dtype})copy_from and ({copy_to_dtype})copy_to mismatch")
if copy_from_dtype == torch.float32:
self.data.forward_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
elif copy_from_dtype == torch.float64:
self.data.forward_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
else:
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
try:
import torch
if copy_from_dtype == torch.float32 or copy_from_dtype == np.float32:
self.data.forward_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
elif copy_from_dtype == torch.float64 or copy_from_dtype == np.float64:
self.data.forward_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
else:
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
except ModuleNotFoundError:
#Torch not installed, just check numpy
if copy_from_dtype == np.float32:
self.data.forward_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
elif copy_from_dtype == np.float64:
self.data.forward_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
else:
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
def reverse_exchange(self, copy_from, copy_to, vec_len):
#(fglines): I copied this pattern from above -- why this duck typing?
#The except clause is for Pytorch tensors, what is the try clause for?
cdef uintptr_t copy_from_ptr;
try:
@ -243,12 +248,22 @@ cdef class MLIAPDataPy:
if copy_from_dtype != copy_to_dtype:
raise TypeError(f"Types of ({copy_from_dtype})copy_from and ({copy_to_dtype})copy_to mismatch")
if copy_from_dtype == torch.float32:
self.data.reverse_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
elif copy_from_dtype == torch.float64:
self.data.reverse_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
else:
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
try:
import torch
if copy_from_dtype == torch.float32 or copy_from_dtype == np.float32:
self.data.reverse_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
elif copy_from_dtype == torch.float64 or copy_from_dtype == np.float64:
self.data.reverse_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
else:
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
except ModuleNotFoundError:
#Torch not installed, just check numpy
if copy_from_dtype == np.float32:
self.data.reverse_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
elif copy_from_dtype == np.float64:
self.data.reverse_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
else:
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
@property
def f(self):