Ported to KOKKOS, untested

This commit is contained in:
Aidan Thompson
2022-04-22 16:43:19 -06:00
parent 4de9ab85ce
commit 2c71d0eea2
2 changed files with 92 additions and 36 deletions

View File

@ -2234,29 +2234,30 @@ template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
real_type SNAKokkos<DeviceType, real_type, vector_length>::compute_sfac(real_type r, real_type rcut, real_type sinner, real_type dinner)
{
real_type sfac_outer;
constexpr real_type one = static_cast<real_type>(1.0);
constexpr real_type zero = static_cast<real_type>(0.0);
constexpr real_type onehalf = static_cast<real_type>(0.5);
if (switch_flag == 0) return one;
if (switch_flag == 0) sfac_outer = one;
if (switch_flag == 1) {
if (r <= rmin0) return one;
if (r <= rmin0) sfac_outer = one;
else if (r > rcut) return zero;
else {
real_type rcutfac = static_cast<real_type>(MY_PI) / (rcut - rmin0);
if (switch_inner_flag == 0)
return onehalf * (cos((r - rmin0) * rcutfac) + one);
if (switch_inner_flag == 1) {
if (r >= sinner + dinner)
return onehalf * (cos((r - rmin0) * rcutfac) + one);
else if (r > sinner - dinner) {
real_type rcutfacinner = static_cast<real_type>(MY_PI2) / dinner;
return onehalf * (cos((r - rmin0) * rcutfac) + one) *
onehalf * (one - cos(static_cast<real_type>(MY_PI2) + (r - sinner) * rcutfacinner));
} else return zero;
}
return zero; // dummy return
sfac_outer = onehalf * (cos((r - rmin0) * rcutfac) + one);
}
}
if (switch_inner_flag == 0) return sfac_outer;
if (switch_inner_flag == 1) {
if (r >= sinner + dinner)
return sfac_outer;
else if (r > sinner - dinner) {
real_type rcutfac = static_cast<real_type>(MY_PI2) / dinner;
return sfac_outer *
onehalf * (one - cos(static_cast<real_type>(MY_PI2) + (r - sinner) * rcutfac));
} else return zero;
}
return zero; // dummy return
}
@ -2266,41 +2267,86 @@ template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
real_type SNAKokkos<DeviceType, real_type, vector_length>::compute_dsfac(real_type r, real_type rcut, real_type sinner, real_type dinner)
{
real_type sfac_outer, dsfac_outer, sfac_inner, dsfac_inner;
constexpr real_type one = static_cast<real_type>(1.0);
constexpr real_type zero = static_cast<real_type>(0.0);
constexpr real_type onehalf = static_cast<real_type>(0.5);
if (switch_flag == 0) return zero;
if (switch_flag == 0) dsfac_outer = zero;
if (switch_flag == 1) {
if (r <= rmin0) return zero;
if (r <= rmin0) dsfac_outer = zero;
else if (r > rcut) return zero;
else {
real_type rcutfac = static_cast<real_type>(MY_PI) / (rcut - rmin0);
return -onehalf * sin((r - rmin0) * rcutfac) * rcutfac;
dsfac_outer = -onehalf * sin((r - rmin0) * rcutfac) * rcutfac;
}
}
if (switch_inner_flag == 0) return dsfac_outer;
if (switch_inner_flag == 1) {
if (r >= sinner + dinner)
return dsfac_outer;
else if (r > sinner - dinner) {
// calculate sfac_outer
if (switch_flag == 0) sfac_outer = one;
if (switch_flag == 1) {
if (r <= rmin0) sfac_outer = one;
else if (r > rcut) sfac_outer = zero;
else {
real_type rcutfac = static_cast<real_type>(MY_PI) / (rcut - rmin0);
sfac_outer = onehalf * (cos((r - rmin0) * rcutfac) + one);
}
}
// calculate sfac_inner
real_type rcutfac = static_cast<real_type>(MY_PI2) / dinner;
sfac_inner = onehalf * (one - cos(static_cast<real_type>(MY_PI2) + (r - sinner) * rcutfac));
dsfac_inner = onehalf * rcutfac * sin(static_cast<real_type>(MY_PI2) + (r - sinner) * rcutfac);
return dsfac_outer * sfac_inner + sfac_outer * dsfac_inner;
} else return zero;
}
return zero; // dummy return
}
template<class DeviceType, typename real_type, int vector_length>
KOKKOS_INLINE_FUNCTION
void SNAKokkos<DeviceType, real_type, vector_length>::compute_s_dsfac(const real_type r, const real_type rcut, const real_type sinner, const real_type dinner, real_type& sfac, real_type& dsfac) {
real_type sfac_outer, dsfac_outer, sfac_inner, dsfac_inner;
constexpr real_type one = static_cast<real_type>(1.0);
constexpr real_type zero = static_cast<real_type>(0.0);
constexpr real_type onehalf = static_cast<real_type>(0.5);
if (switch_flag == 0) { sfac = zero; dsfac = zero; }
if (switch_flag == 0) { sfac_outer = zero; dsfac_outer = zero; }
else if (switch_flag == 1) {
if (r <= rmin0) { sfac = one; dsfac = zero; }
else if (r > rcut) { sfac = zero; dsfac = zero; }
if (r <= rmin0) { sfac_outer = one; dsfac_outer = zero; }
else if (r > rcut) { sfac = zero; dsfac = zero; return; }
else {
const real_type rcutfac = static_cast<real_type>(MY_PI) / (rcut - rmin0);
const real_type theta0 = (r - rmin0) * rcutfac;
const real_type sn = sin(theta0);
const real_type cs = cos(theta0);
sfac = onehalf * (cs + one);
dsfac = -onehalf * sn * rcutfac;
sfac_outer = onehalf * (cs + one);
dsfac_outer = -onehalf * sn * rcutfac;
}
} else { sfac = zero; dsfac = zero; }
} else { sfac = zero; dsfac = zero; return; } // dummy return
if (switch_inner_flag == 0) { sfac = sfac_outer; dsfac = dsfac_outer; return; }
else if (switch_inner_flag == 1) {
if (r >= sinner + dinner) { sfac = sfac_outer; dsfac = dsfac_outer; return; }
else if (r > sinner - dinner) {
real_type rcutfac = static_cast<real_type>(MY_PI2) / dinner;
sfac_inner = onehalf * (one - cos(static_cast<real_type>(MY_PI2) + (r - sinner) * rcutfac));
dsfac_inner = onehalf * rcutfac * sin(static_cast<real_type>(MY_PI2) + (r - sinner) * rcutfac);
sfac = sfac_outer * sfac_inner;
dsfac = dsfac_outer * sfac_inner + sfac_outer * dsfac_inner;
return;
} else { sfac = zero; dsfac = zero; return; }
} else { sfac = zero; dsfac = zero; return; } // dummy return
}
/* ----------------------------------------------------------------------

View File

@ -1533,6 +1533,9 @@ void SNA::compute_ncoeff()
double SNA::compute_sfac(double r, double rcut, double sinner, double dinner)
{
double sfac;
// calculate sfac = sfac_outer
if (switch_flag == 0) sfac = 1.0;
else if (r <= rmin0) sfac = 1.0;
else if (r > rcut) sfac = 0.0;
@ -1541,6 +1544,8 @@ double SNA::compute_sfac(double r, double rcut, double sinner, double dinner)
sfac = 0.5 * (cos((r - rmin0) * rcutfac) + 1.0);
}
// calculate sfac *= sfac_inner, rarely visited
if (switch_inner_flag == 1 && r < sinner + dinner) {
if (r > sinner - dinner) {
double rcutfac = MY_PI2 / dinner;
@ -1555,33 +1560,38 @@ double SNA::compute_sfac(double r, double rcut, double sinner, double dinner)
double SNA::compute_dsfac(double r, double rcut, double sinner, double dinner)
{
double sfac, dsfac, sfac_inner, dsfac_inner;
if (switch_flag == 0) dsfac = 0.0;
else if (r <= rmin0) dsfac = 0.0;
else if (r > rcut) dsfac = 0.0;
double dsfac, sfac_outer, dsfac_outer, sfac_inner, dsfac_inner;
if (switch_flag == 0) dsfac_outer = 0.0;
else if (r <= rmin0) dsfac_outer = 0.0;
else if (r > rcut) dsfac_outer = 0.0;
else {
double rcutfac = MY_PI / (rcut - rmin0);
dsfac = -0.5 * sin((r - rmin0) * rcutfac) * rcutfac;
dsfac_outer = -0.5 * sin((r - rmin0) * rcutfac) * rcutfac;
}
// duplicated computation, but rarely visited
// some duplicated computation, but rarely visited
if (switch_inner_flag == 1 && r < sinner + dinner) {
if (r > sinner - dinner) {
if (switch_flag == 0) sfac = 1.0;
else if (r <= rmin0) sfac = 1.0;
else if (r > rcut) sfac = 0.0;
// calculate sfac_outer
if (switch_flag == 0) sfac_outer = 1.0;
else if (r <= rmin0) sfac_outer = 1.0;
else if (r > rcut) sfac_outer = 0.0;
else {
double rcutfac = MY_PI / (rcut - rmin0);
sfac = 0.5 * (cos((r - rmin0) * rcutfac) + 1.0);
sfac_outer = 0.5 * (cos((r - rmin0) * rcutfac) + 1.0);
}
// calculate sfac_inner
double rcutfac = MY_PI2 / dinner;
sfac_inner = 0.5 * (1.0 - cos(MY_PI2 + (r - sinner) * rcutfac));
dsfac_inner = 0.5 * rcutfac * sin(MY_PI2 + (r - sinner) * rcutfac);
dsfac = dsfac*sfac_inner + sfac*dsfac_inner;
dsfac = dsfac_outer*sfac_inner + sfac_outer*dsfac_inner;
} else dsfac = 0.0;
}
} else dsfac = dsfac_outer;
return dsfac;
}