enabled the use of heffte for the cpu backend

This commit is contained in:
Miroslav Stoyanov
2023-02-02 17:11:28 -05:00
parent 97a0885145
commit 5bbdfe5b4f
3 changed files with 70 additions and 0 deletions

View File

@ -27,30 +27,66 @@ FFT3d::FFT3d(LAMMPS *lmp, MPI_Comm comm, int nfast, int nmid, int nslow,
int out_klo, int out_khi,
int scaled, int permute, int *nbuf, int usecollective) : Pointers(lmp)
{
#ifndef HEFFTE
plan = fft_3d_create_plan(comm,nfast,nmid,nslow,
in_ilo,in_ihi,in_jlo,in_jhi,in_klo,in_khi,
out_ilo,out_ihi,out_jlo,out_jhi,out_klo,out_khi,
scaled,permute,nbuf,usecollective);
if (plan == nullptr) error->one(FLERR,"Could not create 3d FFT plan");
#else
heffte::plan_options options = heffte::default_options<heffte_backend>();
options.algorithm = (usecollective == 0) ?
heffte::reshape_algorithm::p2p_plined
: heffte::reshape_algorithm::alltoallv;
options.use_reorder = (permute != 0);
hscale = (scaled == 0) ? heffte::scale::none : heffte::scale::full;
heffte_plan = std::unique_ptr<heffte::fft3d<heffte_backend>>(
new heffte::fft3d<heffte_backend>(
heffte::box3d<>({in_ilo,in_jlo,in_klo}, {in_ihi, in_jhi, in_khi}),
heffte::box3d<>({out_ilo,out_jlo,out_klo}, {out_ihi, out_jhi, out_khi}),
comm, options)
);
*nbuf = heffte_plan->size_workspace();
heffte_workspace.resize(heffte_plan->size_workspace());
#endif
}
/* ---------------------------------------------------------------------- */
FFT3d::~FFT3d()
{
#ifndef HEFFTE
fft_3d_destroy_plan(plan);
#endif
}
/* ---------------------------------------------------------------------- */
void FFT3d::compute(FFT_SCALAR *in, FFT_SCALAR *out, int flag)
{
#ifndef HEFFTE
fft_3d((FFT_DATA *) in,(FFT_DATA *) out,flag,plan);
#else
if (flag == 1)
heffte_plan->forward(reinterpret_cast<std::complex<FFT_SCALAR>*>(in),
reinterpret_cast<std::complex<FFT_SCALAR>*>(out),
reinterpret_cast<std::complex<FFT_SCALAR>*>(heffte_workspace.data())
);
else
heffte_plan->backward(reinterpret_cast<std::complex<FFT_SCALAR>*>(in),
reinterpret_cast<std::complex<FFT_SCALAR>*>(out),
reinterpret_cast<std::complex<FFT_SCALAR>*>(heffte_workspace.data()),
hscale
);
#endif
}
/* ---------------------------------------------------------------------- */
void FFT3d::timing1d(FFT_SCALAR *in, int nsize, int flag)
{
#ifndef HEFFTE
fft_1d_only((FFT_DATA *) in,nsize,flag,plan);
#endif
}