From 00cbb633bf4e8a9f9066e41e0e27da7e455a80e6 Mon Sep 17 00:00:00 2001 From: Nick Curtis Date: Wed, 15 Sep 2021 14:39:34 -0400 Subject: [PATCH] Implement host MPI for fused QEQ Change-Id: I3278a72878fb7cdb64a059aaf025c039dc0d71e5 --- src/KOKKOS/fix_qeq_reaxff_kokkos.cpp | 36 +++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/src/KOKKOS/fix_qeq_reaxff_kokkos.cpp b/src/KOKKOS/fix_qeq_reaxff_kokkos.cpp index 3ea8c6725f..2a0b5389c1 100644 --- a/src/KOKKOS/fix_qeq_reaxff_kokkos.cpp +++ b/src/KOKKOS/fix_qeq_reaxff_kokkos.cpp @@ -2031,8 +2031,18 @@ int FixQEqReaxFFKokkos::pack_forward_comm(int n, int *list, double * int m; if (pack_flag == 1) { + #ifndef HIP_OPT_CG_SOLVE_FUSED k_d.sync_host(); for (m = 0; m < n; m++) buf[m] = h_d[list[m]]; + #else + k_d_fused.sync_host(); + for (m = 0; m < n; m++) { + if (!(converged & 1)) + buf[m*2] = h_d_fused(list[m],0); + if (!(converged & 2)) + buf[m*2+1] = h_d_fused(list[m],1); + } + #endif } else if (pack_flag == 2) { k_s.sync_host(); for (m = 0; m < n; m++) buf[m] = h_s[list[m]]; @@ -2044,7 +2054,16 @@ int FixQEqReaxFFKokkos::pack_forward_comm(int n, int *list, double * for (m = 0; m < n; m++) buf[m] = atom->q[list[m]]; } + #ifdef HIP_OPT_CG_SOLVE_FUSED + if (pack_flag == 1) { + // sending 2x the data + return 2*n; + } else { + return n; + } + #else return n; + #endif } /* ---------------------------------------------------------------------- */ @@ -2055,9 +2074,20 @@ void FixQEqReaxFFKokkos::unpack_forward_comm(int n, int first, doubl int i, m; if (pack_flag == 1) { - k_d.sync_host(); - for (m = 0, i = first; m < n; m++, i++) h_d[i] = buf[m]; - k_d.modify_host(); + #ifndef HIP_OPT_CG_SOLVE_FUSED + k_d.sync_host(); + for (m = 0, i = first; m < n; m++, i++) h_d[i] = buf[m]; + k_d.modify_host(); + #else + k_d_fused.sync_host(); + for (m = 0, i = first; m < n; m++, i++) { + if (!(converged & 1)) + h_d_fused(i,0) = buf[m*2]; + if (!(converged & 2)) + h_d_fused(i,1) = buf[m*2+1]; + } + k_d_fused.modify_host(); + #endif } else if (pack_flag == 2) { k_s.sync_host(); for (m = 0, i = first; m < n; m++, i++) h_s[i] = buf[m];