Fused CG passes 1 & 2 for QEQ solver

Change-Id: I5fa396d8a2f2713712056a264d2bb05b7321dc1a
This commit is contained in:
Nicholas Curtis
2021-08-25 12:02:09 -04:00
committed by Nick Curtis
parent 2a1823f59d
commit fb379dab15
3 changed files with 771 additions and 9 deletions

View File

@ -310,6 +310,12 @@ void FixQEqReaxFFKokkos<DeviceType>::pre_force(int /*vflag*/)
else
ndup_o = Kokkos::Experimental::create_scatter_view<Kokkos::Experimental::ScatterSum, Kokkos::Experimental::ScatterNonDuplicated> (d_o);
#ifdef HIP_OPT_CG_SOLVE_FUSED
// fused cg solve over b_s, s & b_t, t
matvecs = cg_solve_fused();
#else
// 1st cg solve over b_s, s
matvecs = cg_solve1();
@ -318,6 +324,8 @@ void FixQEqReaxFFKokkos<DeviceType>::pre_force(int /*vflag*/)
matvecs += cg_solve2();
#endif
// calculate_Q();
k_s_hist.template sync<DeviceType>();
@ -378,9 +386,11 @@ void FixQEqReaxFFKokkos<DeviceType>::allocate_array()
if (atom->nmax > nmax) {
nmax = atom->nmax;
#ifndef HIP_OPT_CG_SOLVE_FUSED
k_o = DAT::tdual_ffloat_1d("qeq/kk:o",nmax);
d_o = k_o.template view<DeviceType>();
h_o = k_o.h_view;
#endif
d_Hdia_inv = typename AT::t_ffloat_1d("qeq/kk:Hdia_inv",nmax);
@ -396,6 +406,7 @@ void FixQEqReaxFFKokkos<DeviceType>::allocate_array()
d_t = k_t.template view<DeviceType>();
h_t = k_t.h_view;
#ifndef HIP_OPT_CG_SOLVE_FUSED
d_p = typename AT::t_ffloat_1d("qeq/kk:p",nmax);
d_r = typename AT::t_ffloat_1d("qeq/kk:r",nmax);
@ -406,6 +417,21 @@ void FixQEqReaxFFKokkos<DeviceType>::allocate_array()
memoryKK->create_kokkos(k_chi_field,chi_field,nmax,"qeq/kk:chi_field");
d_chi_field = k_chi_field.template view<DeviceType>();
#endif
#ifdef HIP_OPT_CG_SOLVE_FUSED
k_o_fused = DAT::tdual_ffloat2_1d("qeq/kk:o",nmax);
d_o_fused = k_o_fused.template view<DeviceType>();
h_o_fused = k_o_fused.h_view;
d_p_fused = typename AT::t_ffloat2_1d("qeq/kk:p",nmax);
d_r_fused = typename AT::t_ffloat2_1d("qeq/kk:r",nmax);
k_d_fused = DAT::tdual_ffloat2_1d("qeq/kk:d",nmax);
d_d_fused = k_d_fused.template view<DeviceType>();
h_d_fused = k_d_fused.h_view;
#endif
}
// init_storage
@ -431,10 +457,21 @@ void FixQEqReaxFFKokkos<DeviceType>::zero_item(int ii) const
d_b_t[i] = -1.0;
d_s[i] = 0.0;
d_t[i] = 0.0;
#ifndef HIP_OPT_CG_SOLVE_FUSED
d_p[i] = 0.0;
d_o[i] = 0.0;
d_r[i] = 0.0;
d_d[i] = 0.0;
#else
d_o_fused(i,0) = 0.0;
d_o_fused(i,1) = 0.0;
d_d_fused(i,0) = 0.0;
d_d_fused(i,1) = 0.0;
d_r_fused(i,0) = 0.0;
d_r_fused(i,1) = 0.0;
d_p_fused(i,0) = 0.0;
d_p_fused(i,1) = 0.0;
#endif
}
}
@ -764,6 +801,7 @@ template<class DeviceType>
int FixQEqReaxFFKokkos<DeviceType>::cg_solve1()
// b = b_s, x = s;
{
#ifndef HIP_OPT_CG_SOLVE_FUSED
const int inum = list->inum;
F_FLOAT tmp, sig_old, b_norm;
@ -917,6 +955,7 @@ template<class DeviceType>
int FixQEqReaxFFKokkos<DeviceType>::cg_solve2()
// b = b_t, x = t;
{
#ifndef HIP_OPT_CG_SOLVE_FUSED
const int inum = list->inum;
F_FLOAT tmp, sig_old, b_norm;
@ -1063,8 +1102,164 @@ int FixQEqReaxFFKokkos<DeviceType>::cg_solve2()
"{}", loop, update->ntimestep,
sqrt(sig_new)/b_norm));
return loop;
#endif
}
#ifdef HIP_OPT_CG_SOLVE_FUSED
template<class DeviceType>
void FixQEqReaxKokkos<DeviceType>::cg_solve_fused()
// b = b_s, x = s;
{
// reset converged
converged = 0;
const int inum = list->inum;
F_FLOAT2 tmp;
F_FLOAT2 sig_old;
F_FLOAT2 b_norm;
int teamsize;
int vectorsize;
int leaguesize;
if (execution_space == Host) {
teamsize = 1;
vectorsize = 1;
leaguesize = inum;
}
else {
#ifdef HIP_OPT_SPMV
teamsize = 16;
vectorsize = 64;
leaguesize = (inum + teamsize - 1) / (teamsize);
#else
teamsize = 128;
vectorsize = 1;
leaguesize = inum;
#endif
}
// sparse_matvec( &H, x, q );
FixQEqReaxKokkosSparse12_32Functor<DeviceType> sparse_12_32_functor(this);
Kokkos::parallel_for(inum,sparse_12_32_functor);
if (neighflag != FULL) {
Kokkos::abort("Not implemented!");
} else {
#ifdef HIP_OPT_SPMV
Kokkos::parallel_for(Kokkos::TeamPolicy <DeviceType, TagSparseMatvec13Vector> (leaguesize, teamsize, vectorsize), *this);
#else
Kokkos::parallel_for(Kokkos::TeamPolicy <DeviceType, TagSparseMatvec13> (inum, teamsize), *this);
#endif
}
if (neighflag != FULL) {
Kokkos::abort("Not implemented!");
}
// vector_sum( r , 1., b, -1., q, nn );
// preconditioning: d[j] = r[j] * Hdia_inv[j];
// b_norm = parallel_norm( b, nn );
F_FLOAT2 my_norm;
FixQEqReaxKokkosNorm12Functor<DeviceType> norm12_functor(this);
Kokkos::parallel_reduce(inum,norm12_functor,my_norm);
F_FLOAT2 norm_sqr;
MPI_Allreduce( &my_norm.v, &norm_sqr.v, 2, MPI_DOUBLE, MPI_SUM, world );
b_norm.v[0] = sqrt(norm_sqr.v[0]);
b_norm.v[1] = sqrt(norm_sqr.v[1]);
F_FLOAT2 my_dot;
FixQEqReaxKokkosDot11Functor<DeviceType> dot11_functor(this);
Kokkos::parallel_reduce(inum,dot11_functor,my_dot);
F_FLOAT2 dot_sqr;
MPI_Allreduce( &my_dot.v, &dot_sqr.v, 2, MPI_DOUBLE, MPI_SUM, world );
F_FLOAT2 sig_new;
sig_new = dot_sqr;
F_FLOAT residual[2] = {0, 0};
int loop;
for (loop = 1; (loop < imax); loop++) {
if (!(converged & 1))
residual[0] = sqrt(sig_new.v[0]) / b_norm.v[0];
if (!(converged & 2))
residual[1] = sqrt(sig_new.v[1]) / b_norm.v[1];
converged = static_cast<int>(residual[0] <= tolerance) | (static_cast<int>(residual[1] <= tolerance) << 1);
if (converged == 3) {
// both cg solves have converged
break;
}
// comm->forward_comm_fix(this); //Dist_vector( d );
pack_flag = 1;
// mark size 2 for fused
comm->forward_comm_fix(this, 2);
// sparse_matvec( &H, d, q );
FixQEqReaxKokkosSparse22FusedFunctor<DeviceType> sparse22_functor(this);
Kokkos::parallel_for(inum,sparse22_functor);
if (neighflag != FULL) {
Kokkos::abort("Not implemented!");
} else {
#ifdef HIP_OPT_SPMV
Kokkos::parallel_for(Kokkos::TeamPolicy <DeviceType, TagSparseMatvec2FusedVector>(leaguesize, teamsize, vectorsize), *this);
#else
Kokkos::parallel_for(Kokkos::TeamPolicy <DeviceType, TagSparseMatvec2Fused> (inum, teamsize), *this);
#endif
}
if (neighflag != FULL) {
Kokkos::abort("Not implemented!");
}
// tmp = parallel_dot( d, q, nn);
my_dot.init();
dot_sqr.init();
FixQEqReaxKokkosDot22Functor<DeviceType> dot22_functor(this);
Kokkos::parallel_reduce(inum,dot22_functor,my_dot);
MPI_Allreduce( &my_dot.v, &dot_sqr.v, 2, MPI_DOUBLE, MPI_SUM, world );
tmp = dot_sqr;
if (!(converged & 1))
alpha[0] = sig_new.v[0] / tmp.v[0];
if (!(converged & 2))
alpha[1] = sig_new.v[1] / tmp.v[1];
sig_old = sig_new;
// vector_add( s, alpha, d, nn );
// vector_add( r, -alpha, q, nn );
my_dot.init();
dot_sqr.init();
FixQEqReaxKokkosPrecon12Functor<DeviceType> precon12_functor(this);
Kokkos::parallel_for(inum,precon12_functor);
// preconditioning: p[j] = r[j] * Hdia_inv[j];
// sig_new = parallel_dot( r, p, nn);
FixQEqReaxKokkosPreconFusedFunctor<DeviceType> precon_functor(this);
Kokkos::parallel_reduce(inum,precon_functor,my_dot);
MPI_Allreduce( &my_dot.v, &dot_sqr.v, 2, MPI_DOUBLE, MPI_SUM, world );
sig_new = dot_sqr;
if (!(converged & 1))
beta[0] = sig_new.v[0] / sig_old.v[0];
if (!(converged & 2))
beta[1] = sig_new.v[1] / sig_old.v[1];
// vector_sum( d, 1., p, beta, d, nn );
FixQEqReaxKokkosVecSum2FusedFunctor<DeviceType> vecsum12_functor(this);
Kokkos::parallel_for(inum,vecsum12_functor);
}
if (loop >= imax && comm->me == 0) {
char str[128];
sprintf(str,"Fix qeq/reax cg_solve_fused convergence failed after %d iterations "
"at " BIGINT_FORMAT " step: (%f, %f)",loop,update->ntimestep,
(sqrt(sig_new.v[0])/b_norm.v[0]), (sqrt(sig_new.v[1])/b_norm.v[1]));
error->warning(FLERR,str);
//error->all(FLERR,str);
}
return loop;
}
#endif
/* ---------------------------------------------------------------------- */
template<class DeviceType>
@ -1115,6 +1310,25 @@ void FixQEqReaxFFKokkos<DeviceType>::sparse12_item(int ii) const
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
// fused operator
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::sparse12_32_item(int ii) const
{
const int i = d_ilist[ii];
const int itype = type(i);
if (mask[i] & groupbit) {
if (!(converged & 1))
d_o_fused(i,0) = params(itype).eta * d_s[i];
if (!(converged & 2))
d_o_fused(i,1) = params(itype).eta * d_t[i];
}
}
#endif
/* ---------------------------------------------------------------------- */
template<class DeviceType>
template<int NEIGHFLAG>
KOKKOS_INLINE_FUNCTION
@ -1153,6 +1367,35 @@ void FixQEqReaxFFKokkos<DeviceType>::operator() (TagSparseMatvec1, const membert
}
}
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::operator() (TagSparseMatvec13, const membertype13 &team) const
{
const int i = d_ilist[team.league_rank()];
if (mask[i] & groupbit) {
F_FLOAT2 doitmp;
Kokkos::parallel_reduce(Kokkos::TeamThreadRange(team, d_firstnbr[i], d_firstnbr[i] + d_numnbrs[i]), [&] (const int &jj, F_FLOAT2& doi) {
const int j = d_jlist(jj);
if (!(converged & 1))
doi.v[0] += d_val(jj) * d_s[j];
if (!(converged & 2))
doi.v[1] += d_val(jj) * d_t[j];
}, doitmp);
Kokkos::single(Kokkos::PerTeam(team), [&] () {
if (!(converged & 1))
d_o_fused(i,0) += doitmp.v[0];
if (!(converged & 2))
d_o_fused(i,1) += doitmp.v[1];
});
}
}
#endif
/* ---------------------------------------------------------------------- */
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::operator() (TagSparseMatvec1Vector, const membertype1vec &team) const
@ -1165,12 +1408,41 @@ void FixQEqReaxFFKokkos<DeviceType>::operator() (TagSparseMatvec1Vector, const m
const int j = d_jlist(jj);
doi += d_val(jj) * d_s[j];
}, doitmp);
Kokkos::single(Kokkos::PerThread(team), [&] () {d_o[i] += doitmp; });
Kokkos::single(Kokkos::PerThread(team), [&] () {d_o[i] += doitmp;
});
}
}
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::operator() (TagSparseMatvec13Vector, const membertype13vec &team) const
{
int k = team.league_rank () * team.team_size () + team.team_rank ();
const int i = d_ilist[k];
if (mask[i] & groupbit) {
F_FLOAT2 doitmp;
Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(team, d_firstnbr[i], d_firstnbr[i] + d_numnbrs[i]), [&] (const int &jj, F_FLOAT2& doi) {
const int j = d_jlist(jj);
if (!(converged & 1))
doi.v[0] += d_val(jj) * d_s[j];
if (!(converged & 2))
doi.v[1] += d_val(jj) * d_t[j];
}, doitmp);
Kokkos::single(Kokkos::PerThread(team), [&] () {
if (!(converged & 1))
d_o_fused(i,0) += doitmp.v[0];
if (!(converged & 2))
d_o_fused(i,1) += doitmp.v[1];
});
}
}
#endif
/* ---------------------------------------------------------------------- */
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::sparse22_item(int ii) const
@ -1184,6 +1456,24 @@ void FixQEqReaxFFKokkos<DeviceType>::sparse22_item(int ii) const
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::sparse22_fused_item(int ii) const
{
const int i = d_ilist[ii];
const int itype = type(i);
if (mask[i] & groupbit) {
if (!(converged & 1))
d_o_fused(i,0) = params(itype).eta * d_d_fused(i,0);
if (!(converged & 2))
d_o_fused(i,1) = params(itype).eta * d_d_fused(i,1);
}
}
#endif
/* ---------------------------------------------------------------------- */
template<class DeviceType>
template<int NEIGHFLAG>
KOKKOS_INLINE_FUNCTION
@ -1207,7 +1497,6 @@ void FixQEqReaxFFKokkos<DeviceType>::sparse23_item(int ii) const
/* ---------------------------------------------------------------------- */
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::operator() (TagSparseMatvec2Vector, const membertype2vec &team) const
@ -1224,6 +1513,8 @@ void FixQEqReaxFFKokkos<DeviceType>::operator() (TagSparseMatvec2Vector, const m
}
}
/* ---------------------------------------------------------------------- */
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::operator() (TagSparseMatvec2, const membertype2 &team) const
@ -1239,6 +1530,62 @@ void FixQEqReaxFFKokkos<DeviceType>::operator() (TagSparseMatvec2, const membert
}
}
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::operator() (TagSparseMatvec2FusedVector, const membertype2fusedvec &team) const
{
int k = team.league_rank () * team.team_size () + team.team_rank ();
const int i = d_ilist[k];
if (mask[i] & groupbit) {
F_FLOAT2 doitmp;
Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(team, d_firstnbr[i], d_firstnbr[i] + d_numnbrs[i]), [&] (const int &jj, F_FLOAT2& doi) {
const int j = d_jlist(jj);
if (!(converged & 1))
doi.v[0] += d_val(jj) * d_d_fused(j,0);
if (!(converged & 2))
doi.v[1] += d_val(jj) * d_d_fused(j,1);
}, doitmp);
Kokkos::single(Kokkos::PerThread(team), [&] () {
if (!(converged & 1))
d_o_fused(i,0) += doitmp.v[0];
if (!(converged & 2))
d_o_fused(i,1) += doitmp.v[1];
});
}
}
#endif
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::operator() (TagSparseMatvec2Fused, const membertype2fused &team) const
{
const int i = d_ilist[team.league_rank()];
if (mask[i] & groupbit) {
F_FLOAT2 doitmp;
Kokkos::parallel_reduce(Kokkos::TeamThreadRange(team, d_firstnbr[i], d_firstnbr[i] + d_numnbrs[i]), [&] (const int &jj, F_FLOAT2& doi) {
const int j = d_jlist(jj);
if (!(converged & 1))
doi.v[0] += d_val(jj) * d_d_fused(j,0);
if (!(converged & 2))
doi.v[1] += d_val(jj) * d_d_fused(j,1);
}, doitmp);
Kokkos::single(Kokkos::PerTeam(team), [&] () {
if (!(converged & 1))
d_o_fused(i,0) += doitmp.v[0];
if (!(converged & 2))
d_o_fused(i,1) += doitmp.v[1];
});
}
}
#endif
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::operator() (TagZeroQGhosts, const int &i) const
@ -1321,13 +1668,32 @@ template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::vecsum2_item(int ii) const
{
#ifndef HIP_OPT_CG_SOLVE_FUSED
const int i = d_ilist[ii];
if (mask[i] & groupbit)
d_d[i] = 1.0 * d_p[i] + beta * d_d[i];
#endif
}
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::vecsum2_fused_item(int ii) const
{
const int i = d_ilist[ii];
if (mask[i] & groupbit) {
if (!(converged & 1))
d_d_fused(i,0) = 1.0 * d_p_fused(i,0) + beta[0] * d_d_fused(i,0);
if (!(converged & 2))
d_d_fused(i,1) = 1.0 * d_p_fused(i,1) + beta[1] * d_d_fused(i,1);
}
}
#endif
/* ---------------------------------------------------------------------- */
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
double FixQEqReaxFFKokkos<DeviceType>::norm1_item(int ii) const
@ -1360,6 +1726,31 @@ double FixQEqReaxFFKokkos<DeviceType>::norm2_item(int ii) const
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::norm12_item(int ii, F_FLOAT2& out) const
{
const int i = d_ilist[ii];
if (mask[i] & groupbit) {
if (!(converged & 1)) {
d_r_fused(i,0) = 1.0*d_b_s[i] + -1.0*d_o_fused(i,0);
d_d_fused(i,0) = d_r_fused(i,0) * d_Hdia_inv[i];
out.v[0] += d_b_s[i] * d_b_s[i];
}
if (!(converged & 2)) {
d_r_fused(i,1) = 1.0*d_b_t[i] + -1.0*d_o_fused(i,1);
d_d_fused(i,1) = d_r_fused(i,1) * d_Hdia_inv[i];
out.v[1] += d_b_t[i] * d_b_t[i];
}
}
}
#endif
/* ---------------------------------------------------------------------- */
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
double FixQEqReaxFFKokkos<DeviceType>::dot1_item(int ii) const
@ -1387,15 +1778,47 @@ double FixQEqReaxFFKokkos<DeviceType>::dot2_item(int ii) const
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::dot11_item(int ii, F_FLOAT2& out) const
{
const int i = d_ilist[ii];
if (mask[i] & groupbit) {
if (!(converged & 1))
out.v[0] += d_r_fused(i,0) * d_d_fused(i,0);
if (!(converged & 2))
out.v[1] += d_r_fused(i,1) * d_d_fused(i,1);
}
}
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::dot22_item(int ii, F_FLOAT2& out) const
{
const int i = d_ilist[ii];
if (mask[i] & groupbit) {
if (!(converged & 1))
out.v[0] += d_d_fused(i,0) * d_o_fused(i,0);
if (!(converged & 2))
out.v[1] += d_d_fused(i,1) * d_o_fused(i,1);
}
}
#endif
/* ---------------------------------------------------------------------- */
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::precon1_item(int ii) const
{
#ifndef HIP_OPT_CG_SOLVE_FUSED
const int i = d_ilist[ii];
if (mask[i] & groupbit) {
d_s[i] += alpha * d_d[i];
d_r[i] += -alpha * d_o[i];
}
#endif
}
/* ---------------------------------------------------------------------- */
@ -1404,15 +1827,39 @@ template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::precon2_item(int ii) const
{
#ifndef HIP_OPT_CG_SOLVE_FUSED
const int i = d_ilist[ii];
if (mask[i] & groupbit) {
d_t[i] += alpha * d_d[i];
d_r[i] += -alpha * d_o[i];
}
#endif
}
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
// fused operator
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::precon12_item(int ii) const
{
const int i = d_ilist[ii];
if (mask[i] & groupbit) {
if (!(converged & 1)){
d_s[i] += alpha[0] * d_d_fused(i,0);
d_r_fused(i,0) += -alpha[0] * d_o_fused(i,0);
}
if (!(converged & 2)) {
d_t[i] += alpha[1] * d_d_fused(i,1);
d_r_fused(i,1) += -alpha[1] * d_o_fused(i,1);
}
}
}
#endif
/* ---------------------------------------------------------------------- */
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
double FixQEqReaxFFKokkos<DeviceType>::precon_item(int ii) const
@ -1428,6 +1875,27 @@ double FixQEqReaxFFKokkos<DeviceType>::precon_item(int ii) const
/* ---------------------------------------------------------------------- */
#ifdef HIP_OPT_CG_SOLVE_FUSED
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxKokkos<DeviceType>::precon_fused_item(int ii, F_FLOAT2& out) const
{
const int i = d_ilist[ii];
if (mask[i] & groupbit) {
if (!(converged & 1)) {
d_p_fused(i,0) = d_r_fused(i,0) * d_Hdia_inv[i];
out.v[0] += d_r_fused(i,0) * d_p_fused(i,0);
}
if (!(converged & 2)) {
d_p_fused(i,1) = d_r_fused(i,1) * d_Hdia_inv[i];
out.v[1] += d_r_fused(i,1) * d_p_fused(i,1);
}
}
}
#endif
/* ---------------------------------------------------------------------- */
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
double FixQEqReaxFFKokkos<DeviceType>::vecacc1_item(int ii) const
@ -1484,7 +1952,16 @@ int FixQEqReaxFFKokkos<DeviceType>::pack_forward_comm_fix_kokkos(int n, DAT::tdu
iswap = iswap_in;
d_buf = k_buf.view<DeviceType>();
Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType, TagFixQEqReaxFFPackForwardComm>(0,n),*this);
#ifdef HIP_OPT_CG_SOLVE_FUSED
if (pack_flag == 1) {
// sending 2x the data
return 2*n;
} else {
return n;
}
#else
return n;
#endif
}
template<class DeviceType>
@ -1492,8 +1969,16 @@ KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::operator()(TagFixQEqReaxFFPackForwardComm, const int &i) const {
int j = d_sendlist(iswap, i);
if (pack_flag == 1)
if (pack_flag == 1) {
#ifndef HIP_OPT_CG_SOLVE_FUSED
d_buf[i] = d_d[j];
#else
if (!(converged & 1))
d_buf[i*2] = d_d_fused(j,0);
if (!(converged & 2))
d_buf[i*2+1] = d_d_fused(j,1);
#endif
}
else if (pack_flag == 2)
d_buf[i] = d_s[j];
else if (pack_flag == 3)
@ -1518,9 +2003,17 @@ void FixQEqReaxFFKokkos<DeviceType>::unpack_forward_comm_fix_kokkos(int n, int f
template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixQEqReaxFFKokkos<DeviceType>::operator()(TagFixQEqReaxFFUnpackForwardComm, const int &i) const {
if (pack_flag == 1)
if (pack_flag == 1){
#ifndef HIP_OPT_CG_SOLVE_FUSED
d_d[i + first] = d_buf[i];
else if (pack_flag == 2)
#else
if (!(converged & 1))
d_d_fused(i+first,0) = d_buf[i*2];
if (!(converged & 2))
d_d_fused(i+first,1) = d_buf[i*2+1];
#endif
}
else if ( pack_flag == 2)
d_s[i + first] = d_buf[i];
else if (pack_flag == 3)
d_t[i + first] = d_buf[i];

View File

@ -38,8 +38,13 @@ struct TagSparseMatvec1 {};
struct TagSparseMatvec1Vector {};
struct TagSparseMatvec2 {};
struct TagSparseMatvec2Vector {};
struct TagSparseMatvec2Fused {};
struct TagSparseMatvec2FusedVector {};
struct TagSparseMatvec3 {};
struct TagSparseMatvec3Vector {};
// fused operators
struct TagSparseMatvec13 {};
struct TagSparseMatvec13Vector {};
struct TagZeroQGhosts{};
struct TagFixQEqReaxFFPackForwardComm {};
struct TagFixQEqReaxFFUnpackForwardComm {};
@ -84,6 +89,11 @@ class FixQEqReaxFFKokkos : public FixQEqReaxFF, public KokkosBase {
KOKKOS_INLINE_FUNCTION
void sparse22_item(int) const;
#ifdef HIP_OPT_CG_SOLVE_FUSED
KOKKOS_INLINE_FUNCTION
void sparse22_fused_item(int) const;
#endif
template<int NEIGHFLAG>
KOKKOS_INLINE_FUNCTION
void sparse23_item(int) const;
@ -91,6 +101,11 @@ class FixQEqReaxFFKokkos : public FixQEqReaxFF, public KokkosBase {
KOKKOS_INLINE_FUNCTION
void sparse32_item(int) const;
#ifdef HIP_OPT_CG_SOLVE_FUSED
KOKKOS_INLINE_FUNCTION
void sparse12_32_item(int) const;
#endif
template<int NEIGHFLAG>
KOKKOS_INLINE_FUNCTION
void sparse33_item(int) const;
@ -111,6 +126,16 @@ class FixQEqReaxFFKokkos : public FixQEqReaxFF, public KokkosBase {
KOKKOS_INLINE_FUNCTION
void operator() (TagSparseMatvec2Vector, const membertype2vec &team) const;
typedef typename Kokkos::TeamPolicy <DeviceType, TagSparseMatvec2Fused> ::member_type membertype2fused;
KOKKOS_INLINE_FUNCTION
void operator() (TagSparseMatvec2Fused, const membertype2fused &team) const;
#ifdef HIP_OPT_CG_SOLVE_FUSED
typedef typename Kokkos::TeamPolicy <DeviceType, TagSparseMatvec2FusedVector> ::member_type membertype2fusedvec;
KOKKOS_INLINE_FUNCTION
void operator() (TagSparseMatvec2FusedVector, const membertype2fusedvec &team) const;
#endif
typedef typename Kokkos::TeamPolicy <DeviceType, TagSparseMatvec3> ::member_type membertype3;
KOKKOS_INLINE_FUNCTION
void operator() (TagSparseMatvec3, const membertype3 &team) const;
@ -119,33 +144,74 @@ class FixQEqReaxFFKokkos : public FixQEqReaxFF, public KokkosBase {
KOKKOS_INLINE_FUNCTION
void operator() (TagSparseMatvec3Vector, const membertype3vec &team) const;
#ifdef HIP_OPT_CG_SOLVE_FUSED
typedef typename Kokkos::TeamPolicy <DeviceType, TagSparseMatvec13> ::member_type membertype13;
KOKKOS_INLINE_FUNCTION
void operator() (TagSparseMatvec13, const membertype13 &team) const;
typedef typename Kokkos::TeamPolicy <DeviceType, TagSparseMatvec13Vector> ::member_type membertype13vec;
KOKKOS_INLINE_FUNCTION
void operator() (TagSparseMatvec13Vector, const membertype13vec &team) const;
#endif
KOKKOS_INLINE_FUNCTION
void operator()(TagZeroQGhosts, const int&) const;
KOKKOS_INLINE_FUNCTION
void vecsum2_item(int) const;
#ifdef HIP_OPT_CG_SOLVE_FUSED
KOKKOS_INLINE_FUNCTION
void vecsum2_fused_item(int) const;
#endif
KOKKOS_INLINE_FUNCTION
double norm1_item(int) const;
KOKKOS_INLINE_FUNCTION
double norm2_item(int) const;
#ifdef HIP_OPT_CG_SOLVE_FUSED
// fused operator
KOKKOS_INLINE_FUNCTION
void norm12_item(int, F_FLOAT2&) const;
#endif
KOKKOS_INLINE_FUNCTION
double dot1_item(int) const;
KOKKOS_INLINE_FUNCTION
double dot2_item(int) const;
#ifdef HIP_OPT_CG_SOLVE_FUSED
// fused operators
KOKKOS_INLINE_FUNCTION
void dot11_item(int, F_FLOAT2&) const;
KOKKOS_INLINE_FUNCTION
void dot22_item(int, F_FLOAT2&) const;
#endif
KOKKOS_INLINE_FUNCTION
void precon1_item(int) const;
KOKKOS_INLINE_FUNCTION
void precon2_item(int) const;
#ifdef HIP_OPT_CG_SOLVE_FUSED
// fused operator
KOKKOS_INLINE_FUNCTION
void precon12_item(int) const;
#endif
KOKKOS_INLINE_FUNCTION
double precon_item(int) const;
#ifdef HIP_OPT_CG_SOLVE_FUSED
KOKKOS_INLINE_FUNCTION
void precon_fused_item(int, F_FLOAT2&) const;
#endif
KOKKOS_INLINE_FUNCTION
double vecacc1_item(int) const;
@ -224,6 +290,15 @@ class FixQEqReaxFFKokkos : public FixQEqReaxFF, public KokkosBase {
DAT::tdual_ffloat_1d k_o, k_d;
typename AT::t_ffloat_1d d_p, d_o, d_r, d_d;
HAT::t_ffloat_1d h_o, h_d;
#ifdef HIP_OPT_CG_SOLVE_FUSED
// fused arrays
DAT::tdual_ffloat2_1d k_o_fused, k_d_fused;
typename AT::t_ffloat2_1d d_p_fused, d_o_fused, d_r_fused, d_d_fused;
HAT::t_ffloat2_1d h_o_fused, h_d_fused;
int converged = 0;
#endif
typename AT::t_ffloat_1d_randomread r_p, r_o, r_r, r_d;
DAT::tdual_ffloat_2d k_shield, k_s_hist, k_t_hist;
@ -243,14 +318,23 @@ class FixQEqReaxFFKokkos : public FixQEqReaxFF, public KokkosBase {
void init_hist();
void allocate_matrix();
void allocate_array();
int cg_solve1();
int cg_solve2();
void cg_solve1();
void cg_solve2();
#ifdef HIP_OPT_CG_SOLVE_FUSED
void cg_solve_fused();
#endif
void calculate_q();
int neighflag, pack_flag;
int nlocal,nall,nmax,newton_pair;
int count, isuccess;
double alpha, beta, delta, cutsq;
#ifdef HIP_OPT_CG_SOLVE_FUSED
F_FLOAT alpha[2];
F_FLOAT beta[2];
#else
double alpha, beta;
#endif
double delta, cutsq;
void grow_arrays(int);
void copy_arrays(int, int, int);
@ -385,6 +469,21 @@ struct FixQEqReaxFFKokkosSparse22Functor {
}
};
#ifdef HIP_OPT_CG_SOLVE_FUSED
template <class DeviceType>
struct FixQEqReaxKokkosSparse22FusedFunctor {
typedef DeviceType device_type ;
FixQEqReaxKokkos<DeviceType> c;
FixQEqReaxKokkosSparse22FusedFunctor(FixQEqReaxKokkos<DeviceType>* c_ptr):c(*c_ptr) {
c.cleanup_copy();
};
KOKKOS_INLINE_FUNCTION
void operator()(const int ii) const {
c.sparse22_fused_item(ii);
}
};
#endif
template <class DeviceType,int NEIGHFLAG>
struct FixQEqReaxFFKokkosSparse23Functor {
typedef DeviceType device_type ;
@ -411,6 +510,21 @@ struct FixQEqReaxFFKokkosSparse32Functor {
}
};
#ifdef HIP_OPT_CG_SOLVE_FUSED
template <class DeviceType>
struct FixQEqReaxKokkosSparse12_32Functor {
typedef DeviceType device_type ;
FixQEqReaxKokkos<DeviceType> c;
FixQEqReaxKokkosSparse12_32Functor(FixQEqReaxKokkos<DeviceType>* c_ptr):c(*c_ptr) {
c.cleanup_copy();
};
KOKKOS_INLINE_FUNCTION
void operator()(const int ii) const {
c.sparse12_32_item(ii);
}
};
#endif
template <class DeviceType,int NEIGHFLAG>
struct FixQEqReaxFFKokkosSparse33Functor {
typedef DeviceType device_type ;
@ -437,6 +551,21 @@ struct FixQEqReaxFFKokkosVecSum2Functor {
}
};
#ifdef HIP_OPT_CG_SOLVE_FUSED
template <class DeviceType>
struct FixQEqReaxKokkosVecSum2FusedFunctor {
typedef DeviceType device_type ;
FixQEqReaxKokkos<DeviceType> c;
FixQEqReaxKokkosVecSum2FusedFunctor(FixQEqReaxKokkos<DeviceType>* c_ptr):c(*c_ptr) {
c.cleanup_copy();
};
KOKKOS_INLINE_FUNCTION
void operator()(const int ii) const {
c.vecsum2_fused_item(ii);
}
};
#endif
template <class DeviceType>
struct FixQEqReaxFFKokkosNorm1Functor {
typedef DeviceType device_type ;
@ -465,6 +594,23 @@ struct FixQEqReaxFFKokkosNorm2Functor {
}
};
#ifdef HIP_OPT_CG_SOLVE_FUSED
// fused operator
template <class DeviceType>
struct FixQEqReaxKokkosNorm12Functor {
typedef DeviceType device_type ;
FixQEqReaxKokkos<DeviceType> c;
typedef F_FLOAT2 value_type;
FixQEqReaxKokkosNorm12Functor(FixQEqReaxKokkos<DeviceType>* c_ptr):c(*c_ptr) {
c.cleanup_copy();
};
KOKKOS_INLINE_FUNCTION
void operator()(const int ii, value_type &tmp) const {
c.norm12_item(ii, tmp);
}
};
#endif
template <class DeviceType>
struct FixQEqReaxFFKokkosDot1Functor {
typedef DeviceType device_type ;
@ -493,6 +639,37 @@ struct FixQEqReaxFFKokkosDot2Functor {
}
};
#ifdef HIP_OPT_CG_SOLVE_FUSED
// fused operators
template <class DeviceType>
struct FixQEqReaxKokkosDot11Functor {
typedef DeviceType device_type ;
FixQEqReaxKokkos<DeviceType> c;
typedef F_FLOAT2 value_type;
FixQEqReaxKokkosDot11Functor(FixQEqReaxKokkos<DeviceType>* c_ptr):c(*c_ptr) {
c.cleanup_copy();
};
KOKKOS_INLINE_FUNCTION
void operator()(const int ii, value_type &tmp) const {
c.dot11_item(ii, tmp);
}
};
template <class DeviceType>
struct FixQEqReaxKokkosDot22Functor {
typedef DeviceType device_type ;
FixQEqReaxKokkos<DeviceType> c;
typedef F_FLOAT2 value_type;
FixQEqReaxKokkosDot22Functor(FixQEqReaxKokkos<DeviceType>* c_ptr):c(*c_ptr) {
c.cleanup_copy();
};
KOKKOS_INLINE_FUNCTION
void operator()(const int ii, value_type &tmp) const {
c.dot22_item(ii, tmp);
}
};
#endif
template <class DeviceType>
struct FixQEqReaxFFKokkosPrecon1Functor {
typedef DeviceType device_type ;
@ -519,6 +696,21 @@ struct FixQEqReaxFFKokkosPrecon2Functor {
}
};
#ifdef HIP_OPT_CG_SOLVE_FUSED
template <class DeviceType>
struct FixQEqReaxKokkosPrecon12Functor {
typedef DeviceType device_type ;
FixQEqReaxKokkos<DeviceType> c;
FixQEqReaxKokkosPrecon12Functor(FixQEqReaxKokkos<DeviceType>* c_ptr):c(*c_ptr) {
c.cleanup_copy();
};
KOKKOS_INLINE_FUNCTION
void operator()(const int ii) const {
c.precon12_item(ii);
}
};
#endif
template <class DeviceType>
struct FixQEqReaxFFKokkosPreconFunctor {
typedef DeviceType device_type ;
@ -533,6 +725,22 @@ struct FixQEqReaxFFKokkosPreconFunctor {
}
};
#ifdef HIP_OPT_CG_SOLVE_FUSED
template <class DeviceType>
struct FixQEqReaxKokkosPreconFusedFunctor {
typedef DeviceType device_type ;
FixQEqReaxKokkos<DeviceType> c;
typedef F_FLOAT2 value_type;
FixQEqReaxKokkosPreconFusedFunctor(FixQEqReaxKokkos<DeviceType>* c_ptr):c(*c_ptr) {
c.cleanup_copy();
};
KOKKOS_INLINE_FUNCTION
void operator()(const int ii, value_type &tmp) const {
c.precon_fused_item(ii, tmp);
}
};
#endif
template <class DeviceType>
struct FixQEqReaxFFKokkosVecAcc1Functor {
typedef DeviceType device_type ;
@ -573,8 +781,18 @@ struct FixQEqReaxFFKokkosCalculateQFunctor {
c.calculate_q_item(ii);
}
};
}
#ifdef HIP_OPT_CG_SOLVE_FUSED
namespace Kokkos { //reduction identity must be defined in Kokkos namespace
template<>
struct reduction_identity< F_FLOAT2 > {
KOKKOS_FORCEINLINE_FUNCTION static F_FLOAT2 sum() {
return F_FLOAT2();
}
};
}
#endif
#endif
#endif

View File

@ -501,6 +501,41 @@ struct s_FEV_FLOAT {
};
typedef struct s_FEV_FLOAT<6,3> FEV_FLOAT;
struct s_FLOAT2 {
F_FLOAT v[2];
KOKKOS_INLINE_FUNCTION
s_FLOAT2() {
init();
}
KOKKOS_INLINE_FUNCTION
s_FLOAT2(const s_FLOAT2 & rhs) {
for (int i = 0; i < 2; i++){
v[i] = rhs.v[i];
}
}
KOKKOS_INLINE_FUNCTION
void operator+=(const s_FLOAT2 &rhs) {
v[0] += rhs.v[0];
v[1] += rhs.v[1];
}
KOKKOS_INLINE_FUNCTION
void operator+=(const volatile s_FLOAT2 &rhs) volatile {
v[0] += rhs.v[0];
v[1] += rhs.v[1];
}
KOKKOS_INLINE_FUNCTION
void init() {
v[0] = 0;
v[1] = 0;
}
};
typedef struct s_FLOAT2 F_FLOAT2;
#ifndef PREC_POS
#define PREC_POS PRECISION
#endif
@ -732,6 +767,14 @@ typedef tdual_ffloat_1d::t_dev_um t_ffloat_1d_um;
typedef tdual_ffloat_1d::t_dev_const_um t_ffloat_1d_const_um;
typedef tdual_ffloat_1d::t_dev_const_randomread t_ffloat_1d_randomread;
// 1d F_FLOAT2 array n
typedef Kokkos::DualView<F_FLOAT*[2], Kokkos::LayoutRight, LMPDeviceType> tdual_ffloat2_1d;
typedef tdual_ffloat2_1d::t_dev t_ffloat2_1d;
typedef tdual_ffloat2_1d::t_dev_const t_ffloat2_1d_const;
typedef tdual_ffloat2_1d::t_dev_um t_ffloat2_1d_um;
typedef tdual_ffloat2_1d::t_dev_const_um t_ffloat2_1d_const_um;
typedef tdual_ffloat2_1d::t_dev_const_randomread t_ffloat2_1d_randomread;
//2d F_FLOAT array n*m
typedef Kokkos::DualView<F_FLOAT**, Kokkos::LayoutRight, LMPDeviceType> tdual_ffloat_2d;
@ -1002,6 +1045,14 @@ typedef tdual_ffloat_1d::t_host_um t_ffloat_1d_um;
typedef tdual_ffloat_1d::t_host_const_um t_ffloat_1d_const_um;
typedef tdual_ffloat_1d::t_host_const_randomread t_ffloat_1d_randomread;
// 1d F_FLOAT2 array n
typedef Kokkos::DualView<F_FLOAT*[2], Kokkos::LayoutRight, LMPDeviceType> tdual_ffloat2_1d;
typedef tdual_ffloat2_1d::t_host t_ffloat2_1d;
typedef tdual_ffloat2_1d::t_host_const t_ffloat2_1d_const;
typedef tdual_ffloat2_1d::t_host_um t_ffloat2_1d_um;
typedef tdual_ffloat2_1d::t_host_const_um t_ffloat2_1d_const_um;
typedef tdual_ffloat2_1d::t_host_const_randomread t_ffloat2_1d_randomread;
//2d F_FLOAT array n*m
typedef Kokkos::DualView<F_FLOAT**, Kokkos::LayoutRight, LMPDeviceType> tdual_ffloat_2d;
typedef tdual_ffloat_2d::t_host t_ffloat_2d;