This commit is contained in:
cjknight
2024-08-30 14:05:36 -05:00
parent 0c753d92ba
commit d9e6dff93b
3 changed files with 10 additions and 13 deletions

View File

@ -155,14 +155,12 @@ public:
KOKKOS_INLINE_FUNCTION
void operator() (const int &i) const {
#if defined(FFT_KOKKOS_FFTW3) || defined(FFT_KOKKOS_CUFFT) || defined(FFT_KOKKOS_HIPFFT)
#if defined(FFT_KOKKOS_FFTW3) || defined(FFT_KOKKOS_CUFFT) || defined(FFT_KOKKOS_HIPFFT) || defined(FFT_KOKKOS_MKL_GPU)
FFT_SCALAR* out_ptr = (FFT_SCALAR *)(d_out.data()+i);
*(out_ptr++) *= norm;
*(out_ptr++) *= norm;
#elif defined(FFT_KOKKOS_MKL)
d_out(i) *= norm;
#elif defined(FFT_KOKKOS_MKL_GPU)
d_out(i) *= norm;
#else // FFT_KOKKOS_KISS
d_out(i).re *= norm;
d_out(i).im *= norm;
@ -635,24 +633,21 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
sycl::queue queue = LMPDeviceType().sycl_queue(); // is this the correct queue?
plan->desc_fast = new descriptor_t (nfast);
plan->desc_fast->set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (FFT_SCALAR)(1.0/nfast));
plan->desc_fast->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, plan->total1/nfast);
plan->desc_fast->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nfast);
plan->desc_fast->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nfast);
plan->desc_fast->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, plan->length1);
plan->desc_fast->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, plan->length1);
plan->desc_fast->commit(queue);
plan->desc_mid = new descriptor_t (nmid);
plan->desc_mid->set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (FFT_SCALAR)(1.0/nmid));
plan->desc_mid->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, plan->total2/nmid);
plan->desc_mid->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nmid);
plan->desc_mid->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nmid);
plan->desc_mid->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, plan->length2);
plan->desc_mid->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, plan->length2);
plan->desc_mid->commit(queue);
plan->desc_slow = new descriptor_t (nslow);
plan->desc_slow->set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (FFT_SCALAR)(1.0/nslow));
plan->desc_slow->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, plan->total3/nslow);
plan->desc_slow->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nslow);
plan->desc_slow->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nslow);
plan->desc_slow->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, plan->length3);
plan->desc_slow->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, plan->length3);
plan->desc_slow->commit(queue);
#elif defined(FFT_KOKKOS_MKL)

View File

@ -21,11 +21,13 @@
namespace LAMMPS_NS {
#if defined(FFT_KOKKOS_MKL_GPU)
#ifdef FFT_SINGLE
typedef oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::COMPLEX> descriptor_t;
#else
typedef oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::COMPLEX> descriptor_t;
#endif
#endif
// -------------------------------------------------------------------------

View File

@ -110,7 +110,7 @@
#if defined(FFT_KOKKOS_MKL_GPU)
#include "CL/sycl.hpp"
#include "oneapi/mkl/dfti.hpp" // conflict between PRECISION macro in dfti.hpp and kokkos_type.h
#include "oneapi/mkl/dfti.hpp"
#include "mkl.h"
#if defined(FFT_SINGLE)
typedef float FFT_KOKKOS_DATA;