Updated FixRxKokkos to use kokkos-managed data objects.

- Switched to use kokkos dvector, mask, and dpdTheta views
  from atomKK.
This commit is contained in:
Christopher Stone
2017-01-28 10:41:16 -05:00
parent 70fa9189a8
commit 2ea900df00

View File

@ -760,9 +760,9 @@ void FixRxKokkos<DeviceType>::pre_force(int vflag)
TimerType timer_start = getTimeStamp();
int nlocal = atom->nlocal;
int nghost = atom->nghost;
int newton_pair = force->newton_pair;
const int nlocal = atom->nlocal;
const int nghost = atom->nghost;
const int newton_pair = force->newton_pair;
const bool setToZero = false; // don't set the forward rates to zero.
@ -776,12 +776,23 @@ void FixRxKokkos<DeviceType>::pre_force(int vflag)
TimerType timer_localTemperature = getTimeStamp();
// Total counters from the ODE solvers.
CounterType Counters;
CounterType TotalCounters;
// Set data needed in the operators.
int *mask = atom->mask;
double *dpdTheta = atom->dpdTheta;
// ...
//int *mask = atom->mask;
//double *dpdTheta = atom->dpdTheta;
// Local references to the atomKK objects.
typename ArrayTypes<DeviceType>::t_efloat_1d d_dpdTheta = atomKK->k_dpdTheta.view<DeviceType>();
typename ArrayTypes<DeviceType>::t_float_2d d_dvector = atomKK->k_dvector.view<DeviceType>();
typename ArrayTypes<DeviceType>::t_int_1d d_mask = atomKK->k_mask.view<DeviceType>();
// Get up-to-date data.
atomKK->sync( execution_space, MASK_MASK | DVECTOR_MASK | DPDTHETA_MASK );
// Set some constants outside of the parallel_for
const double boltz = force->boltz;
const double t_stop = update->dt; // DPD time-step and integration length.
@ -796,7 +807,7 @@ void FixRxKokkos<DeviceType>::pre_force(int vflag)
Kokkos::parallel_reduce( nlocal, LAMMPS_LAMBDA(int i, CounterType &counter)
{
if (mask[i] & groupbit)
if (d_mask(i) & groupbit)
{
double *y = new double[8*nspecies];
double *rwork = y + nspecies;
@ -807,7 +818,7 @@ void FixRxKokkos<DeviceType>::pre_force(int vflag)
CounterType counter_i;
const double theta = (localTempFlag) ? dpdThetaLocal[i] : dpdTheta[i];
const double theta = (localTempFlag) ? dpdThetaLocal[i] : d_dpdTheta(i);
//Compute the reaction rate constants
for (int irxn = 0; irxn < nreactions; irxn++)
@ -819,14 +830,13 @@ void FixRxKokkos<DeviceType>::pre_force(int vflag)
userData.kFor[irxn] = d_kineticsData.Arr(irxn) *
pow(theta, d_kineticsData.nArr(irxn)) *
exp(-d_kineticsData.Ea(irxn) / boltz / theta);
//userData.kFor[irxn] = Arr[irxn]*pow(theta,nArr[irxn])*exp(-Ea[irxn]/boltz/theta);
}
}
// Update ConcOld and initialize the ODE solution vector y[].
for (int ispecies = 0; ispecies < nspecies; ispecies++){
const double tmp = atom->dvector[ispecies][i];
atom->dvector[ispecies+nspecies][i] = tmp;
const double tmp = d_dvector(ispecies, i);
d_dvector(ispecies+nspecies, i) = tmp;
y[ispecies] = tmp;
}
@ -845,7 +855,7 @@ void FixRxKokkos<DeviceType>::pre_force(int vflag)
error->one(FLERR,"Computed concentration in RK4 solver is < -10*DBL_EPSILON");
else if(y[ispecies] < MY_EPSILON)
y[ispecies] = 0.0;
atom->dvector[ispecies][i] = y[ispecies];
d_dvector(ispecies,i) = y[ispecies];
}
}
else if (odeIntegrationFlag == ODE_LAMMPS_RKF45)
@ -858,7 +868,7 @@ void FixRxKokkos<DeviceType>::pre_force(int vflag)
error->one(FLERR,"Computed concentration in RKF45 solver is < -1.0e-10");
else if(y[ispecies] < MY_EPSILON)
y[ispecies] = 0.0;
atom->dvector[ispecies][i] = y[ispecies];
d_dvector(ispecies,i) = y[ispecies];
}
//if (diagnosticFrequency == 1 && diagnosticCounterPerODE[StepSum] != NULL)
@ -877,13 +887,21 @@ void FixRxKokkos<DeviceType>::pre_force(int vflag)
} // if
} // parallel_for lambda-body
, Counters // reduction value
, TotalCounters // reduction value
);
TimerType timer_ODE = getTimeStamp();
// Communicate the updated momenta and velocities to all nodes
// Signal that dvector has been modified on this execution space.
atomKK->modified( execution_space, DVECTOR_MASK );
// Communicate the updated species data to all nodes
atomKK->sync ( Host, DVECTOR_MASK );
comm->forward_comm_fix(this);
atomKK->modified ( Host, DVECTOR_MASK );
if(localTempFlag) delete [] dpdThetaLocal;
TimerType timer_stop = getTimeStamp();
@ -894,12 +912,12 @@ void FixRxKokkos<DeviceType>::pre_force(int vflag)
getElapsedTime(timer_start, timer_stop),
getElapsedTime(timer_start, timer_localTemperature),
getElapsedTime(timer_localTemperature, timer_ODE),
getElapsedTime(timer_ODE, timer_stop), nlocal, Counters.nFuncs, Counters.nSteps);
getElapsedTime(timer_ODE, timer_stop), nlocal, TotalCounters.nFuncs, TotalCounters.nSteps);
// Warn the user if a failure was detected in the ODE solver.
if (Counters.nFails > 0){
if (TotalCounters.nFails > 0){
char sbuf[128];
sprintf(sbuf,"in FixRX::pre_force, ODE solver failed for %d atoms.", Counters.nFails);
sprintf(sbuf,"in FixRX::pre_force, ODE solver failed for %d atoms.", TotalCounters.nFails);
error->warning(FLERR, sbuf);
}