Fix issues

This commit is contained in:
Stan Moore
2019-07-30 09:25:24 -06:00
parent f96609a046
commit 9a43a6824c
2 changed files with 38 additions and 31 deletions

View File

@ -11,7 +11,7 @@ if(PKG_KSPACE)
else() else()
set(FFT "KISS" CACHE STRING "FFT library for KSPACE package") set(FFT "KISS" CACHE STRING "FFT library for KSPACE package")
endif() endif()
set(FFT_VALUES KISS FFTW MKL CUFFT) set(FFT_VALUES KISS FFTW3 MKL CUFFT)
set_property(CACHE FFT PROPERTY STRINGS ${FFT_VALUES}) set_property(CACHE FFT PROPERTY STRINGS ${FFT_VALUES})
validate_option(FFT FFT_VALUES) validate_option(FFT FFT_VALUES)
string(TOUPPER ${FFT} FFT) string(TOUPPER ${FFT} FFT)

View File

@ -224,11 +224,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
DftiComputeBackward(plan->handle_fast,(FFT_DATA *)d_data.data()); DftiComputeBackward(plan->handle_fast,(FFT_DATA *)d_data.data());
#elif defined(FFT_FFTW3) #elif defined(FFT_FFTW3)
if (flag == -1) if (flag == -1)
fftw_execute_dft(plan->plan_fast_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data()); FFTW_API(execute_dft)(plan->plan_fast_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
else else
fftw_execute_dft(plan->plan_fast_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data()); FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
#elif defined(FFT_CUFFT) #elif defined(FFT_CUFFT)
cufftExecZ2Z(plan->plan_fast,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag); cufftExec(plan->plan_fast,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
#else #else
typename AT::t_FFT_DATA_1d d_tmp = typename AT::t_FFT_DATA_1d d_tmp =
typename AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0()); typename AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
@ -270,11 +270,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
DftiComputeBackward(plan->handle_mid,(FFT_DATA *)d_data.data()); DftiComputeBackward(plan->handle_mid,(FFT_DATA *)d_data.data());
#elif defined(FFT_FFTW3) #elif defined(FFT_FFTW3)
if (flag == -1) if (flag == -1)
fftw_execute_dft(plan->plan_mid_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data()); FFTW_API(execute_dft)(plan->plan_mid_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
else else
fftw_execute_dft(plan->plan_mid_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data()); FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
#elif defined(FFT_CUFFT) #elif defined(FFT_CUFFT)
cufftExecZ2Z(plan->plan_mid,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag); cufftExec(plan->plan_mid,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
#else #else
if (flag == -1) if (flag == -1)
f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_forward,length); f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_forward,length);
@ -312,11 +312,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename AT::t_FFT_DATA_1d d_in, typ
DftiComputeBackward(plan->handle_slow,(FFT_DATA *)d_data.data()); DftiComputeBackward(plan->handle_slow,(FFT_DATA *)d_data.data());
#elif defined(FFT_FFTW3) #elif defined(FFT_FFTW3)
if (flag == -1) if (flag == -1)
fftw_execute_dft(plan->plan_slow_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data()); FFTW_API(execute_dft)(plan->plan_slow_forward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
else else
fftw_execute_dft(plan->plan_slow_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data()); FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data());
#elif defined(FFT_CUFFT) #elif defined(FFT_CUFFT)
cufftExecZ2Z(plan->plan_slow,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag); cufftExec(plan->plan_slow,(FFT_DATA *)d_data.data(),(FFT_DATA *)d_data.data(),flag);
#else #else
if (flag == -1) if (flag == -1)
f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_slow_forward,length); f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_slow_forward,length);
@ -640,42 +640,44 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
} }
#elif defined(FFT_FFTW3) #elif defined(FFT_FFTW3)
if (nthreads > 1) if (nthreads > 1) {
fftw_plan_with_nthreads(nthreads); FFTW_API(init_threads)();
FFTW_API(plan_with_nthreads)(nthreads);
}
plan->plan_fast_forward = plan->plan_fast_forward =
fftw_plan_many_dft(1, &nfast,plan->total1/plan->length1, FFTW_API(plan_many_dft)(1, &nfast,plan->total1/plan->length1,
NULL,&nfast,1,plan->length1, NULL,&nfast,1,plan->length1,
NULL,&nfast,1,plan->length1, NULL,&nfast,1,plan->length1,
FFTW_FORWARD,FFTW_ESTIMATE); FFTW_FORWARD,FFTW_ESTIMATE);
plan->plan_fast_backward = plan->plan_fast_backward =
fftw_plan_many_dft(1, &nfast,plan->total1/plan->length1, FFTW_API(plan_many_dft)(1, &nfast,plan->total1/plan->length1,
NULL,&nfast,1,plan->length1, NULL,&nfast,1,plan->length1,
NULL,&nfast,1,plan->length1, NULL,&nfast,1,plan->length1,
FFTW_BACKWARD,FFTW_ESTIMATE); FFTW_BACKWARD,FFTW_ESTIMATE);
plan->plan_mid_forward = plan->plan_mid_forward =
fftw_plan_many_dft(1, &nmid,plan->total2/plan->length2, FFTW_API(plan_many_dft)(1, &nmid,plan->total2/plan->length2,
NULL,&nmid,1,plan->length2, NULL,&nmid,1,plan->length2,
NULL,&nmid,1,plan->length2, NULL,&nmid,1,plan->length2,
FFTW_FORWARD,FFTW_ESTIMATE); FFTW_FORWARD,FFTW_ESTIMATE);
plan->plan_mid_backward = plan->plan_mid_backward =
fftw_plan_many_dft(1, &nmid,plan->total2/plan->length2, FFTW_API(plan_many_dft)(1, &nmid,plan->total2/plan->length2,
NULL,&nmid,1,plan->length2, NULL,&nmid,1,plan->length2,
NULL,&nmid,1,plan->length2, NULL,&nmid,1,plan->length2,
FFTW_BACKWARD,FFTW_ESTIMATE); FFTW_BACKWARD,FFTW_ESTIMATE);
plan->plan_slow_forward = plan->plan_slow_forward =
fftw_plan_many_dft(1, &nslow,plan->total3/plan->length3, FFTW_API(plan_many_dft)(1, &nslow,plan->total3/plan->length3,
NULL,&nslow,1,plan->length3, NULL,&nslow,1,plan->length3,
NULL,&nslow,1,plan->length3, NULL,&nslow,1,plan->length3,
FFTW_FORWARD,FFTW_ESTIMATE); FFTW_FORWARD,FFTW_ESTIMATE);
plan->plan_slow_backward = plan->plan_slow_backward =
fftw_plan_many_dft(1, &nslow,plan->total3/plan->length3, FFTW_API(plan_many_dft)(1, &nslow,plan->total3/plan->length3,
NULL,&nslow,1,plan->length3, NULL,&nslow,1,plan->length3,
NULL,&nslow,1,plan->length3, NULL,&nslow,1,plan->length3,
FFTW_BACKWARD,FFTW_ESTIMATE); FFTW_BACKWARD,FFTW_ESTIMATE);
@ -683,17 +685,17 @@ struct fft_plan_3d_kokkos<DeviceType>* FFT3dKokkos<DeviceType>::fft_3d_create_pl
cufftPlanMany(&(plan->plan_fast), 1, &nfast, cufftPlanMany(&(plan->plan_fast), 1, &nfast,
&nfast,1,plan->length1, &nfast,1,plan->length1,
&nfast,1,plan->length1, &nfast,1,plan->length1,
CUFFT_Z2Z,plan->total1/plan->length1); CUFFT_TYPE,plan->total1/plan->length1);
cufftPlanMany(&(plan->plan_mid), 1, &nmid, cufftPlanMany(&(plan->plan_mid), 1, &nmid,
&nmid,1,plan->length2, &nmid,1,plan->length2,
&nmid,1,plan->length2, &nmid,1,plan->length2,
CUFFT_Z2Z,plan->total2/plan->length2); CUFFT_TYPE,plan->total2/plan->length2);
cufftPlanMany(&(plan->plan_slow), 1, &nslow, cufftPlanMany(&(plan->plan_slow), 1, &nslow,
&nslow,1,plan->length3, &nslow,1,plan->length3,
&nslow,1,plan->length3, &nslow,1,plan->length3,
CUFFT_Z2Z,plan->total3/plan->length3); CUFFT_TYPE,plan->total3/plan->length3);
#else #else
kissfftKK = new KissFFTKokkos<DeviceType>(); kissfftKK = new KissFFTKokkos<DeviceType>();
@ -758,6 +760,11 @@ void FFT3dKokkos<DeviceType>::fft_3d_destroy_plan_kokkos(struct fft_plan_3d_kokk
FFTW_API(destroy_plan)(plan->plan_mid_backward); FFTW_API(destroy_plan)(plan->plan_mid_backward);
FFTW_API(destroy_plan)(plan->plan_fast_forward); FFTW_API(destroy_plan)(plan->plan_fast_forward);
FFTW_API(destroy_plan)(plan->plan_fast_backward); FFTW_API(destroy_plan)(plan->plan_fast_backward);
FFTW_API(cleanup_threads)();
#elif defined (FFT_KISSFFT) #elif defined (FFT_KISSFFT)
delete kissfftKK; delete kissfftKK;
#endif #endif
@ -839,18 +846,18 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d
} }
#elif defined(FFT_FFTW3) #elif defined(FFT_FFTW3)
if (flag == -1) { if (flag == -1) {
fftw_execute_dft(plan->plan_fast_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data()); FFTW_API(execute_dft)(plan->plan_fast_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
fftw_execute_dft(plan->plan_mid_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data()); FFTW_API(execute_dft)(plan->plan_mid_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
fftw_execute_dft(plan->plan_slow_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data()); FFTW_API(execute_dft)(plan->plan_slow_forward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
} else { } else {
fftw_execute_dft(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data()); FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
fftw_execute_dft(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data()); FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
fftw_execute_dft(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data()); FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
} }
#elif defined(FFT_CUFFT) #elif defined(FFT_CUFFT)
cufftExecZ2Z(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag); cufftExec(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
cufftExecZ2Z(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag); cufftExec(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
cufftExecZ2Z(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag); cufftExec(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
#else #else
kiss_fft_functor<DeviceType> f; kiss_fft_functor<DeviceType> f;
typename AT::t_FFT_DATA_1d d_tmp = typename AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0()); typename AT::t_FFT_DATA_1d d_tmp = typename AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());
@ -882,7 +889,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename AT::t_FFT_DATA_1d d
FFT_SCALAR norm = plan->norm; FFT_SCALAR norm = plan->norm;
int num = MIN(plan->normnum,nsize); int num = MIN(plan->normnum,nsize);
norm_functor<DeviceType> f(d_out,norm); norm_functor<DeviceType> f(d_data,norm);
Kokkos::parallel_for(num,f); Kokkos::parallel_for(num,f);
} }
} }