Add cuFFT norm functor
This commit is contained in:
@ -103,6 +103,28 @@ void FFT3dKokkos<DeviceType>::timing1d(typename AT::t_FFT_SCALAR_1d d_in, int ns
|
||||
plan plan returned by previous call to fft_3d_create_plan
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
#ifdef FFT_CUFFT
|
||||
template<class DeviceType>
|
||||
struct cufft_norm_functor {
|
||||
public:
|
||||
typedef DeviceType device_type;
|
||||
typedef ArrayTypes<DeviceType> AT;
|
||||
typename AT::t_FFT_SCALAR_1d_um d_out;
|
||||
int norm;
|
||||
|
||||
cufft_norm_functor(typename AT::t_FFT_SCALAR_1d &d_out_, int norm_):
|
||||
d_out(d_out_)
|
||||
{
|
||||
norm = norm_;
|
||||
}
|
||||
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void operator() (const int &i) const {
|
||||
d_out(i) *= norm;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef FFT_KISSFFT
|
||||
template<class DeviceType>
|
||||
struct kiss_fft_functor {
|
||||
@ -131,14 +153,14 @@ public:
|
||||
};
|
||||
|
||||
template<class DeviceType>
|
||||
struct norm_functor {
|
||||
struct kiss_norm_functor {
|
||||
public:
|
||||
typedef DeviceType device_type;
|
||||
typedef ArrayTypes<DeviceType> AT;
|
||||
typename AT::t_FFT_DATA_1d_um d_out;
|
||||
int norm;
|
||||
|
||||
norm_functor(typename AT::t_FFT_DATA_1d &d_out_, int norm_):
|
||||
kiss_norm_functor(typename AT::t_FFT_DATA_1d &d_out_, int norm_):
|
||||
d_out(d_out_)
|
||||
{
|
||||
norm = norm_;
|
||||
@ -297,9 +319,13 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename ArrayTypes<DeviceType>::t_F
|
||||
norm = plan->norm;
|
||||
num = plan->normnum;
|
||||
#if defined(FFT_CUFFT)
|
||||
//scale(ptr, norm, num); //////////////////////
|
||||
typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d d_norm_scalar =
|
||||
typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
|
||||
cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
|
||||
Kokkos::parallel_for(num,f);
|
||||
DeviceType::fence();
|
||||
#else
|
||||
norm_functor<DeviceType> f(d_out,norm);
|
||||
kiss_norm_functor<DeviceType> f(d_out,norm);
|
||||
Kokkos::parallel_for(num,f);
|
||||
DeviceType::fence();
|
||||
#endif
|
||||
@ -757,9 +783,13 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename ArrayTypes<DeviceTy
|
||||
FFT_SCALAR norm = plan->norm;
|
||||
num = MIN(plan->normnum,nsize);
|
||||
#if defined(FFT_CUFFT)
|
||||
//scale(ptr, norm, num); ///////////////
|
||||
typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d d_norm_scalar =
|
||||
typename ArrayTypes<DeviceType>::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
|
||||
cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
|
||||
Kokkos::parallel_for(num,f);
|
||||
DeviceType::fence();
|
||||
#else
|
||||
norm_functor<DeviceType> f(d_data,norm);
|
||||
kiss_norm_functor<DeviceType> f(d_data,norm);
|
||||
Kokkos::parallel_for(num,f);
|
||||
DeviceType::fence();
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user