diff --git a/src/KOKKOS/fix_rx_kokkos.cpp b/src/KOKKOS/fix_rx_kokkos.cpp index a6da0306bb..09a122a108 100644 --- a/src/KOKKOS/fix_rx_kokkos.cpp +++ b/src/KOKKOS/fix_rx_kokkos.cpp @@ -203,14 +203,14 @@ void FixRxKokkos::rk4(const double t_stop, double *y, double *rwork, /* ---------------------------------------------------------------------- */ template - template -void FixRxKokkos::k_rk4(const double t_stop, double *y, double *rwork, UserDataType& userData) const + template +void FixRxKokkos::k_rk4(const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData) const { - double *k1 = rwork; - double *k2 = k1 + nspecies; - double *k3 = k2 + nspecies; - double *k4 = k3 + nspecies; - double *yp = k4 + nspecies; + VectorType k1( rwork ); + VectorType k2( &k1[nspecies] ); + VectorType k3( &k2[nspecies] ); + VectorType k4( &k3[nspecies] ); + VectorType yp( &k4[nspecies] ); const int numSteps = minSteps; @@ -262,8 +262,8 @@ void FixRxKokkos::k_rk4(const double t_stop, double *y, double *rwor // x = x + a1*f1 + a3*f3 + a4*f4 + a5*f5 template - template -void FixRxKokkos::k_rkf45_step (const int neq, const double h, double y[], double y_out[], double rwk[], UserDataType& userData) const + template +void FixRxKokkos::k_rkf45_step (const int neq, const double h, VectorType& y, VectorType& y_out, VectorType& rwk, UserDataType& userData) const { const double c21=0.25; const double c31=0.09375; @@ -291,16 +291,15 @@ void FixRxKokkos::k_rkf45_step (const int neq, const double h, doubl const double b6=0.036363636363636; // local dependent variables (5 total) - double* f1 = &rwk[ 0]; - double* f2 = &rwk[ neq]; - double* f3 = &rwk[2*neq]; - double* f4 = &rwk[3*neq]; - double* f5 = &rwk[4*neq]; - double* f6 = &rwk[5*neq]; + VectorType& f1 = rwk; + VectorType f2( &rwk[ neq] ); + VectorType f3( &rwk[2*neq] ); + VectorType f4( &rwk[3*neq] ); + VectorType f5( &rwk[4*neq] ); + VectorType f6( &rwk[5*neq] ); // scratch for the intermediate solution. - //double* ytmp = &rwk[6*neq]; - double* ytmp = y_out; + VectorType& ytmp = y_out; // 1) k_rhs (0.0, y, f1, userData); @@ -368,11 +367,11 @@ void FixRxKokkos::k_rkf45_step (const int neq, const double h, doubl } template - template + template int FixRxKokkos::k_rkf45_h0 (const int neq, const double t, const double t_stop, const double hmin, const double hmax, - double& h0, double y[], double rwk[], UserDataType& userData) const + double& h0, VectorType& y, VectorType& rwk, UserDataType& userData) const { // Set lower and upper bounds on h0, and take geometric mean as first trial value. // Exit with this value if the bounds cross each other. @@ -388,9 +387,9 @@ int FixRxKokkos::k_rkf45_h0 // Start iteration to find solution to ... {WRMS norm of (h0^2 y'' / 2)} = 1 - double *ydot = rwk; - double *y1 = ydot + neq; - double *ydot1 = y1 + neq; + VectorType& ydot = rwk; + VectorType y1 ( &ydot[ neq] ); + VectorType ydot1 ( &ydot[2*neq] ); const int max_iters = 10; bool hnew_is_ok = false; @@ -463,8 +462,8 @@ int FixRxKokkos::k_rkf45_h0 } template - template -void FixRxKokkos::k_rkf45(const int neq, const double t_stop, double *y, double *rwork, UserDataType& userData, CounterType& counter) const + template +void FixRxKokkos::k_rkf45(const int neq, const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData, CounterType& counter) const { // Rounding coefficient. const double uround = DBL_EPSILON; @@ -501,9 +500,10 @@ void FixRxKokkos::k_rkf45(const int neq, const double t_stop, double //printf("t= %e t_stop= %e h= %e\n", t, t_stop, h); // Integrate until we reach the end time. - while (fabs(t - t_stop) > tround){ - double *yout = rwork; - double *eout = yout + neq; + while (fabs(t - t_stop) > tround) + { + VectorType& yout = rwork; + VectorType eout ( &yout[neq] ); // Take a trial step. k_rkf45_step (neq, h, y, yout, eout, userData); @@ -1035,8 +1035,6 @@ template template int FixRxKokkos::k_rhs(double t, const VectorType& y, VectorType& dydt, UserDataType& userData) const { - //StridedArrayType _y( const_cast( y ) ), _dydt( dydt ); - // Use the sparse format instead. if (useSparseKinetics) return this->k_rhs_sparse( t, y, dydt, userData); @@ -1409,20 +1407,36 @@ void FixRxKokkos::solve_reactions(const int vflag, const bool isPreF ); } + // Create scratch array space. + const size_t scratchSpaceSize = (8*nspecies + 2*nreactions); + //double *scratchSpace = new double[ scratchSpaceSize * nlocal ]; + + typename ArrayTypes::t_double_1d d_scratchSpace("d_scratchSpace", scratchSpaceSize * nlocal); + Kokkos::parallel_reduce( nlocal, LAMMPS_LAMBDA(int i, CounterType &counter) { if (d_mask(i) & groupbit) { - double *y = new double[8*nspecies]; - double *rwork = y + nspecies; + //double *y = new double[8*nspecies]; + //double *rwork = y + nspecies; - UserRHSData userData; - userData.kFor = new double[nreactions]; - userData.rxnRateLaw = new double[nreactions]; + //StridedArrayType _y( y ); + //StridedArrayType _rwork( rwork ); - UserRHSDataKokkos<1> userDataKokkos; - userDataKokkos.kFor.m_data = userData.kFor; - userDataKokkos.rxnRateLaw.m_data = userData.rxnRateLaw; + StridedArrayType y( d_scratchSpace.ptr_on_device() + scratchSpaceSize * i ); + StridedArrayType rwork( &y[nspecies] ); + + //UserRHSData userData; + //userData.kFor = new double[nreactions]; + //userData.rxnRateLaw = new double[nreactions]; + + //UserRHSDataKokkos<1> userDataKokkos; + //userDataKokkos.kFor.m_data = userData.kFor; + //userDataKokkos.rxnRateLaw.m_data = userData.rxnRateLaw; + + UserRHSDataKokkos<1> userData; + userData.kFor.m_data = &( rwork[7*nspecies] ); + userData.rxnRateLaw.m_data = &( userData.kFor[ nreactions ] ); CounterType counter_i; @@ -1452,12 +1466,11 @@ void FixRxKokkos::solve_reactions(const int vflag, const bool isPreF // Solver the ODE system. if (odeIntegrationFlag == ODE_LAMMPS_RK4) { - //rk4(t_stop, y, rwork, &userData); - k_rk4(t_stop, y, rwork, userDataKokkos); + k_rk4(t_stop, y, rwork, userData); } else if (odeIntegrationFlag == ODE_LAMMPS_RKF45) { - rkf45(nspecies, t_stop, y, rwork, &userData, counter_i); + k_rkf45(nspecies, t_stop, y, rwork, userData, counter_i); if (diagnosticFrequency == 1) { @@ -1477,9 +1490,9 @@ void FixRxKokkos::solve_reactions(const int vflag, const bool isPreF d_dvector(ispecies,i) = y[ispecies]; } - delete [] y; - delete [] userData.kFor; - delete [] userData.rxnRateLaw; + //delete [] y; + //delete [] userData.kFor; + //delete [] userData.rxnRateLaw; // Update the iteration statistics counter. Is this unique for each iteration? counter += counter_i; @@ -1490,6 +1503,8 @@ void FixRxKokkos::solve_reactions(const int vflag, const bool isPreF , TotalCounters // reduction value for all iterations. ); + //delete [] scratchSpace; + TimerType timer_ODE = getTimeStamp(); // Signal that dvector has been modified on this execution space. diff --git a/src/KOKKOS/fix_rx_kokkos.h b/src/KOKKOS/fix_rx_kokkos.h index e36d606525..9ac944c6a5 100644 --- a/src/KOKKOS/fix_rx_kokkos.h +++ b/src/KOKKOS/fix_rx_kokkos.h @@ -129,23 +129,23 @@ class FixRxKokkos : public FixRX { double& h0, double y[], double rwk[], void *v_params) const; //!< Classic Runge-Kutta 4th-order stepper. - template - void k_rk4(const double t_stop, double *y, double *rwork, UserDataType& userData) const; + template + void k_rk4(const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData) const; //!< Runge-Kutta-Fehlberg ODE Solver. - template - void k_rkf45(const int neq, const double t_stop, double *y, double *rwork, UserDataType& userData, CounterType& counter) const; + template + void k_rkf45(const int neq, const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData, CounterType& counter) const; //!< Runge-Kutta-Fehlberg ODE stepper function. - template - void k_rkf45_step (const int neq, const double h, double y[], double y_out[], - double rwk[], UserDataType& userData) const; + template + void k_rkf45_step (const int neq, const double h, VectorType& y, VectorType& y_out, + VectorType& rwk, UserDataType& userData) const; //!< Initial step size estimation for the Runge-Kutta-Fehlberg ODE solver. - template + template int k_rkf45_h0 (const int neq, const double t, const double t_stop, const double hmin, const double hmax, - double& h0, double y[], double rwk[], UserDataType& userData) const; + double& h0, VectorType& y, VectorType& rwk, UserDataType& userData) const; //!< ODE Solver diagnostics. void odeDiagnostics(void);