Only transfer data arrays that are needed in each kernel

This commit is contained in:
Trung Nguyen
2021-10-02 00:56:15 -05:00
parent f4d3d3a2b5
commit 5a6426bf96
2 changed files with 55 additions and 100 deletions

View File

@ -350,8 +350,7 @@ int** BaseAmoebaT::precompute(const int ago, const int inum_full, const int nall
const bool eflag_in, const bool vflag_in,
const bool eatom, const bool vatom, int &host_start,
int **&ilist, int **&jnum, const double cpu_time,
bool &success, double *host_q, double *boxlo,
double *prd) {
bool &success, double *host_q, double *boxlo, double *prd) {
acc_timers();
if (eatom) _eflag=2;
else if (eflag_in) _eflag=1;
@ -509,7 +508,7 @@ int** BaseAmoebaT::compute_udirect2b(const int ago, const int inum_full,
int** firstneigh = nullptr;
cast_extra_data(host_amtype, host_amgroup, host_rpole, host_uind, host_uinp, host_pval);
atom->add_extra_data();
atom->add_extra_data();
// ------------------- Resize _fieldp array ------------------------
@ -647,30 +646,34 @@ void BaseAmoebaT::cast_extra_data(int* amtype, int* amgroup, double** rpole,
int n = 0;
int nstride = 4;
for (int i = 0; i < _nall; i++) {
int idx = n+i*nstride;
pextra[idx] = rpole[i][0];
pextra[idx+1] = rpole[i][1];
pextra[idx+2] = rpole[i][2];
pextra[idx+3] = rpole[i][3];
}
if (rpole) {
for (int i = 0; i < _nall; i++) {
int idx = n+i*nstride;
pextra[idx] = rpole[i][0];
pextra[idx+1] = rpole[i][1];
pextra[idx+2] = rpole[i][2];
pextra[idx+3] = rpole[i][3];
}
n += nstride*_nall;
for (int i = 0; i < _nall; i++) {
int idx = n+i*nstride;
pextra[idx] = rpole[i][4];
pextra[idx+1] = rpole[i][5];
pextra[idx+2] = rpole[i][6];
pextra[idx+3] = rpole[i][8];
}
n += nstride*_nall;
for (int i = 0; i < _nall; i++) {
int idx = n+i*nstride;
pextra[idx] = rpole[i][4];
pextra[idx+1] = rpole[i][5];
pextra[idx+2] = rpole[i][6];
pextra[idx+3] = rpole[i][8];
}
n += nstride*_nall;
for (int i = 0; i < _nall; i++) {
int idx = n+i*nstride;
pextra[idx] = rpole[i][9];
pextra[idx+1] = rpole[i][12];
pextra[idx+2] = (numtyp)amtype[i];
pextra[idx+3] = (numtyp)amgroup[i];
n += nstride*_nall;
for (int i = 0; i < _nall; i++) {
int idx = n+i*nstride;
pextra[idx] = rpole[i][9];
pextra[idx+1] = rpole[i][12];
pextra[idx+2] = (numtyp)amtype[i];
pextra[idx+3] = (numtyp)amgroup[i];
}
} else {
n += 2*nstride*_nall;
}
n += nstride*_nall;