enabled the use of heffte for the cpu backend
This commit is contained in:
@ -46,6 +46,22 @@ else()
|
||||
target_compile_definitions(lammps PRIVATE -DFFT_KISS)
|
||||
endif()
|
||||
|
||||
option(FFT_HEFFTE "Use heFFTe as the distributed FFT engine." OFF)
|
||||
if(FFT_HEFFTE)
|
||||
# if FFT_HEFFTE is enabled, switch the builtin FFT engine with Heffte
|
||||
if(FFT STREQUAL "FFTW3") # respect the backend choice, FFTW or MKL
|
||||
set(HEFFTE_COMPONENTS "FFTW")
|
||||
elseif(FFT STREQUAL "MKL")
|
||||
set(HEFFTE_COMPONENTS "MKL")
|
||||
else()
|
||||
message(FATAL_ERROR "Using -DFFT_HEFFTE=ON, requires FFT either FFTW or MKL")
|
||||
endif()
|
||||
|
||||
find_package(Heffte 2.3.0 REQUIRED ${HEFFTE_COMPONENTS})
|
||||
target_compile_definitions(lammps PRIVATE -DHEFFTE)
|
||||
target_link_libraries(lammps PRIVATE Heffte::Heffte)
|
||||
endif()
|
||||
|
||||
set(FFT_PACK "array" CACHE STRING "Optimization for FFT")
|
||||
set(FFT_PACK_VALUES array pointer memcpy)
|
||||
set_property(CACHE FFT_PACK PROPERTY STRINGS ${FFT_PACK_VALUES})
|
||||
|
||||
@ -27,30 +27,66 @@ FFT3d::FFT3d(LAMMPS *lmp, MPI_Comm comm, int nfast, int nmid, int nslow,
|
||||
int out_klo, int out_khi,
|
||||
int scaled, int permute, int *nbuf, int usecollective) : Pointers(lmp)
|
||||
{
|
||||
#ifndef HEFFTE
|
||||
plan = fft_3d_create_plan(comm,nfast,nmid,nslow,
|
||||
in_ilo,in_ihi,in_jlo,in_jhi,in_klo,in_khi,
|
||||
out_ilo,out_ihi,out_jlo,out_jhi,out_klo,out_khi,
|
||||
scaled,permute,nbuf,usecollective);
|
||||
if (plan == nullptr) error->one(FLERR,"Could not create 3d FFT plan");
|
||||
#else
|
||||
heffte::plan_options options = heffte::default_options<heffte_backend>();
|
||||
options.algorithm = (usecollective == 0) ?
|
||||
heffte::reshape_algorithm::p2p_plined
|
||||
: heffte::reshape_algorithm::alltoallv;
|
||||
options.use_reorder = (permute != 0);
|
||||
hscale = (scaled == 0) ? heffte::scale::none : heffte::scale::full;
|
||||
|
||||
heffte_plan = std::unique_ptr<heffte::fft3d<heffte_backend>>(
|
||||
new heffte::fft3d<heffte_backend>(
|
||||
heffte::box3d<>({in_ilo,in_jlo,in_klo}, {in_ihi, in_jhi, in_khi}),
|
||||
heffte::box3d<>({out_ilo,out_jlo,out_klo}, {out_ihi, out_jhi, out_khi}),
|
||||
comm, options)
|
||||
);
|
||||
*nbuf = heffte_plan->size_workspace();
|
||||
heffte_workspace.resize(heffte_plan->size_workspace());
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
FFT3d::~FFT3d()
|
||||
{
|
||||
#ifndef HEFFTE
|
||||
fft_3d_destroy_plan(plan);
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
void FFT3d::compute(FFT_SCALAR *in, FFT_SCALAR *out, int flag)
|
||||
{
|
||||
#ifndef HEFFTE
|
||||
fft_3d((FFT_DATA *) in,(FFT_DATA *) out,flag,plan);
|
||||
#else
|
||||
if (flag == 1)
|
||||
heffte_plan->forward(reinterpret_cast<std::complex<FFT_SCALAR>*>(in),
|
||||
reinterpret_cast<std::complex<FFT_SCALAR>*>(out),
|
||||
reinterpret_cast<std::complex<FFT_SCALAR>*>(heffte_workspace.data())
|
||||
);
|
||||
else
|
||||
heffte_plan->backward(reinterpret_cast<std::complex<FFT_SCALAR>*>(in),
|
||||
reinterpret_cast<std::complex<FFT_SCALAR>*>(out),
|
||||
reinterpret_cast<std::complex<FFT_SCALAR>*>(heffte_workspace.data()),
|
||||
hscale
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------- */
|
||||
|
||||
void FFT3d::timing1d(FFT_SCALAR *in, int nsize, int flag)
|
||||
{
|
||||
#ifndef HEFFTE
|
||||
fft_1d_only((FFT_DATA *) in,nsize,flag,plan);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -17,6 +17,17 @@
|
||||
#include "fft3d.h" // IWYU pragma: export
|
||||
#include "pointers.h"
|
||||
|
||||
#ifdef HEFFTE
|
||||
#include "heffte.h"
|
||||
// select the backend
|
||||
#if defined(FFT_FFTW3)
|
||||
using heffte_backend = heffte::backend::fftw;
|
||||
#elif defined(FFT_MKL)
|
||||
using heffte_backend = heffte::backend::mkl;
|
||||
#endif
|
||||
|
||||
#endif // HEFFTE
|
||||
|
||||
namespace LAMMPS_NS {
|
||||
|
||||
class FFT3d : protected Pointers {
|
||||
@ -30,7 +41,14 @@ class FFT3d : protected Pointers {
|
||||
void timing1d(FFT_SCALAR *, int, int);
|
||||
|
||||
private:
|
||||
#ifdef HEFFTE
|
||||
// the heFFTe plan supersedes the internal fft_plan_3d
|
||||
std::unique_ptr<heffte::fft3d<heffte_backend>> heffte_plan;
|
||||
std::vector<std::complex<FFT_SCALAR>> heffte_workspace;
|
||||
heffte::scale hscale;
|
||||
#else
|
||||
struct fft_plan_3d *plan;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace LAMMPS_NS
|
||||
|
||||
Reference in New Issue
Block a user