diff --git a/cmake/Modules/Packages/KSPACE.cmake b/cmake/Modules/Packages/KSPACE.cmake index de7e7e5b20..119c8fa867 100644 --- a/cmake/Modules/Packages/KSPACE.cmake +++ b/cmake/Modules/Packages/KSPACE.cmake @@ -46,6 +46,22 @@ else() target_compile_definitions(lammps PRIVATE -DFFT_KISS) endif() +option(FFT_HEFFTE "Use heFFTe as the distributed FFT engine." OFF) +if(FFT_HEFFTE) + # if FFT_HEFFTE is enabled, switch the builtin FFT engine with Heffte + if(FFT STREQUAL "FFTW3") # respect the backend choice, FFTW or MKL + set(HEFFTE_COMPONENTS "FFTW") + elseif(FFT STREQUAL "MKL") + set(HEFFTE_COMPONENTS "MKL") + else() + message(FATAL_ERROR "Using -DFFT_HEFFTE=ON, requires FFT either FFTW or MKL") + endif() + + find_package(Heffte 2.3.0 REQUIRED ${HEFFTE_COMPONENTS}) + target_compile_definitions(lammps PRIVATE -DHEFFTE) + target_link_libraries(lammps PRIVATE Heffte::Heffte) +endif() + set(FFT_PACK "array" CACHE STRING "Optimization for FFT") set(FFT_PACK_VALUES array pointer memcpy) set_property(CACHE FFT_PACK PROPERTY STRINGS ${FFT_PACK_VALUES}) diff --git a/src/KSPACE/fft3d_wrap.cpp b/src/KSPACE/fft3d_wrap.cpp index 478cf6fc9d..a6b4167d71 100644 --- a/src/KSPACE/fft3d_wrap.cpp +++ b/src/KSPACE/fft3d_wrap.cpp @@ -27,30 +27,66 @@ FFT3d::FFT3d(LAMMPS *lmp, MPI_Comm comm, int nfast, int nmid, int nslow, int out_klo, int out_khi, int scaled, int permute, int *nbuf, int usecollective) : Pointers(lmp) { + #ifndef HEFFTE plan = fft_3d_create_plan(comm,nfast,nmid,nslow, in_ilo,in_ihi,in_jlo,in_jhi,in_klo,in_khi, out_ilo,out_ihi,out_jlo,out_jhi,out_klo,out_khi, scaled,permute,nbuf,usecollective); if (plan == nullptr) error->one(FLERR,"Could not create 3d FFT plan"); + #else + heffte::plan_options options = heffte::default_options(); + options.algorithm = (usecollective == 0) ? + heffte::reshape_algorithm::p2p_plined + : heffte::reshape_algorithm::alltoallv; + options.use_reorder = (permute != 0); + hscale = (scaled == 0) ? heffte::scale::none : heffte::scale::full; + + heffte_plan = std::unique_ptr>( + new heffte::fft3d( + heffte::box3d<>({in_ilo,in_jlo,in_klo}, {in_ihi, in_jhi, in_khi}), + heffte::box3d<>({out_ilo,out_jlo,out_klo}, {out_ihi, out_jhi, out_khi}), + comm, options) + ); + *nbuf = heffte_plan->size_workspace(); + heffte_workspace.resize(heffte_plan->size_workspace()); + #endif } /* ---------------------------------------------------------------------- */ FFT3d::~FFT3d() { + #ifndef HEFFTE fft_3d_destroy_plan(plan); + #endif } /* ---------------------------------------------------------------------- */ void FFT3d::compute(FFT_SCALAR *in, FFT_SCALAR *out, int flag) { + #ifndef HEFFTE fft_3d((FFT_DATA *) in,(FFT_DATA *) out,flag,plan); + #else + if (flag == 1) + heffte_plan->forward(reinterpret_cast*>(in), + reinterpret_cast*>(out), + reinterpret_cast*>(heffte_workspace.data()) + ); + else + heffte_plan->backward(reinterpret_cast*>(in), + reinterpret_cast*>(out), + reinterpret_cast*>(heffte_workspace.data()), + hscale + ); + #endif } /* ---------------------------------------------------------------------- */ void FFT3d::timing1d(FFT_SCALAR *in, int nsize, int flag) { + #ifndef HEFFTE fft_1d_only((FFT_DATA *) in,nsize,flag,plan); + #endif } diff --git a/src/KSPACE/fft3d_wrap.h b/src/KSPACE/fft3d_wrap.h index f72cfd4622..f34680d682 100644 --- a/src/KSPACE/fft3d_wrap.h +++ b/src/KSPACE/fft3d_wrap.h @@ -17,6 +17,17 @@ #include "fft3d.h" // IWYU pragma: export #include "pointers.h" +#ifdef HEFFTE +#include "heffte.h" +// select the backend +#if defined(FFT_FFTW3) +using heffte_backend = heffte::backend::fftw; +#elif defined(FFT_MKL) +using heffte_backend = heffte::backend::mkl; +#endif + +#endif // HEFFTE + namespace LAMMPS_NS { class FFT3d : protected Pointers { @@ -30,7 +41,14 @@ class FFT3d : protected Pointers { void timing1d(FFT_SCALAR *, int, int); private: + #ifdef HEFFTE + // the heFFTe plan supersedes the internal fft_plan_3d + std::unique_ptr> heffte_plan; + std::vector> heffte_workspace; + heffte::scale hscale; + #else struct fft_plan_3d *plan; + #endif }; } // namespace LAMMPS_NS