diff --git a/src/KOKKOS/pair_exp6_rx_kokkos.cpp b/src/KOKKOS/pair_exp6_rx_kokkos.cpp index 5c74cba8c7..312f1c6076 100644 --- a/src/KOKKOS/pair_exp6_rx_kokkos.cpp +++ b/src/KOKKOS/pair_exp6_rx_kokkos.cpp @@ -187,22 +187,25 @@ void PairExp6rxKokkos::compute(int eflag_in, int vflag_in) { const int np_total = nlocal + atom->nghost; - PairExp6ParamData.epsilon1 = typename AT::t_float_1d("PairExp6ParamData.epsilon1" ,np_total); - PairExp6ParamData.alpha1 = typename AT::t_float_1d("PairExp6ParamData.alpha1" ,np_total); - PairExp6ParamData.rm1 = typename AT::t_float_1d("PairExp6ParamData.rm1" ,np_total); - PairExp6ParamData.mixWtSite1 = typename AT::t_float_1d("PairExp6ParamData.mixWtSite1" ,np_total); - PairExp6ParamData.epsilon2 = typename AT::t_float_1d("PairExp6ParamData.epsilon2" ,np_total); - PairExp6ParamData.alpha2 = typename AT::t_float_1d("PairExp6ParamData.alpha2" ,np_total); - PairExp6ParamData.rm2 = typename AT::t_float_1d("PairExp6ParamData.rm2" ,np_total); - PairExp6ParamData.mixWtSite2 = typename AT::t_float_1d("PairExp6ParamData.mixWtSite2" ,np_total); - PairExp6ParamData.epsilonOld1 = typename AT::t_float_1d("PairExp6ParamData.epsilonOld1" ,np_total); - PairExp6ParamData.alphaOld1 = typename AT::t_float_1d("PairExp6ParamData.alphaOld1" ,np_total); - PairExp6ParamData.rmOld1 = typename AT::t_float_1d("PairExp6ParamData.rmOld1" ,np_total); - PairExp6ParamData.mixWtSite1old = typename AT::t_float_1d("PairExp6ParamData.mixWtSite1old",np_total); - PairExp6ParamData.epsilonOld2 = typename AT::t_float_1d("PairExp6ParamData.epsilonOld2" ,np_total); - PairExp6ParamData.alphaOld2 = typename AT::t_float_1d("PairExp6ParamData.alphaOld2" ,np_total); - PairExp6ParamData.rmOld2 = typename AT::t_float_1d("PairExp6ParamData.rmOld2" ,np_total); - PairExp6ParamData.mixWtSite2old = typename AT::t_float_1d("PairExp6ParamData.mixWtSite2old",np_total); + if (np_total > PairExp6ParamData.epsilon1.dimension_0()) { + PairExp6ParamData.epsilon1 = typename AT::t_float_1d("PairExp6ParamData.epsilon1" ,np_total); + PairExp6ParamData.alpha1 = typename AT::t_float_1d("PairExp6ParamData.alpha1" ,np_total); + PairExp6ParamData.rm1 = typename AT::t_float_1d("PairExp6ParamData.rm1" ,np_total); + PairExp6ParamData.mixWtSite1 = typename AT::t_float_1d("PairExp6ParamData.mixWtSite1" ,np_total); + PairExp6ParamData.epsilon2 = typename AT::t_float_1d("PairExp6ParamData.epsilon2" ,np_total); + PairExp6ParamData.alpha2 = typename AT::t_float_1d("PairExp6ParamData.alpha2" ,np_total); + PairExp6ParamData.rm2 = typename AT::t_float_1d("PairExp6ParamData.rm2" ,np_total); + PairExp6ParamData.mixWtSite2 = typename AT::t_float_1d("PairExp6ParamData.mixWtSite2" ,np_total); + PairExp6ParamData.epsilonOld1 = typename AT::t_float_1d("PairExp6ParamData.epsilonOld1" ,np_total); + PairExp6ParamData.alphaOld1 = typename AT::t_float_1d("PairExp6ParamData.alphaOld1" ,np_total); + PairExp6ParamData.rmOld1 = typename AT::t_float_1d("PairExp6ParamData.rmOld1" ,np_total); + PairExp6ParamData.mixWtSite1old = typename AT::t_float_1d("PairExp6ParamData.mixWtSite1old",np_total); + PairExp6ParamData.epsilonOld2 = typename AT::t_float_1d("PairExp6ParamData.epsilonOld2" ,np_total); + PairExp6ParamData.alphaOld2 = typename AT::t_float_1d("PairExp6ParamData.alphaOld2" ,np_total); + PairExp6ParamData.rmOld2 = typename AT::t_float_1d("PairExp6ParamData.rmOld2" ,np_total); + PairExp6ParamData.mixWtSite2old = typename AT::t_float_1d("PairExp6ParamData.mixWtSite2old",np_total); + } else + Kokkos::parallel_for(Kokkos::RangePolicy(0,np_total),*this); #ifdef KOKKOS_HAVE_CUDA Kokkos::parallel_for(Kokkos::RangePolicy(0,np_total),*this); @@ -352,6 +355,27 @@ void PairExp6rxKokkos::compute(int eflag_in, int vflag_in) //printf("PairExp6rxKokkos::compute %f %f\n", getElapsedTime(t_start, t_stop), getElapsedTime(t_mix_start, t_mix_stop)); } +template +KOKKOS_INLINE_FUNCTION +void PairExp6rxKokkos::operator()(TagPairExp6rxZeroMixingWeights, const int &i) const { + PairExp6ParamData.epsilon1[i] = 0.0; + PairExp6ParamData.alpha1[i] = 0.0; + PairExp6ParamData.rm1[i] = 0.0; + PairExp6ParamData.mixWtSite1[i] = 0.0; + PairExp6ParamData.epsilon2[i] = 0.0; + PairExp6ParamData.alpha2[i] = 0.0; + PairExp6ParamData.rm2[i] = 0.0; + PairExp6ParamData.mixWtSite2[i] = 0.0; + PairExp6ParamData.epsilonOld1[i] = 0.0; + PairExp6ParamData.alphaOld1[i] = 0.0; + PairExp6ParamData.rmOld1[i] = 0.0; + PairExp6ParamData.mixWtSite1old[i] = 0.0; + PairExp6ParamData.epsilonOld2[i] = 0.0; + PairExp6ParamData.alphaOld2[i] = 0.0; + PairExp6ParamData.rmOld2[i] = 0.0; + PairExp6ParamData.mixWtSite2old[i] = 0.0; +} + template KOKKOS_INLINE_FUNCTION void PairExp6rxKokkos::operator()(TagPairExp6rxgetMixingWeights, const int &i) const { diff --git a/src/KOKKOS/pair_exp6_rx_kokkos.h b/src/KOKKOS/pair_exp6_rx_kokkos.h index 9f38732c32..5e9fb4e3e3 100644 --- a/src/KOKKOS/pair_exp6_rx_kokkos.h +++ b/src/KOKKOS/pair_exp6_rx_kokkos.h @@ -52,6 +52,7 @@ struct PairExp6ParamDataTypeKokkos {} }; +struct TagPairExp6rxZeroMixingWeights{}; struct TagPairExp6rxgetMixingWeights{}; template @@ -76,6 +77,9 @@ class PairExp6rxKokkos : public PairExp6rx { void coeff(int, char **); void init_style(); + KOKKOS_INLINE_FUNCTION + void operator()(TagPairExp6rxZeroMixingWeights, const int&) const; + KOKKOS_INLINE_FUNCTION void operator()(TagPairExp6rxgetMixingWeights, const int&) const;