Add cuFFT norm functor

This commit is contained in:
Stan Moore
2019-06-17 13:18:26 -06:00
parent 0322ebd093
commit ca5aa1f907

View File

@ -103,6 +103,28 @@ void FFT3dKokkos<DeviceType>::timing1d(typename AT::t_FFT_SCALAR_1d d_in, int ns
plan plan returned by previous call to fft_3d_create_plan
------------------------------------------------------------------------- */
#ifdef FFT_CUFFT
template<class DeviceType>
struct cufft_norm_functor {
public:
typedef DeviceType device_type;
typedef ArrayTypes<DeviceType> AT;
typename AT::t_FFT_SCALAR_1d_um d_out;
int norm;
cufft_norm_functor(typename AT::t_FFT_SCALAR_1d &d_out_, int norm_):
d_out(d_out_)
{
norm = norm_;
}
KOKKOS_INLINE_FUNCTION
void operator() (const int &i) const {
d_out(i) *= norm;
}
};
#endif
#ifdef FFT_KISSFFT
template<class DeviceType>
struct kiss_fft_functor {
@ -131,14 +153,14 @@ public:
};
template<class DeviceType>
struct norm_functor {
struct kiss_norm_functor {
public:
typedef DeviceType device_type;
typedef ArrayTypes<DeviceType> AT;
typename AT::t_FFT_DATA_1d_um d_out;
int norm;
norm_functor(typename AT::t_FFT_DATA_1d &d_out_, int norm_):
kiss_norm_functor(typename AT::t_FFT_DATA_1d &d_out_, int norm_):
d_out(d_out_)
{
norm = norm_;
@ -297,9 +319,13 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
norm = plan->norm;
num = plan->normnum;
#if defined(FFT_CUFFT)
//scale(ptr, norm, num); //////////////////////
typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d d_norm_scalar =
typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
Kokkos::parallel_for(num,f);
DeviceType::fence();
#else
norm_functor<DeviceType> f(d_out,norm);
kiss_norm_functor<DeviceType> f(d_out,norm);
Kokkos::parallel_for(num,f);
DeviceType::fence();
#endif
@ -757,9 +783,13 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceTy
FFT_SCALAR norm = plan->norm;
num = MIN(plan->normnum,nsize);
#if defined(FFT_CUFFT)
//scale(ptr, norm, num); ///////////////
typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d d_norm_scalar =
typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
Kokkos::parallel_for(num,f);
DeviceType::fence();
#else
norm_functor<DeviceType> f(d_data,norm);
kiss_norm_functor<DeviceType> f(d_data,norm);
Kokkos::parallel_for(num,f);
DeviceType::fence();
#endif