From d850d93dad6ecd3a0f1fd8f449f32d9ce862bb4d Mon Sep 17 00:00:00 2001 From: Forrest Glines Date: Tue, 25 Mar 2025 13:00:14 -0600 Subject: [PATCH] Make Pytorch optional --- src/KOKKOS/mliap_unified_couple_kokkos.pyx | 49 ++++++++++++++-------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/src/KOKKOS/mliap_unified_couple_kokkos.pyx b/src/KOKKOS/mliap_unified_couple_kokkos.pyx index fac569ac08..e375deef44 100644 --- a/src/KOKKOS/mliap_unified_couple_kokkos.pyx +++ b/src/KOKKOS/mliap_unified_couple_kokkos.pyx @@ -1,7 +1,6 @@ # cython: language_level=3 # distutils: language = c++ -import torch import pickle import numpy as np import lammps.mliap @@ -191,8 +190,6 @@ cdef class MLIAPDataPy: self.update_pair_forces_gpu(fij) def forward_exchange(self, copy_from, copy_to, vec_len): - #(fglines): I copied this pattern from above -- why this duck typing? - #The except clause is for Pytorch tensors, what is the try clause for? cdef uintptr_t copy_from_ptr; try: @@ -213,16 +210,24 @@ cdef class MLIAPDataPy: if copy_from_dtype != copy_to_dtype: raise TypeError(f"Types of ({copy_from_dtype})copy_from and ({copy_to_dtype})copy_to mismatch") - if copy_from_dtype == torch.float32: - self.data.forward_exchange( copy_from_ptr, copy_to_ptr, vec_len) - elif copy_from_dtype == torch.float64: - self.data.forward_exchange( copy_from_ptr, copy_to_ptr, vec_len) - else: - raise TypeError(f"Unsupported comms type: ({copy_from_dtype})") + try: + import torch + if copy_from_dtype == torch.float32 or copy_from_dtype == np.float32: + self.data.forward_exchange( copy_from_ptr, copy_to_ptr, vec_len) + elif copy_from_dtype == torch.float64 or copy_from_dtype == np.float64: + self.data.forward_exchange( copy_from_ptr, copy_to_ptr, vec_len) + else: + raise TypeError(f"Unsupported comms type: ({copy_from_dtype})") + except ModuleNotFoundError: + #Torch not installed, just check numpy + if copy_from_dtype == np.float32: + self.data.forward_exchange( copy_from_ptr, copy_to_ptr, vec_len) + elif copy_from_dtype == np.float64: + self.data.forward_exchange( copy_from_ptr, copy_to_ptr, vec_len) + else: + raise TypeError(f"Unsupported comms type: ({copy_from_dtype})") def reverse_exchange(self, copy_from, copy_to, vec_len): - #(fglines): I copied this pattern from above -- why this duck typing? - #The except clause is for Pytorch tensors, what is the try clause for? cdef uintptr_t copy_from_ptr; try: @@ -243,12 +248,22 @@ cdef class MLIAPDataPy: if copy_from_dtype != copy_to_dtype: raise TypeError(f"Types of ({copy_from_dtype})copy_from and ({copy_to_dtype})copy_to mismatch") - if copy_from_dtype == torch.float32: - self.data.reverse_exchange( copy_from_ptr, copy_to_ptr, vec_len) - elif copy_from_dtype == torch.float64: - self.data.reverse_exchange( copy_from_ptr, copy_to_ptr, vec_len) - else: - raise TypeError(f"Unsupported comms type: ({copy_from_dtype})") + try: + import torch + if copy_from_dtype == torch.float32 or copy_from_dtype == np.float32: + self.data.reverse_exchange( copy_from_ptr, copy_to_ptr, vec_len) + elif copy_from_dtype == torch.float64 or copy_from_dtype == np.float64: + self.data.reverse_exchange( copy_from_ptr, copy_to_ptr, vec_len) + else: + raise TypeError(f"Unsupported comms type: ({copy_from_dtype})") + except ModuleNotFoundError: + #Torch not installed, just check numpy + if copy_from_dtype == np.float32: + self.data.reverse_exchange( copy_from_ptr, copy_to_ptr, vec_len) + elif copy_from_dtype == np.float64: + self.data.reverse_exchange( copy_from_ptr, copy_to_ptr, vec_len) + else: + raise TypeError(f"Unsupported comms type: ({copy_from_dtype})") @property def f(self):