streamline Lepton variable update process with ptr-vectors

This commit is contained in:
Shern Tee
2025-01-14 22:31:42 +10:00
parent 1f7533029b
commit 276b8d9c93

View File

@ -181,28 +181,29 @@ void FixEfieldLepton::post_force(int vflag)
auto dphi_z = parsed.differentiate("z").createCompiledExpression();
std::array<Lepton::CompiledExpression*, 3> dphis = {&dphi_x, &dphi_y, &dphi_z};
// check if reference to x, y, z exist
// array of vectors of ptrs to Lepton variable references
std::array<std::vector<double *>, 3> var_ref_ptrs{};
// fill ptr-vectors with Lepton refs as needed
const char* DIM_NAMES[] = {"x", "y", "z"};
std::array<bool, 3> phi_has_ref{}; // zero-init
if (atom->q_flag){
phi = parsed.createCompiledExpression();
for (size_t i = 0; i < 3; i++) {
for (size_t d = 0; d < 3; d++) {
try {
phi.getVariableReference(DIM_NAMES[i]);
phi_has_ref[i] = true;
}
catch (Lepton::Exception &) {
double *ptr = &(phi.getVariableReference(DIM_NAMES[d]));
var_ref_ptrs[d].push_back(ptr);
} catch (Lepton::Exception &) {
// do nothing
}
}
}
std::array<std::array<bool, 3>, 3> dphis_has_ref{};
bool e_uniform = true;
for (size_t j = 0; j < 3; j++)
for (size_t i = 0; i < 3; i++) {
for (size_t d = 0; d < 3; d++) {
try {
(*dphis[j]).getVariableReference(DIM_NAMES[i]);
dphis_has_ref[j][i] = true;
double *ptr = &((*dphis[j]).getVariableReference(DIM_NAMES[d]));
var_ref_ptrs[d].push_back(ptr);
e_uniform = false;
}
catch (Lepton::Exception &) {
@ -226,7 +227,7 @@ void FixEfieldLepton::post_force(int vflag)
double ex, ey, ez;
double fx, fy, fz;
double v[6], unwrap[3];
double v[6], unwrap[3], dstep[3];
double xf, yf, zf, xb, yb, zb;
double exf, eyf, ezf, exb, eyb, ezb;
double mu_norm, h_mu;
@ -241,12 +242,14 @@ void FixEfieldLepton::post_force(int vflag)
fx = fy = fz = 0.0;
domain->unmap(x[i], image[i], unwrap);
// evaluate e-field, used by q and mu
for (size_t j = 0; j < 3; j++) {
if (dphis_has_ref[j][0]) (*dphis[j]).getVariableReference("x") = unwrap[0];
if (dphis_has_ref[j][1]) (*dphis[j]).getVariableReference("y") = unwrap[1];
if (dphis_has_ref[j][2]) (*dphis[j]).getVariableReference("z") = unwrap[2];
// put unwrapped coords into Lepton variable refs
for (size_t d = 0; d < 3; d++) {
for (auto & var_ref_ptr : var_ref_ptrs[d]) {
*var_ref_ptr = unwrap[d];
}
}
// evaluate e-field, used by q and mu
ex = -dphi_x.evaluate();
ey = -dphi_y.evaluate();
ez = -dphi_z.evaluate();
@ -258,10 +261,6 @@ void FixEfieldLepton::post_force(int vflag)
fy = qe2f * q[i] * ey;
fz = qe2f * q[i] * ez;
// potential energy = q phi
if (phi_has_ref[0]) phi.getVariableReference("x") = unwrap[0];
if (phi_has_ref[1]) phi.getVariableReference("y") = unwrap[1];
if (phi_has_ref[2]) phi.getVariableReference("z") = unwrap[2];
fsum[0] += qe2f * q[i] * phi.evaluate();
}
@ -280,27 +279,27 @@ void FixEfieldLepton::post_force(int vflag)
// using central difference method
if (!e_uniform) {
h_mu = h / mu_norm;
xf = unwrap[0] + h_mu * mu[i][0];
yf = unwrap[1] + h_mu * mu[i][1];
zf = unwrap[2] + h_mu * mu[i][2];
for (size_t j = 0; j < 3; j++) {
if (dphis_has_ref[j][0]) (*dphis[j]).getVariableReference("x") = xf;
if (dphis_has_ref[j][1]) (*dphis[j]).getVariableReference("y") = yf;
if (dphis_has_ref[j][2]) (*dphis[j]).getVariableReference("z") = zf;
dstep[0] = h_mu * mu[i][0];
dstep[1] = h_mu * mu[i][1];
dstep[2] = h_mu * mu[i][2];
// one step forwards, two steps back ;)
for (size_t d = 0; d < 3; d++) {
for (auto & var_ref_ptr : var_ref_ptrs[d]) {
*var_ref_ptr += dstep[d];
}
}
exf = -dphi_x.evaluate();
exf = -dphi_x.evaluate();
eyf = -dphi_y.evaluate();
ezf = -dphi_z.evaluate();
xb = unwrap[0] - h_mu * mu[i][0];
yb = unwrap[1] - h_mu * mu[i][1];
zb = unwrap[2] - h_mu * mu[i][2];
for (size_t j = 0; j < 3; j++) {
if (dphis_has_ref[j][0]) (*dphis[j]).getVariableReference("x") = xb;
if (dphis_has_ref[j][1]) (*dphis[j]).getVariableReference("y") = yb;
if (dphis_has_ref[j][2]) (*dphis[j]).getVariableReference("z") = zb;
for (size_t d = 0; d < 3; d++) {
for (auto & var_ref_ptr : var_ref_ptrs[d]) {
*var_ref_ptr -= 2*dstep[d];
}
}
exb = -dphi_x.evaluate();
eyb = -dphi_y.evaluate();
ezb = -dphi_z.evaluate();