Switched to using Kokkos device data for ODE scratch data.

- Finished porting all scratch arrays to using the StridedArrayType
  template.
- Created a single, large Kokkos device array and using that for all
  scratch data passed into the StridedArrayType objects.
This commit is contained in:
Christopher Stone
2017-02-12 22:48:02 -05:00
parent 4ac7a5d1f2
commit 2f32c1a9af
2 changed files with 67 additions and 52 deletions

View File

@ -203,14 +203,14 @@ void FixRxKokkos<DeviceType>::rk4(const double t_stop, double *y, double *rwork,
/* ---------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- */
template <typename DeviceType> template <typename DeviceType>
template <typename UserDataType> template <typename VectorType, typename UserDataType>
void FixRxKokkos<DeviceType>::k_rk4(const double t_stop, double *y, double *rwork, UserDataType& userData) const void FixRxKokkos<DeviceType>::k_rk4(const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData) const
{ {
double *k1 = rwork; VectorType k1( rwork );
double *k2 = k1 + nspecies; VectorType k2( &k1[nspecies] );
double *k3 = k2 + nspecies; VectorType k3( &k2[nspecies] );
double *k4 = k3 + nspecies; VectorType k4( &k3[nspecies] );
double *yp = k4 + nspecies; VectorType yp( &k4[nspecies] );
const int numSteps = minSteps; const int numSteps = minSteps;
@ -262,8 +262,8 @@ void FixRxKokkos<DeviceType>::k_rk4(const double t_stop, double *y, double *rwor
// x = x + a1*f1 + a3*f3 + a4*f4 + a5*f5 // x = x + a1*f1 + a3*f3 + a4*f4 + a5*f5
template <typename DeviceType> template <typename DeviceType>
template <typename UserDataType> template <typename VectorType, typename UserDataType>
void FixRxKokkos<DeviceType>::k_rkf45_step (const int neq, const double h, double y[], double y_out[], double rwk[], UserDataType& userData) const void FixRxKokkos<DeviceType>::k_rkf45_step (const int neq, const double h, VectorType& y, VectorType& y_out, VectorType& rwk, UserDataType& userData) const
{ {
const double c21=0.25; const double c21=0.25;
const double c31=0.09375; const double c31=0.09375;
@ -291,16 +291,15 @@ void FixRxKokkos<DeviceType>::k_rkf45_step (const int neq, const double h, doubl
const double b6=0.036363636363636; const double b6=0.036363636363636;
// local dependent variables (5 total) // local dependent variables (5 total)
double* f1 = &rwk[ 0]; VectorType& f1 = rwk;
double* f2 = &rwk[ neq]; VectorType f2( &rwk[ neq] );
double* f3 = &rwk[2*neq]; VectorType f3( &rwk[2*neq] );
double* f4 = &rwk[3*neq]; VectorType f4( &rwk[3*neq] );
double* f5 = &rwk[4*neq]; VectorType f5( &rwk[4*neq] );
double* f6 = &rwk[5*neq]; VectorType f6( &rwk[5*neq] );
// scratch for the intermediate solution. // scratch for the intermediate solution.
//double* ytmp = &rwk[6*neq]; VectorType& ytmp = y_out;
double* ytmp = y_out;
// 1) // 1)
k_rhs (0.0, y, f1, userData); k_rhs (0.0, y, f1, userData);
@ -368,11 +367,11 @@ void FixRxKokkos<DeviceType>::k_rkf45_step (const int neq, const double h, doubl
} }
template <typename DeviceType> template <typename DeviceType>
template <typename UserDataType> template <typename VectorType, typename UserDataType>
int FixRxKokkos<DeviceType>::k_rkf45_h0 int FixRxKokkos<DeviceType>::k_rkf45_h0
(const int neq, const double t, const double t_stop, (const int neq, const double t, const double t_stop,
const double hmin, const double hmax, const double hmin, const double hmax,
double& h0, double y[], double rwk[], UserDataType& userData) const double& h0, VectorType& y, VectorType& rwk, UserDataType& userData) const
{ {
// Set lower and upper bounds on h0, and take geometric mean as first trial value. // Set lower and upper bounds on h0, and take geometric mean as first trial value.
// Exit with this value if the bounds cross each other. // Exit with this value if the bounds cross each other.
@ -388,9 +387,9 @@ int FixRxKokkos<DeviceType>::k_rkf45_h0
// Start iteration to find solution to ... {WRMS norm of (h0^2 y'' / 2)} = 1 // Start iteration to find solution to ... {WRMS norm of (h0^2 y'' / 2)} = 1
double *ydot = rwk; VectorType& ydot = rwk;
double *y1 = ydot + neq; VectorType y1 ( &ydot[ neq] );
double *ydot1 = y1 + neq; VectorType ydot1 ( &ydot[2*neq] );
const int max_iters = 10; const int max_iters = 10;
bool hnew_is_ok = false; bool hnew_is_ok = false;
@ -463,8 +462,8 @@ int FixRxKokkos<DeviceType>::k_rkf45_h0
} }
template <typename DeviceType> template <typename DeviceType>
template <typename UserDataType> template <typename VectorType, typename UserDataType>
void FixRxKokkos<DeviceType>::k_rkf45(const int neq, const double t_stop, double *y, double *rwork, UserDataType& userData, CounterType& counter) const void FixRxKokkos<DeviceType>::k_rkf45(const int neq, const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData, CounterType& counter) const
{ {
// Rounding coefficient. // Rounding coefficient.
const double uround = DBL_EPSILON; const double uround = DBL_EPSILON;
@ -501,9 +500,10 @@ void FixRxKokkos<DeviceType>::k_rkf45(const int neq, const double t_stop, double
//printf("t= %e t_stop= %e h= %e\n", t, t_stop, h); //printf("t= %e t_stop= %e h= %e\n", t, t_stop, h);
// Integrate until we reach the end time. // Integrate until we reach the end time.
while (fabs(t - t_stop) > tround){ while (fabs(t - t_stop) > tround)
double *yout = rwork; {
double *eout = yout + neq; VectorType& yout = rwork;
VectorType eout ( &yout[neq] );
// Take a trial step. // Take a trial step.
k_rkf45_step (neq, h, y, yout, eout, userData); k_rkf45_step (neq, h, y, yout, eout, userData);
@ -1035,8 +1035,6 @@ template <typename DeviceType>
template <typename VectorType, typename UserDataType> template <typename VectorType, typename UserDataType>
int FixRxKokkos<DeviceType>::k_rhs(double t, const VectorType& y, VectorType& dydt, UserDataType& userData) const int FixRxKokkos<DeviceType>::k_rhs(double t, const VectorType& y, VectorType& dydt, UserDataType& userData) const
{ {
//StridedArrayType<double,1> _y( const_cast<double *>( y ) ), _dydt( dydt );
// Use the sparse format instead. // Use the sparse format instead.
if (useSparseKinetics) if (useSparseKinetics)
return this->k_rhs_sparse( t, y, dydt, userData); return this->k_rhs_sparse( t, y, dydt, userData);
@ -1409,20 +1407,36 @@ void FixRxKokkos<DeviceType>::solve_reactions(const int vflag, const bool isPreF
); );
} }
// Create scratch array space.
const size_t scratchSpaceSize = (8*nspecies + 2*nreactions);
//double *scratchSpace = new double[ scratchSpaceSize * nlocal ];
typename ArrayTypes<DeviceType>::t_double_1d d_scratchSpace("d_scratchSpace", scratchSpaceSize * nlocal);
Kokkos::parallel_reduce( nlocal, LAMMPS_LAMBDA(int i, CounterType &counter) Kokkos::parallel_reduce( nlocal, LAMMPS_LAMBDA(int i, CounterType &counter)
{ {
if (d_mask(i) & groupbit) if (d_mask(i) & groupbit)
{ {
double *y = new double[8*nspecies]; //double *y = new double[8*nspecies];
double *rwork = y + nspecies; //double *rwork = y + nspecies;
UserRHSData userData; //StridedArrayType<double,1> _y( y );
userData.kFor = new double[nreactions]; //StridedArrayType<double,1> _rwork( rwork );
userData.rxnRateLaw = new double[nreactions];
UserRHSDataKokkos<1> userDataKokkos; StridedArrayType<double,1> y( d_scratchSpace.ptr_on_device() + scratchSpaceSize * i );
userDataKokkos.kFor.m_data = userData.kFor; StridedArrayType<double,1> rwork( &y[nspecies] );
userDataKokkos.rxnRateLaw.m_data = userData.rxnRateLaw;
//UserRHSData userData;
//userData.kFor = new double[nreactions];
//userData.rxnRateLaw = new double[nreactions];
//UserRHSDataKokkos<1> userDataKokkos;
//userDataKokkos.kFor.m_data = userData.kFor;
//userDataKokkos.rxnRateLaw.m_data = userData.rxnRateLaw;
UserRHSDataKokkos<1> userData;
userData.kFor.m_data = &( rwork[7*nspecies] );
userData.rxnRateLaw.m_data = &( userData.kFor[ nreactions ] );
CounterType counter_i; CounterType counter_i;
@ -1452,12 +1466,11 @@ void FixRxKokkos<DeviceType>::solve_reactions(const int vflag, const bool isPreF
// Solver the ODE system. // Solver the ODE system.
if (odeIntegrationFlag == ODE_LAMMPS_RK4) if (odeIntegrationFlag == ODE_LAMMPS_RK4)
{ {
//rk4(t_stop, y, rwork, &userData); k_rk4(t_stop, y, rwork, userData);
k_rk4(t_stop, y, rwork, userDataKokkos);
} }
else if (odeIntegrationFlag == ODE_LAMMPS_RKF45) else if (odeIntegrationFlag == ODE_LAMMPS_RKF45)
{ {
rkf45(nspecies, t_stop, y, rwork, &userData, counter_i); k_rkf45(nspecies, t_stop, y, rwork, userData, counter_i);
if (diagnosticFrequency == 1) if (diagnosticFrequency == 1)
{ {
@ -1477,9 +1490,9 @@ void FixRxKokkos<DeviceType>::solve_reactions(const int vflag, const bool isPreF
d_dvector(ispecies,i) = y[ispecies]; d_dvector(ispecies,i) = y[ispecies];
} }
delete [] y; //delete [] y;
delete [] userData.kFor; //delete [] userData.kFor;
delete [] userData.rxnRateLaw; //delete [] userData.rxnRateLaw;
// Update the iteration statistics counter. Is this unique for each iteration? // Update the iteration statistics counter. Is this unique for each iteration?
counter += counter_i; counter += counter_i;
@ -1490,6 +1503,8 @@ void FixRxKokkos<DeviceType>::solve_reactions(const int vflag, const bool isPreF
, TotalCounters // reduction value for all iterations. , TotalCounters // reduction value for all iterations.
); );
//delete [] scratchSpace;
TimerType timer_ODE = getTimeStamp(); TimerType timer_ODE = getTimeStamp();
// Signal that dvector has been modified on this execution space. // Signal that dvector has been modified on this execution space.

View File

@ -129,23 +129,23 @@ class FixRxKokkos : public FixRX {
double& h0, double y[], double rwk[], void *v_params) const; double& h0, double y[], double rwk[], void *v_params) const;
//!< Classic Runge-Kutta 4th-order stepper. //!< Classic Runge-Kutta 4th-order stepper.
template <typename UserDataType> template <typename VectorType, typename UserDataType>
void k_rk4(const double t_stop, double *y, double *rwork, UserDataType& userData) const; void k_rk4(const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData) const;
//!< Runge-Kutta-Fehlberg ODE Solver. //!< Runge-Kutta-Fehlberg ODE Solver.
template <typename UserDataType> template <typename VectorType, typename UserDataType>
void k_rkf45(const int neq, const double t_stop, double *y, double *rwork, UserDataType& userData, CounterType& counter) const; void k_rkf45(const int neq, const double t_stop, VectorType& y, VectorType& rwork, UserDataType& userData, CounterType& counter) const;
//!< Runge-Kutta-Fehlberg ODE stepper function. //!< Runge-Kutta-Fehlberg ODE stepper function.
template <typename UserDataType> template <typename VectorType, typename UserDataType>
void k_rkf45_step (const int neq, const double h, double y[], double y_out[], void k_rkf45_step (const int neq, const double h, VectorType& y, VectorType& y_out,
double rwk[], UserDataType& userData) const; VectorType& rwk, UserDataType& userData) const;
//!< Initial step size estimation for the Runge-Kutta-Fehlberg ODE solver. //!< Initial step size estimation for the Runge-Kutta-Fehlberg ODE solver.
template <typename UserDataType> template <typename VectorType, typename UserDataType>
int k_rkf45_h0 (const int neq, const double t, const double t_stop, int k_rkf45_h0 (const int neq, const double t, const double t_stop,
const double hmin, const double hmax, const double hmin, const double hmax,
double& h0, double y[], double rwk[], UserDataType& userData) const; double& h0, VectorType& y, VectorType& rwk, UserDataType& userData) const;
//!< ODE Solver diagnostics. //!< ODE Solver diagnostics.
void odeDiagnostics(void); void odeDiagnostics(void);