Make Pytorch optional
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
# cython: language_level=3
|
||||
# distutils: language = c++
|
||||
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
import lammps.mliap
|
||||
@ -191,8 +190,6 @@ cdef class MLIAPDataPy:
|
||||
self.update_pair_forces_gpu(fij)
|
||||
|
||||
def forward_exchange(self, copy_from, copy_to, vec_len):
|
||||
#(fglines): I copied this pattern from above -- why this duck typing?
|
||||
#The except clause is for Pytorch tensors, what is the try clause for?
|
||||
|
||||
cdef uintptr_t copy_from_ptr;
|
||||
try:
|
||||
@ -213,16 +210,24 @@ cdef class MLIAPDataPy:
|
||||
if copy_from_dtype != copy_to_dtype:
|
||||
raise TypeError(f"Types of ({copy_from_dtype})copy_from and ({copy_to_dtype})copy_to mismatch")
|
||||
|
||||
if copy_from_dtype == torch.float32:
|
||||
self.data.forward_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
|
||||
elif copy_from_dtype == torch.float64:
|
||||
self.data.forward_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
|
||||
else:
|
||||
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
|
||||
try:
|
||||
import torch
|
||||
if copy_from_dtype == torch.float32 or copy_from_dtype == np.float32:
|
||||
self.data.forward_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
|
||||
elif copy_from_dtype == torch.float64 or copy_from_dtype == np.float64:
|
||||
self.data.forward_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
|
||||
else:
|
||||
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
|
||||
except ModuleNotFoundError:
|
||||
#Torch not installed, just check numpy
|
||||
if copy_from_dtype == np.float32:
|
||||
self.data.forward_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
|
||||
elif copy_from_dtype == np.float64:
|
||||
self.data.forward_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
|
||||
else:
|
||||
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
|
||||
|
||||
def reverse_exchange(self, copy_from, copy_to, vec_len):
|
||||
#(fglines): I copied this pattern from above -- why this duck typing?
|
||||
#The except clause is for Pytorch tensors, what is the try clause for?
|
||||
|
||||
cdef uintptr_t copy_from_ptr;
|
||||
try:
|
||||
@ -243,12 +248,22 @@ cdef class MLIAPDataPy:
|
||||
if copy_from_dtype != copy_to_dtype:
|
||||
raise TypeError(f"Types of ({copy_from_dtype})copy_from and ({copy_to_dtype})copy_to mismatch")
|
||||
|
||||
if copy_from_dtype == torch.float32:
|
||||
self.data.reverse_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
|
||||
elif copy_from_dtype == torch.float64:
|
||||
self.data.reverse_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
|
||||
else:
|
||||
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
|
||||
try:
|
||||
import torch
|
||||
if copy_from_dtype == torch.float32 or copy_from_dtype == np.float32:
|
||||
self.data.reverse_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
|
||||
elif copy_from_dtype == torch.float64 or copy_from_dtype == np.float64:
|
||||
self.data.reverse_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
|
||||
else:
|
||||
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
|
||||
except ModuleNotFoundError:
|
||||
#Torch not installed, just check numpy
|
||||
if copy_from_dtype == np.float32:
|
||||
self.data.reverse_exchange( <float*>copy_from_ptr, <float*>copy_to_ptr, vec_len)
|
||||
elif copy_from_dtype == np.float64:
|
||||
self.data.reverse_exchange( <double*>copy_from_ptr, <double*>copy_to_ptr, vec_len)
|
||||
else:
|
||||
raise TypeError(f"Unsupported comms type: ({copy_from_dtype})")
|
||||
|
||||
@property
|
||||
def f(self):
|
||||
|
||||
Reference in New Issue
Block a user