Clean up cupy vs. torch in exchanges

This commit is contained in:
Forrest Glines
2025-03-25 15:39:20 -06:00
parent d850d93dad
commit 3c055fe93b

View File

@ -8,6 +8,10 @@ try:
import cupy
except ImportError:
pass
try:
import torch
except ImportError:
pass
from libc.stdint cimport uintptr_t
cimport cython
@ -190,36 +194,29 @@ cdef class MLIAPDataPy:
self.update_pair_forces_gpu(fij)
def forward_exchange(self, copy_from, copy_to, vec_len):
cdef uintptr_t copy_from_ptr, copy_to_ptr;
cdef uintptr_t copy_from_ptr;
try:
copy_from_ptr = copy_from.data.ptr
copy_from_dtype = copy_from.data.dtype
except:
copy_from_ptr = copy_from.data_ptr()
copy_from_dtype = copy_from.dtype
cdef uintptr_t copy_to_ptr;
try:
copy_to_ptr = copy_to.data.ptr
copy_to_dtype = copy_to.data.dtype
except:
copy_to_ptr = copy_to.data_ptr()
copy_to_dtype = copy_to.dtype
copy_from_dtype = copy_from.dtype
copy_to_dtype = copy_to.dtype
if copy_from_dtype != copy_to_dtype:
raise TypeError(f"Types of ({copy_from_dtype})copy_from and ({copy_to_dtype})copy_to mismatch")
try:
import torch
if copy_from_dtype == torch.float32 or copy_from_dtype == np.float32:
#Attempt assuming PyTorch data
copy_from_ptr = copy_from.data_ptr()
copy_to_ptr = copy_to.data_ptr()
if copy_from_dtype == torch.float32:
self.data.forward_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
elif copy_from_dtype == torch.float64 or copy_from_dtype == np.float64:
elif copy_from_dtype == torch.float64:
self.data.forward_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
else:
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
except ModuleNotFoundError:
#Torch not installed, just check numpy
except:
#Attempt assuming Numpy data
copy_from_ptr = copy_from.data.ptr
copy_to_ptr = copy_to.data.ptr
if copy_from_dtype == np.float32:
self.data.forward_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
elif copy_from_dtype == np.float64:
@ -228,36 +225,29 @@ cdef class MLIAPDataPy:
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
def reverse_exchange(self, copy_from, copy_to, vec_len):
cdef uintptr_t copy_from_ptr;
try:
copy_from_ptr = copy_from.data.ptr
copy_from_dtype = copy_from.data.dtype
except:
copy_from_ptr = copy_from.data_ptr()
copy_from_dtype = copy_from.dtype
cdef uintptr_t copy_to_ptr;
try:
copy_to_ptr = copy_to.data.ptr
copy_to_dtype = copy_to.data.dtype
except:
copy_to_ptr = copy_to.data_ptr()
copy_to_dtype = copy_to.dtype
cdef uintptr_t copy_from_ptr, copy_to_ptr;
copy_from_dtype = copy_from.dtype
copy_to_dtype = copy_to.dtype
if copy_from_dtype != copy_to_dtype:
raise TypeError(f"Types of ({copy_from_dtype})copy_from and ({copy_to_dtype})copy_to mismatch")
try:
import torch
if copy_from_dtype == torch.float32 or copy_from_dtype == np.float32:
#Attempt assuming PyTorch data
copy_from_ptr = copy_from.data_ptr()
copy_to_ptr = copy_to.data_ptr()
if copy_from_dtype == torch.float32:
self.data.reverse_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
elif copy_from_dtype == torch.float64 or copy_from_dtype == np.float64:
elif copy_from_dtype == torch.float64:
self.data.reverse_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
else:
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
except ModuleNotFoundError:
#Torch not installed, just check numpy
except:
#Attempt assuming Numpy data
copy_from_ptr = copy_from.data.ptr
copy_to_ptr = copy_to.data.ptr
if copy_from_dtype == np.float32:
self.data.reverse_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
elif copy_from_dtype == np.float64: