Clean up cupy vs. torch in exchanges
This commit is contained in:
@ -8,6 +8,10 @@ try:
|
||||
import cupy
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pass
|
||||
from libc.stdint cimport uintptr_t
|
||||
|
||||
cimport cython
|
||||
@ -190,36 +194,29 @@ cdef class MLIAPDataPy:
|
||||
self.update_pair_forces_gpu(fij)
|
||||
|
||||
def forward_exchange(self, copy_from, copy_to, vec_len):
|
||||
cdef uintptr_t copy_from_ptr, copy_to_ptr;
|
||||
|
||||
cdef uintptr_t copy_from_ptr;
|
||||
try:
|
||||
copy_from_ptr = copy_from.data.ptr
|
||||
copy_from_dtype = copy_from.data.dtype
|
||||
except:
|
||||
copy_from_ptr = copy_from.data_ptr()
|
||||
copy_from_dtype = copy_from.dtype
|
||||
|
||||
cdef uintptr_t copy_to_ptr;
|
||||
try:
|
||||
copy_to_ptr = copy_to.data.ptr
|
||||
copy_to_dtype = copy_to.data.dtype
|
||||
except:
|
||||
copy_to_ptr = copy_to.data_ptr()
|
||||
copy_to_dtype = copy_to.dtype
|
||||
|
||||
copy_from_dtype = copy_from.dtype
|
||||
copy_to_dtype = copy_to.dtype
|
||||
if copy_from_dtype != copy_to_dtype:
|
||||
raise TypeError(f"Types of ({copy_from_dtype})copy_from and ({copy_to_dtype})copy_to mismatch")
|
||||
|
||||
try:
|
||||
import torch
|
||||
if copy_from_dtype == torch.float32 or copy_from_dtype == np.float32:
|
||||
#Attempt assuming PyTorch data
|
||||
copy_from_ptr = copy_from.data_ptr()
|
||||
copy_to_ptr = copy_to.data_ptr()
|
||||
|
||||
if copy_from_dtype == torch.float32:
|
||||
self.data.forward_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
|
||||
elif copy_from_dtype == torch.float64 or copy_from_dtype == np.float64:
|
||||
elif copy_from_dtype == torch.float64:
|
||||
self.data.forward_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
|
||||
else:
|
||||
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
|
||||
except ModuleNotFoundError:
|
||||
#Torch not installed, just check numpy
|
||||
except:
|
||||
#Attempt assuming Numpy data
|
||||
copy_from_ptr = copy_from.data.ptr
|
||||
copy_to_ptr = copy_to.data.ptr
|
||||
|
||||
if copy_from_dtype == np.float32:
|
||||
self.data.forward_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
|
||||
elif copy_from_dtype == np.float64:
|
||||
@ -228,36 +225,29 @@ cdef class MLIAPDataPy:
|
||||
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
|
||||
|
||||
def reverse_exchange(self, copy_from, copy_to, vec_len):
|
||||
|
||||
cdef uintptr_t copy_from_ptr;
|
||||
try:
|
||||
copy_from_ptr = copy_from.data.ptr
|
||||
copy_from_dtype = copy_from.data.dtype
|
||||
except:
|
||||
copy_from_ptr = copy_from.data_ptr()
|
||||
copy_from_dtype = copy_from.dtype
|
||||
|
||||
cdef uintptr_t copy_to_ptr;
|
||||
try:
|
||||
copy_to_ptr = copy_to.data.ptr
|
||||
copy_to_dtype = copy_to.data.dtype
|
||||
except:
|
||||
copy_to_ptr = copy_to.data_ptr()
|
||||
copy_to_dtype = copy_to.dtype
|
||||
cdef uintptr_t copy_from_ptr, copy_to_ptr;
|
||||
|
||||
copy_from_dtype = copy_from.dtype
|
||||
copy_to_dtype = copy_to.dtype
|
||||
if copy_from_dtype != copy_to_dtype:
|
||||
raise TypeError(f"Types of ({copy_from_dtype})copy_from and ({copy_to_dtype})copy_to mismatch")
|
||||
|
||||
try:
|
||||
import torch
|
||||
if copy_from_dtype == torch.float32 or copy_from_dtype == np.float32:
|
||||
#Attempt assuming PyTorch data
|
||||
copy_from_ptr = copy_from.data_ptr()
|
||||
copy_to_ptr = copy_to.data_ptr()
|
||||
|
||||
if copy_from_dtype == torch.float32:
|
||||
self.data.reverse_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
|
||||
elif copy_from_dtype == torch.float64 or copy_from_dtype == np.float64:
|
||||
elif copy_from_dtype == torch.float64:
|
||||
self.data.reverse_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
|
||||
else:
|
||||
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
|
||||
except ModuleNotFoundError:
|
||||
#Torch not installed, just check numpy
|
||||
except:
|
||||
#Attempt assuming Numpy data
|
||||
copy_from_ptr = copy_from.data.ptr
|
||||
copy_to_ptr = copy_to.data.ptr
|
||||
|
||||
if copy_from_dtype == np.float32:
|
||||
self.data.reverse_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
|
||||
elif copy_from_dtype == np.float64:
|
||||
|
||||
Reference in New Issue
Block a user