try to speed up compute kernel

This commit is contained in:
Axel Kohlmeyer
2022-12-21 19:24:28 -05:00
parent 2cf1793a93
commit 6c5a698be4
2 changed files with 66 additions and 45 deletions

View File

@ -58,86 +58,104 @@ PairLepton::~PairLepton()
void PairLepton::compute(int eflag, int vflag)
{
int i, j, ii, jj, inum, jnum, itype, jtype;
double xtmp, ytmp, ztmp, delx, dely, delz, evdwl, fpair;
double rsq, factor_lj;
int *ilist, *jlist, *numneigh, **firstneigh;
evdwl = 0.0;
ev_init(eflag, vflag);
if (evflag) {
if (eflag) {
if (force->newton_pair)
eval<1, 1, 1>();
else
eval<1, 1, 0>();
} else {
if (force->newton_pair)
eval<1, 0, 1>();
else
eval<1, 0, 0>();
}
} else {
if (force->newton_pair)
eval<0, 0, 1>();
else
eval<0, 0, 0>();
}
if (vflag_fdotr) virial_fdotr_compute();
}
double **x = atom->x;
double **f = atom->f;
int *type = atom->type;
int nlocal = atom->nlocal;
double *special_lj = force->special_lj;
int newton_pair = force->newton_pair;
/* ---------------------------------------------------------------------- */
template <int EVFLAG, int EFLAG, int NEWTON_PAIR> void PairLepton::eval()
{
const double *const *const x = atom->x;
double *const *const f = atom->f;
const int *const type = atom->type;
const int nlocal = atom->nlocal;
const double *const special_lj = force->special_lj;
inum = list->inum;
ilist = list->ilist;
numneigh = list->numneigh;
firstneigh = list->firstneigh;
const int inum = list->inum;
const int *const ilist = list->ilist;
const int *const numneigh = list->numneigh;
const int *const *const firstneigh = list->firstneigh;
std::vector<LMP_Lepton::CompiledExpression> force;
std::vector<LMP_Lepton::CompiledExpression> epot;
for (const auto &expr : expressions) {
force.emplace_back(
LMP_Lepton::Parser::parse(expr).differentiate("r").createCompiledExpression());
if (eflag) epot.emplace_back(LMP_Lepton::Parser::parse(expr).createCompiledExpression());
if (EFLAG) epot.emplace_back(LMP_Lepton::Parser::parse(expr).createCompiledExpression());
}
// loop over neighbors of my atoms
for (ii = 0; ii < inum; ii++) {
i = ilist[ii];
xtmp = x[i][0];
ytmp = x[i][1];
ztmp = x[i][2];
itype = type[i];
jlist = firstneigh[i];
jnum = numneigh[i];
for (int ii = 0; ii < inum; ii++) {
const int i = ilist[ii];
const double xtmp = x[i][0];
const double ytmp = x[i][1];
const double ztmp = x[i][2];
const int itype = type[i];
const int *jlist = firstneigh[i];
const int jnum = numneigh[i];
double fxtmp, fytmp, fztmp;
fxtmp = fytmp = fztmp = 0.0;
for (jj = 0; jj < jnum; jj++) {
j = jlist[jj];
for (int jj = 0; jj < jnum; jj++) {
int j = jlist[jj];
const double factor_lj = special_lj[sbmask(j)];
j &= NEIGHMASK;
const int jtype = type[j];
delx = xtmp - x[j][0];
dely = ytmp - x[j][1];
delz = ztmp - x[j][2];
rsq = delx * delx + dely * dely + delz * delz;
jtype = type[j];
const double delx = xtmp - x[j][0];
const double dely = ytmp - x[j][1];
const double delz = ztmp - x[j][2];
const double rsq = delx * delx + dely * dely + delz * delz;
if (rsq < cutsq[itype][jtype]) {
const double r = sqrt(rsq);
const int idx = type2expression[itype][jtype];
double &r_for = force[idx].getVariableReference("r");
r_for = r;
fpair = -force[idx].evaluate() / r;
fpair *= factor_lj;
const double fpair = -force[idx].evaluate() / r * factor_lj;
f[i][0] += delx * fpair;
f[i][1] += dely * fpair;
f[i][2] += delz * fpair;
if (newton_pair || j < nlocal) {
fxtmp += delx * fpair;
fytmp += dely * fpair;
fztmp += delz * fpair;
if (NEWTON_PAIR || (j < nlocal)) {
f[j][0] -= delx * fpair;
f[j][1] -= dely * fpair;
f[j][2] -= delz * fpair;
}
if (eflag) {
double evdwl = 0.0;
if (EFLAG) {
double &r_pot = epot[idx].getVariableReference("r");
r_pot = r;
evdwl = factor_lj * epot[idx].evaluate();
} else
evdwl = 0.0;
}
if (evflag) ev_tally(i, j, nlocal, newton_pair, evdwl, 0.0, fpair, delx, dely, delz);
if (EVFLAG) ev_tally(i, j, nlocal, NEWTON_PAIR, evdwl, 0.0, fpair, delx, dely, delz);
}
}
f[i][0] += fxtmp;
f[i][1] += fytmp;
f[i][2] += fztmp;
}
if (vflag_fdotr) virial_fdotr_compute();
}
/* ----------------------------------------------------------------------

View File

@ -43,7 +43,7 @@ class PairLepton : public Pair {
void coeff(int, char **) override;
double init_one(int, int) override;
void write_data(FILE *) override;
void write_data_all(FILE *) override;
void write_data_all(FILE *) override;
double single(int, int, int, int, double, double, double, double &) override;
protected:
@ -52,6 +52,9 @@ class PairLepton : public Pair {
int **type2expression;
double cut_global;
private:
template <int EVFLAG, int EFLAG, int NEWTON_PAIR> void eval();
virtual void allocate();
};