Add MKL support
This commit is contained in:
@ -46,7 +46,10 @@ FFT3dKokkos<DeviceType>::FFT3dKokkos(LAMMPS *lmp, MPI_Comm comm, int nfast, int
|
||||
int ngpus = lmp->kokkos->ngpus;
|
||||
ExecutionSpace execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
|
||||
|
||||
#if defined(FFT_FFTW3)
|
||||
#if defined(FFT_MKL)
|
||||
if (ngpus > 0 && execution_space == Device)
|
||||
lmp->error->all(FLERR,"Cannot use the MKL library with Kokkos CUDA on GPUs");
|
||||
#elif defined(FFT_FFTW3)
|
||||
if (ngpus > 0 && execution_space == Device)
|
||||
lmp->error->all(FLERR,"Cannot use the FFTW library with Kokkos CUDA on GPUs");
|
||||
#elif defined(FFT_CUFFT)
|
||||
@ -131,27 +134,31 @@ void FFT3dKokkos<DeviceType>::timing1d(typename AT::t_FFT_SCALAR_1d d_in, int ns
|
||||
plan plan returned by previous call to fft_3d_create_plan
|
||||
------------------------------------------------------------------------- */
|
||||
|
||||
#ifdef FFT_CUFFT
|
||||
template<class DeviceType>
|
||||
struct cufft_norm_functor {
|
||||
struct norm_functor {
|
||||
public:
|
||||
typedef DeviceType device_type;
|
||||
typedef ArrayTypes<DeviceType> AT;
|
||||
typename AT::t_FFT_SCALAR_1d_um d_out;
|
||||
typename AT::t_FFT_DATA_1d_um d_out;
|
||||
int norm;
|
||||
|
||||
cufft_norm_functor(typename AT::t_FFT_SCALAR_1d &d_out_, int norm_):
|
||||
d_out(d_out_)
|
||||
{
|
||||
norm = norm_;
|
||||
}
|
||||
norm_functor(typename AT::t_FFT_DATA_1d &d_out_, int norm_):
|
||||
d_out(d_out_),norm(norm_) {}
|
||||
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void operator() (const int &i) const {
|
||||
d_out(i) *= norm;
|
||||
#if defined(FFT_FFTW3) || defined(FFT_CUFFT)
|
||||
FFT_SCALAR* out_ptr = (FFT_SCALAR *)d_out(i);
|
||||
*(out_ptr++) *= norm;
|
||||
*(out_ptr++) *= norm;
|
||||
#elif defined(FFT_MKL)
|
||||
d_out(i) *= norm;
|
||||
#else
|
||||
d_out(i,0) *= norm;
|
||||
d_out(i,1) *= norm;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef FFT_KISSFFT
|
||||
template<class DeviceType>
|
||||
@ -179,27 +186,6 @@ public:
|
||||
KissFFTKokkos<DeviceType>::kiss_fft_kokkos(st,d_data,d_tmp,offset);
|
||||
}
|
||||
};
|
||||
|
||||
template<class DeviceType>
|
||||
struct kiss_norm_functor {
|
||||
public:
|
||||
typedef DeviceType device_type;
|
||||
typedef ArrayTypes<DeviceType> AT;
|
||||
typename AT::t_FFT_DATA_1d_um d_out;
|
||||
int norm;
|
||||
|
||||
kiss_norm_functor(typename AT::t_FFT_DATA_1d &d_out_, int norm_):
|
||||
d_out(d_out_)
|
||||
{
|
||||
norm = norm_;
|
||||
}
|
||||
|
||||
KOKKOS_INLINE_FUNCTION
|
||||
void operator() (const int &i) const {
|
||||
d_out(i,0) *= norm;
|
||||
d_out(i,1) *= norm;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<class DeviceType>
|
||||
@ -231,13 +217,18 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
|
||||
total = plan->total1;
|
||||
length = plan->length1;
|
||||
|
||||
#if defined(FFT_FFTW3)
|
||||
#if defined(FFT_MKL)
|
||||
if (flag == -1)
|
||||
FFTW_API(execute_dft)(plan->plan_fast_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
DftiComputeForward(plan->handle_fast,(FFT_DATA *)d_data.data());
|
||||
else
|
||||
FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
DftiComputeBackward(plan->handle_fast,(FFT_DATA *)d_data.data());
|
||||
#elif defined(FFT_FFTW3)
|
||||
if (flag == -1)
|
||||
fftw_execute_dft(plan->plan_fast_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
else
|
||||
fftw_execute_dft(plan->plan_fast_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
#elif defined(FFT_CUFFT)
|
||||
cufftExec(plan->plan_fast,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
|
||||
cufftExecZ2Z(plan->plan_fast,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
|
||||
#else
|
||||
typename AT::t_FFT_DATA_1d d_tmp =
|
||||
typename AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
|
||||
@ -272,13 +263,18 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
|
||||
total = plan->total2;
|
||||
length = plan->length2;
|
||||
|
||||
#if defined(FFT_FFTW3)
|
||||
#if defined(FFT_MKL)
|
||||
if (flag == -1)
|
||||
FFTW_API(execute_dft)(plan->plan_mid_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
DftiComputeForward(plan->handle_mid,(FFT_DATA *)d_data.data());
|
||||
else
|
||||
FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
DftiComputeBackward(plan->handle_mid,(FFT_DATA *)d_data.data());
|
||||
#elif defined(FFT_FFTW3)
|
||||
if (flag == -1)
|
||||
fftw_execute_dft(plan->plan_mid_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
else
|
||||
fftw_execute_dft(plan->plan_mid_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
#elif defined(FFT_CUFFT)
|
||||
cufftExec(plan->plan_mid,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
|
||||
cufftExecZ2Z(plan->plan_mid,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
|
||||
#else
|
||||
if (flag == -1)
|
||||
f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_forward,length);
|
||||
@ -309,13 +305,18 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
|
||||
total = plan->total3;
|
||||
length = plan->length3;
|
||||
|
||||
#if defined(FFT_FFTW3)
|
||||
#if defined(FFT_MKL)
|
||||
if (flag == -1)
|
||||
FFTW_API(execute_dft)(plan->plan_slow_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
DftiComputeForward(plan->handle_slow,(FFT_DATA *)d_data.data());
|
||||
else
|
||||
FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
DftiComputeBackward(plan->handle_slow,(FFT_DATA *)d_data.data());
|
||||
#elif defined(FFT_FFTW3)
|
||||
if (flag == -1)
|
||||
fftw_execute_dft(plan->plan_slow_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
else
|
||||
fftw_execute_dft(plan->plan_slow_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
|
||||
#elif defined(FFT_CUFFT)
|
||||
cufftExec(plan->plan_slow,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
|
||||
cufftExecZ2Z(plan->plan_slow,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
|
||||
#else
|
||||
if (flag == -1)
|
||||
f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_slow_forward,length);
|
||||
@ -341,17 +342,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
|
||||
// scaling if required
|
||||
|
||||
if (flag == 1 && plan->scaled) {
|
||||
int norm = plan->norm;
|
||||
FFT_SCALAR num = plan->normnum;
|
||||
#if defined(FFT_CUFFT)
|
||||
typename AT::t_FFT_SCALAR_1d d_norm_scalar =
|
||||
typename AT::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
|
||||
cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
|
||||
FFT_SCALAR norm = plan->norm;
|
||||
int num = plan->normnum;
|
||||
|
||||
norm_functor<DeviceType> f(d_out,norm);
|
||||
Kokkos::parallel_for(num,f);
|
||||
#elif defined(FFT_KISSFFT)
|
||||
kiss_norm_functor<DeviceType> f(d_out,norm);
|
||||
Kokkos::parallel_for(num,f);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -604,48 +599,83 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
|
||||
// system specific pre-computation of 1d FFT coeffs
|
||||
// and scaling normalization
|
||||
|
||||
#if defined(FFT_FFTW3)
|
||||
#if defined(FFT_MKL)
|
||||
DftiCreateDescriptor( &(plan->handle_fast), FFT_MKL_PREC, DFTI_COMPLEX, 1,
|
||||
(MKL_LONG)nfast);
|
||||
DftiSetValue(plan->handle_fast, DFTI_NUMBER_OF_TRANSFORMS,
|
||||
(MKL_LONG)plan->total1/nfast);
|
||||
DftiSetValue(plan->handle_fast, DFTI_PLACEMENT,DFTI_INPLACE);
|
||||
DftiSetValue(plan->handle_fast, DFTI_INPUT_DISTANCE, (MKL_LONG)nfast);
|
||||
DftiSetValue(plan->handle_fast, DFTI_OUTPUT_DISTANCE, (MKL_LONG)nfast);
|
||||
DftiSetValue(plan->handle_fast, DFTI_NUMBER_OF_USER_THREADS, nthreads);
|
||||
DftiCommitDescriptor(plan->handle_fast);
|
||||
|
||||
#if defined(FFT_FFTW_THREADS)
|
||||
if (nthreads > 1) {
|
||||
fftw_init_threads();
|
||||
fftw_plan_with_nthreads(nthreads);
|
||||
DftiCreateDescriptor( &(plan->handle_mid), FFT_MKL_PREC, DFTI_COMPLEX, 1,
|
||||
(MKL_LONG)nmid);
|
||||
DftiSetValue(plan->handle_mid, DFTI_NUMBER_OF_TRANSFORMS,
|
||||
(MKL_LONG)plan->total2/nmid);
|
||||
DftiSetValue(plan->handle_mid, DFTI_PLACEMENT,DFTI_INPLACE);
|
||||
DftiSetValue(plan->handle_mid, DFTI_INPUT_DISTANCE, (MKL_LONG)nmid);
|
||||
DftiSetValue(plan->handle_mid, DFTI_OUTPUT_DISTANCE, (MKL_LONG)nmid);
|
||||
DftiSetValue(plan->handle_mid, DFTI_NUMBER_OF_USER_THREADS, nthreads);
|
||||
DftiCommitDescriptor(plan->handle_mid);
|
||||
|
||||
DftiCreateDescriptor( &(plan->handle_slow), FFT_MKL_PREC, DFTI_COMPLEX, 1,
|
||||
(MKL_LONG)nslow);
|
||||
DftiSetValue(plan->handle_slow, DFTI_NUMBER_OF_TRANSFORMS,
|
||||
(MKL_LONG)plan->total3/nslow);
|
||||
DftiSetValue(plan->handle_slow, DFTI_PLACEMENT,DFTI_INPLACE);
|
||||
DftiSetValue(plan->handle_slow, DFTI_INPUT_DISTANCE, (MKL_LONG)nslow);
|
||||
DftiSetValue(plan->handle_slow, DFTI_OUTPUT_DISTANCE, (MKL_LONG)nslow);
|
||||
DftiSetValue(plan->handle_slow, DFTI_NUMBER_OF_USER_THREADS, nthreads);
|
||||
DftiCommitDescriptor(plan->handle_slow);
|
||||
|
||||
if (scaled == 0)
|
||||
plan->scaled = 0;
|
||||
else {
|
||||
plan->scaled = 1;
|
||||
plan->norm = 1.0/(nfast*nmid*nslow);
|
||||
plan->normnum = (out_ihi-out_ilo+1) * (out_jhi-out_jlo+1) *
|
||||
(out_khi-out_klo+1);
|
||||
}
|
||||
#endif
|
||||
|
||||
#elif defined(FFT_FFTW3)
|
||||
if (nthreads > 1)
|
||||
fftw_plan_with_nthreads(nthreads);
|
||||
|
||||
plan->plan_fast_forward =
|
||||
FFTW_API(plan_many_dft)(1, &nfast,plan->total1/plan->length1,
|
||||
fftw_plan_many_dft(1, &nfast,plan->total1/plan->length1,
|
||||
NULL,&nfast,1,plan->length1,
|
||||
NULL,&nfast,1,plan->length1,
|
||||
FFTW_FORWARD,FFTW_ESTIMATE);
|
||||
|
||||
plan->plan_fast_backward =
|
||||
FFTW_API(plan_many_dft)(1, &nfast,plan->total1/plan->length1,
|
||||
fftw_plan_many_dft(1, &nfast,plan->total1/plan->length1,
|
||||
NULL,&nfast,1,plan->length1,
|
||||
NULL,&nfast,1,plan->length1,
|
||||
FFTW_BACKWARD,FFTW_ESTIMATE);
|
||||
|
||||
plan->plan_mid_forward =
|
||||
FFTW_API(plan_many_dft)(1, &nmid,plan->total2/plan->length2,
|
||||
fftw_plan_many_dft(1, &nmid,plan->total2/plan->length2,
|
||||
NULL,&nmid,1,plan->length2,
|
||||
NULL,&nmid,1,plan->length2,
|
||||
FFTW_FORWARD,FFTW_ESTIMATE);
|
||||
|
||||
plan->plan_mid_backward =
|
||||
FFTW_API(plan_many_dft)(1, &nmid,plan->total2/plan->length2,
|
||||
fftw_plan_many_dft(1, &nmid,plan->total2/plan->length2,
|
||||
NULL,&nmid,1,plan->length2,
|
||||
NULL,&nmid,1,plan->length2,
|
||||
FFTW_BACKWARD,FFTW_ESTIMATE);
|
||||
|
||||
|
||||
plan->plan_slow_forward =
|
||||
FFTW_API(plan_many_dft)(1, &nslow,plan->total3/plan->length3,
|
||||
fftw_plan_many_dft(1, &nslow,plan->total3/plan->length3,
|
||||
NULL,&nslow,1,plan->length3,
|
||||
NULL,&nslow,1,plan->length3,
|
||||
FFTW_FORWARD,FFTW_ESTIMATE);
|
||||
|
||||
plan->plan_slow_backward =
|
||||
FFTW_API(plan_many_dft)(1, &nslow,plan->total3/plan->length3,
|
||||
fftw_plan_many_dft(1, &nslow,plan->total3/plan->length3,
|
||||
NULL,&nslow,1,plan->length3,
|
||||
NULL,&nslow,1,plan->length3,
|
||||
FFTW_BACKWARD,FFTW_ESTIMATE);
|
||||
@ -653,17 +683,17 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
|
||||
cufftPlanMany(&(plan->plan_fast), 1, &nfast,
|
||||
&nfast,1,plan->length1,
|
||||
&nfast,1,plan->length1,
|
||||
CUFFT_TYPE,plan->total1/plan->length1);
|
||||
CUFFT_Z2Z,plan->total1/plan->length1);
|
||||
|
||||
cufftPlanMany(&(plan->plan_mid), 1, &nmid,
|
||||
&nmid,1,plan->length2,
|
||||
&nmid,1,plan->length2,
|
||||
CUFFT_TYPE,plan->total2/plan->length2);
|
||||
CUFFT_Z2Z,plan->total2/plan->length2);
|
||||
|
||||
cufftPlanMany(&(plan->plan_slow), 1, &nslow,
|
||||
&nslow,1,plan->length3,
|
||||
&nslow,1,plan->length3,
|
||||
CUFFT_TYPE,plan->total3/plan->length3);
|
||||
CUFFT_Z2Z,plan->total3/plan->length3);
|
||||
#else
|
||||
kissfftKK = new KissFFTKokkos<DeviceType>();
|
||||
|
||||
@ -717,14 +747,23 @@ void FFT3dKokkos<DeviceType>::fft_3d_destroy_plan_kokkos(struct fft_plan_3d_kokk
|
||||
if (plan->mid2_plan) remapKK->remap_3d_destroy_plan_kokkos(plan->mid2_plan);
|
||||
if (plan->post_plan) remapKK->remap_3d_destroy_plan_kokkos(plan->post_plan);
|
||||
|
||||
delete plan;
|
||||
delete remapKK;
|
||||
|
||||
#if defined (FFT_FFTW_THREADS)
|
||||
fftw_cleanup_threads();
|
||||
#if defined(FFT_MKL)
|
||||
DftiFreeDescriptor(&(plan->handle_fast));
|
||||
DftiFreeDescriptor(&(plan->handle_mid));
|
||||
DftiFreeDescriptor(&(plan->handle_slow));
|
||||
#elif defined(FFT_FFTW3)
|
||||
FFTW_API(destroy_plan)(plan->plan_slow_forward);
|
||||
FFTW_API(destroy_plan)(plan->plan_slow_backward);
|
||||
FFTW_API(destroy_plan)(plan->plan_mid_forward);
|
||||
FFTW_API(destroy_plan)(plan->plan_mid_backward);
|
||||
FFTW_API(destroy_plan)(plan->plan_fast_forward);
|
||||
FFTW_API(destroy_plan)(plan->plan_fast_backward);
|
||||
#elif defined (FFT_KISSFFT)
|
||||
delete kissfftKK;
|
||||
#endif
|
||||
|
||||
delete plan;
|
||||
delete remapKK;
|
||||
}
|
||||
|
||||
/* ----------------------------------------------------------------------
|
||||
@ -777,7 +816,10 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d
|
||||
|
||||
// fftw3 and Dfti in MKL encode the number of transforms
|
||||
// into the plan, so we cannot operate on a smaller data set
|
||||
|
||||
#if defined(FFT_MKL) || defined(FFT_FFTW3)
|
||||
if ((total1 > nsize) || (total2 > nsize) || (total3 > nsize))
|
||||
return;
|
||||
#endif
|
||||
if (total1 > nsize) total1 = (nsize/length1) * length1;
|
||||
if (total2 > nsize) total2 = (nsize/length2) * length2;
|
||||
if (total3 > nsize) total3 = (nsize/length3) * length3;
|
||||
@ -785,20 +827,30 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d
|
||||
// perform 1d FFTs in each of 3 dimensions
|
||||
// data is just an array of 0.0
|
||||
|
||||
#if defined(FFT_FFTW3)
|
||||
#if defined(FFT_MKL)
|
||||
if (flag == -1) {
|
||||
FFTW_API(execute_dft)(plan->plan_fast_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
FFTW_API(execute_dft)(plan->plan_mid_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
FFTW_API(execute_dft)(plan->plan_slow_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
DftiComputeForward(plan->handle_fast,data);
|
||||
DftiComputeForward(plan->handle_mid,data);
|
||||
DftiComputeForward(plan->handle_slow,data);
|
||||
} else {
|
||||
FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
DftiComputeBackward(plan->handle_fast,data);
|
||||
DftiComputeBackward(plan->handle_mid,data);
|
||||
DftiComputeBackward(plan->handle_slow,data);
|
||||
}
|
||||
#elif defined(FFT_FFTW3)
|
||||
if (flag == -1) {
|
||||
fftw_execute_dft(plan->plan_fast_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
fftw_execute_dft(plan->plan_mid_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
fftw_execute_dft(plan->plan_slow_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
} else {
|
||||
fftw_execute_dft(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
fftw_execute_dft(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
fftw_execute_dft(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
|
||||
}
|
||||
#elif defined(FFT_CUFFT)
|
||||
cufftExec(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
|
||||
cufftExec(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
|
||||
cufftExec(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
|
||||
cufftExecZ2Z(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
|
||||
cufftExecZ2Z(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
|
||||
cufftExecZ2Z(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
|
||||
#else
|
||||
kiss_fft_functor<DeviceType> f;
|
||||
typename AT::t_FFT_DATA_1d d_tmp = typename AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
|
||||
@ -829,15 +881,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d
|
||||
if (flag == 1 && plan->scaled) {
|
||||
FFT_SCALAR norm = plan->norm;
|
||||
int num = MIN(plan->normnum,nsize);
|
||||
#if defined(FFT_CUFFT)
|
||||
typename AT::t_FFT_SCALAR_1d d_norm_scalar =
|
||||
typename AT::t_FFT_SCALAR_1d(d_data.data(),d_data.size());
|
||||
cufft_norm_functor<DeviceType> f(d_norm_scalar,norm);
|
||||
|
||||
norm_functor<DeviceType> f(d_out,norm);
|
||||
Kokkos::parallel_for(num,f);
|
||||
#elif defined(FFT_KISSFFT)
|
||||
kiss_norm_functor<DeviceType> f(d_data,norm);
|
||||
Kokkos::parallel_for(num,f);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -29,7 +29,16 @@
|
||||
#define FFT_FFTW3
|
||||
#endif
|
||||
|
||||
#if defined(FFT_FFTW3)
|
||||
#if defined(FFT_MKL)
|
||||
#include "mkl_dfti.h"
|
||||
#if defined(FFT_SINGLE)
|
||||
typedef float _Complex FFT_DATA;
|
||||
#define FFT_MKL_PREC DFTI_SINGLE
|
||||
#else
|
||||
typedef double _Complex FFT_DATA;
|
||||
#define FFT_MKL_PREC DFTI_DOUBLE
|
||||
#endif
|
||||
#elif defined(FFT_FFTW3)
|
||||
#include "fftw3.h"
|
||||
#if defined(FFT_SINGLE)
|
||||
typedef fftwf_complex FFT_DATA;
|
||||
@ -82,7 +91,11 @@ struct fft_plan_3d_kokkos {
|
||||
double norm; // normalization factor for rescaling
|
||||
|
||||
// system specific 1d FFT info
|
||||
#if defined(FFT_FFTW3)
|
||||
#if defined(FFT_MKL)
|
||||
DFTI_DESCRIPTOR *handle_fast;
|
||||
DFTI_DESCRIPTOR *handle_mid;
|
||||
DFTI_DESCRIPTOR *handle_slow;
|
||||
#elif defined(FFT_FFTW3)
|
||||
FFTW_API(plan) plan_fast_forward;
|
||||
FFTW_API(plan) plan_fast_backward;
|
||||
FFTW_API(plan) plan_mid_forward;
|
||||
|
||||
Reference in New Issue
Block a user