diff --git a/src/KOKKOS/sna_kokkos_impl.h b/src/KOKKOS/sna_kokkos_impl.h index 546c07ee11..c9c1bfbb99 100644 --- a/src/KOKKOS/sna_kokkos_impl.h +++ b/src/KOKKOS/sna_kokkos_impl.h @@ -2234,29 +2234,30 @@ template KOKKOS_INLINE_FUNCTION real_type SNAKokkos::compute_sfac(real_type r, real_type rcut, real_type sinner, real_type dinner) { + real_type sfac_outer; constexpr real_type one = static_cast(1.0); constexpr real_type zero = static_cast(0.0); constexpr real_type onehalf = static_cast(0.5); - if (switch_flag == 0) return one; + if (switch_flag == 0) sfac_outer = one; if (switch_flag == 1) { - if (r <= rmin0) return one; + if (r <= rmin0) sfac_outer = one; else if (r > rcut) return zero; else { real_type rcutfac = static_cast(MY_PI) / (rcut - rmin0); - if (switch_inner_flag == 0) - return onehalf * (cos((r - rmin0) * rcutfac) + one); - if (switch_inner_flag == 1) { - if (r >= sinner + dinner) - return onehalf * (cos((r - rmin0) * rcutfac) + one); - else if (r > sinner - dinner) { - real_type rcutfacinner = static_cast(MY_PI2) / dinner; - return onehalf * (cos((r - rmin0) * rcutfac) + one) * - onehalf * (one - cos(static_cast(MY_PI2) + (r - sinner) * rcutfacinner)); - } else return zero; - } - return zero; // dummy return + sfac_outer = onehalf * (cos((r - rmin0) * rcutfac) + one); } } + + if (switch_inner_flag == 0) return sfac_outer; + if (switch_inner_flag == 1) { + if (r >= sinner + dinner) + return sfac_outer; + else if (r > sinner - dinner) { + real_type rcutfac = static_cast(MY_PI2) / dinner; + return sfac_outer * + onehalf * (one - cos(static_cast(MY_PI2) + (r - sinner) * rcutfac)); + } else return zero; + } return zero; // dummy return } @@ -2266,41 +2267,86 @@ template KOKKOS_INLINE_FUNCTION real_type SNAKokkos::compute_dsfac(real_type r, real_type rcut, real_type sinner, real_type dinner) { + real_type sfac_outer, dsfac_outer, sfac_inner, dsfac_inner; constexpr real_type one = static_cast(1.0); constexpr real_type zero = static_cast(0.0); constexpr real_type onehalf = static_cast(0.5); - if (switch_flag == 0) return zero; + if (switch_flag == 0) dsfac_outer = zero; if (switch_flag == 1) { - if (r <= rmin0) return zero; + if (r <= rmin0) dsfac_outer = zero; else if (r > rcut) return zero; else { real_type rcutfac = static_cast(MY_PI) / (rcut - rmin0); - return -onehalf * sin((r - rmin0) * rcutfac) * rcutfac; + dsfac_outer = -onehalf * sin((r - rmin0) * rcutfac) * rcutfac; } } + + if (switch_inner_flag == 0) return dsfac_outer; + if (switch_inner_flag == 1) { + if (r >= sinner + dinner) + return dsfac_outer; + else if (r > sinner - dinner) { + + // calculate sfac_outer + + if (switch_flag == 0) sfac_outer = one; + if (switch_flag == 1) { + if (r <= rmin0) sfac_outer = one; + else if (r > rcut) sfac_outer = zero; + else { + real_type rcutfac = static_cast(MY_PI) / (rcut - rmin0); + sfac_outer = onehalf * (cos((r - rmin0) * rcutfac) + one); + } + } + + // calculate sfac_inner + + real_type rcutfac = static_cast(MY_PI2) / dinner; + sfac_inner = onehalf * (one - cos(static_cast(MY_PI2) + (r - sinner) * rcutfac)); + dsfac_inner = onehalf * rcutfac * sin(static_cast(MY_PI2) + (r - sinner) * rcutfac); + return dsfac_outer * sfac_inner + sfac_outer * dsfac_inner; + + } else return zero; + } return zero; // dummy return } template KOKKOS_INLINE_FUNCTION void SNAKokkos::compute_s_dsfac(const real_type r, const real_type rcut, const real_type sinner, const real_type dinner, real_type& sfac, real_type& dsfac) { + + real_type sfac_outer, dsfac_outer, sfac_inner, dsfac_inner; constexpr real_type one = static_cast(1.0); constexpr real_type zero = static_cast(0.0); constexpr real_type onehalf = static_cast(0.5); - if (switch_flag == 0) { sfac = zero; dsfac = zero; } + + if (switch_flag == 0) { sfac_outer = zero; dsfac_outer = zero; } else if (switch_flag == 1) { - if (r <= rmin0) { sfac = one; dsfac = zero; } - else if (r > rcut) { sfac = zero; dsfac = zero; } + if (r <= rmin0) { sfac_outer = one; dsfac_outer = zero; } + else if (r > rcut) { sfac = zero; dsfac = zero; return; } else { const real_type rcutfac = static_cast(MY_PI) / (rcut - rmin0); const real_type theta0 = (r - rmin0) * rcutfac; const real_type sn = sin(theta0); const real_type cs = cos(theta0); - sfac = onehalf * (cs + one); - dsfac = -onehalf * sn * rcutfac; - + sfac_outer = onehalf * (cs + one); + dsfac_outer = -onehalf * sn * rcutfac; } - } else { sfac = zero; dsfac = zero; } + } else { sfac = zero; dsfac = zero; return; } // dummy return + + if (switch_inner_flag == 0) { sfac = sfac_outer; dsfac = dsfac_outer; return; } + else if (switch_inner_flag == 1) { + if (r >= sinner + dinner) { sfac = sfac_outer; dsfac = dsfac_outer; return; } + else if (r > sinner - dinner) { + real_type rcutfac = static_cast(MY_PI2) / dinner; + sfac_inner = onehalf * (one - cos(static_cast(MY_PI2) + (r - sinner) * rcutfac)); + dsfac_inner = onehalf * rcutfac * sin(static_cast(MY_PI2) + (r - sinner) * rcutfac); + sfac = sfac_outer * sfac_inner; + dsfac = dsfac_outer * sfac_inner + sfac_outer * dsfac_inner; + return; + } else { sfac = zero; dsfac = zero; return; } + } else { sfac = zero; dsfac = zero; return; } // dummy return + } /* ---------------------------------------------------------------------- diff --git a/src/ML-SNAP/sna.cpp b/src/ML-SNAP/sna.cpp index 0bd66d0bce..2b86774d5b 100644 --- a/src/ML-SNAP/sna.cpp +++ b/src/ML-SNAP/sna.cpp @@ -1533,6 +1533,9 @@ void SNA::compute_ncoeff() double SNA::compute_sfac(double r, double rcut, double sinner, double dinner) { double sfac; + + // calculate sfac = sfac_outer + if (switch_flag == 0) sfac = 1.0; else if (r <= rmin0) sfac = 1.0; else if (r > rcut) sfac = 0.0; @@ -1541,6 +1544,8 @@ double SNA::compute_sfac(double r, double rcut, double sinner, double dinner) sfac = 0.5 * (cos((r - rmin0) * rcutfac) + 1.0); } + // calculate sfac *= sfac_inner, rarely visited + if (switch_inner_flag == 1 && r < sinner + dinner) { if (r > sinner - dinner) { double rcutfac = MY_PI2 / dinner; @@ -1555,33 +1560,38 @@ double SNA::compute_sfac(double r, double rcut, double sinner, double dinner) double SNA::compute_dsfac(double r, double rcut, double sinner, double dinner) { - double sfac, dsfac, sfac_inner, dsfac_inner; - if (switch_flag == 0) dsfac = 0.0; - else if (r <= rmin0) dsfac = 0.0; - else if (r > rcut) dsfac = 0.0; + double dsfac, sfac_outer, dsfac_outer, sfac_inner, dsfac_inner; + if (switch_flag == 0) dsfac_outer = 0.0; + else if (r <= rmin0) dsfac_outer = 0.0; + else if (r > rcut) dsfac_outer = 0.0; else { double rcutfac = MY_PI / (rcut - rmin0); - dsfac = -0.5 * sin((r - rmin0) * rcutfac) * rcutfac; + dsfac_outer = -0.5 * sin((r - rmin0) * rcutfac) * rcutfac; } - // duplicated computation, but rarely visited + // some duplicated computation, but rarely visited if (switch_inner_flag == 1 && r < sinner + dinner) { if (r > sinner - dinner) { - if (switch_flag == 0) sfac = 1.0; - else if (r <= rmin0) sfac = 1.0; - else if (r > rcut) sfac = 0.0; + + // calculate sfac_outer + + if (switch_flag == 0) sfac_outer = 1.0; + else if (r <= rmin0) sfac_outer = 1.0; + else if (r > rcut) sfac_outer = 0.0; else { double rcutfac = MY_PI / (rcut - rmin0); - sfac = 0.5 * (cos((r - rmin0) * rcutfac) + 1.0); + sfac_outer = 0.5 * (cos((r - rmin0) * rcutfac) + 1.0); } + // calculate sfac_inner + double rcutfac = MY_PI2 / dinner; sfac_inner = 0.5 * (1.0 - cos(MY_PI2 + (r - sinner) * rcutfac)); dsfac_inner = 0.5 * rcutfac * sin(MY_PI2 + (r - sinner) * rcutfac); - dsfac = dsfac*sfac_inner + sfac*dsfac_inner; + dsfac = dsfac_outer*sfac_inner + sfac_outer*dsfac_inner; } else dsfac = 0.0; - } + } else dsfac = dsfac_outer; return dsfac; }