From f62d1b5d55692be8df076fa78453fce7576d7590 Mon Sep 17 00:00:00 2001 From: alphataubio Date: Thu, 1 Aug 2024 18:55:16 -0400 Subject: [PATCH] complete rewrite of kokkos version MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - array of structs set[i] from base class, converted to view- Host, converted to execution_space - atom->nlocalĀ converted to atomKK->nlocal- domain converted to domainKK- class now templated for DeviceType- SCALE not implemented in kokkos version ... actually by the time i was done it was a complete rewrite of the kokkos version --- src/KOKKOS/fix_deform_kokkos.cpp | 446 +++++++++++++++++-------------- src/KOKKOS/fix_deform_kokkos.h | 29 +- 2 files changed, 271 insertions(+), 204 deletions(-) diff --git a/src/KOKKOS/fix_deform_kokkos.cpp b/src/KOKKOS/fix_deform_kokkos.cpp index 90c4380da9..eb080eefe5 100644 --- a/src/KOKKOS/fix_deform_kokkos.cpp +++ b/src/KOKKOS/fix_deform_kokkos.cpp @@ -13,7 +13,7 @@ ------------------------------------------------------------------------- */ /* ---------------------------------------------------------------------- - Contributing author: Pieter in 't Veld (SNL) + Contributing author: Mitch Murphy (alphataubio@gmail.com) ------------------------------------------------------------------------- */ #include "fix_deform_kokkos.h" @@ -26,6 +26,7 @@ #include "irregular.h" #include "kspace.h" #include "math_const.h" +#include "memory_kokkos.h" #include "modify.h" #include "update.h" #include "variable.h" @@ -36,18 +37,38 @@ using namespace LAMMPS_NS; using namespace FixConst; using namespace MathConst; -enum{NONE=0,FINAL,DELTA,SCALE,VEL,ERATE,TRATE,VOLUME,WIGGLE,VARIABLE}; -enum{ONE_FROM_ONE,ONE_FROM_TWO,TWO_FROM_ONE}; - /* ---------------------------------------------------------------------- */ -FixDeformKokkos::FixDeformKokkos(LAMMPS *lmp, int narg, char **arg) : FixDeform(lmp, narg, arg) +template +FixDeformKokkos::FixDeformKokkos(LAMMPS *lmp, int narg, char **arg) : FixDeform(lmp, narg, arg) { kokkosable = 1; + atomKK = (AtomKokkos *) atom; domainKK = (DomainKokkos *) domain; + execution_space = ExecutionSpaceFromDevice::space; datamask_read = EMPTY_MASK; datamask_modify = EMPTY_MASK; + + memoryKK->create_kokkos(k_set,6,"fix_deform:set"); + d_set = k_set.template view(); +} + +template +FixDeformKokkos::~FixDeformKokkos() +{ + if (copymode) return; + memoryKK->destroy_kokkos(k_set,set); +} + + +template +void FixDeformKokkos::init() +{ + FixDeform::init(); + memcpy((void *)k_set.h_view.data(), (void *)set, sizeof(Set)*6); + k_set.template modify(); + k_set.template sync(); } /* ---------------------------------------------------------------------- @@ -60,131 +81,105 @@ FixDeformKokkos::FixDeformKokkos(LAMMPS *lmp, int narg, char **arg) : FixDeform( image flags to new values, making eqs in doc of Domain:image_flip incorrect ------------------------------------------------------------------------- */ -void FixDeformKokkos::pre_exchange() +template +void FixDeformKokkos::pre_exchange() { if (flip == 0) return; - domain->yz = set[3].tilt_target = set[3].tilt_flip; - domain->xz = set[4].tilt_target = set[4].tilt_flip; - domain->xy = set[5].tilt_target = set[5].tilt_flip; - domain->set_global_box(); - domain->set_local_box(); + domainKK->yz = d_set[3].tilt_target = d_set[3].tilt_flip; + domainKK->xz = d_set[4].tilt_target = d_set[4].tilt_flip; + domainKK->xy = d_set[5].tilt_target = d_set[5].tilt_flip; + domainKK->set_global_box(); + domainKK->set_local_box(); - domainKK->image_flip(flipxy,flipxz,flipyz); + domainKK->image_flip(flipxy, flipxz, flipyz); + // FIXME: just replace with domainKK->remap_all(), is this correct ? + //double **x = atom->x; + //imageint *image = atom->image; + //int nlocal = atom->nlocal; + //for (int i = 0; i < nlocal; i++) domain->remap(x[i],image[i]); + //domain->x2lamda(atom->nlocal); + //irregular->migrate_atoms(); + //domain->lamda2x(atom->nlocal); domainKK->remap_all(); - domainKK->x2lamda(atom->nlocal); - atomKK->sync(Host,ALL_MASK); - irregular->migrate_atoms(); - atomKK->modified(Host,ALL_MASK); - domainKK->lamda2x(atom->nlocal); - flip = 0; } /* ---------------------------------------------------------------------- */ -void FixDeformKokkos::end_of_step() +template +void FixDeformKokkos::end_of_step() { - int i; - - double delta = update->ntimestep - update->beginstep; - if (delta != 0.0) delta /= update->endstep - update->beginstep; - // wrap variable evaluations with clear/add if (varflag) modify->clearstep_compute(); - // set new box size + // set new box size for strain-based dims + + apply_strain(); + + // set new box size for VOLUME dims that are linked to other dims + // NOTE: still need to set h_rate for these dims + + apply_volume(); + + if (varflag) modify->addstep_compute(update->ntimestep + nevery); + + update_domain(); + + // redo KSpace coeffs since box has changed + + if (kspace_flag) force->kspace->setup(); +} + +/* ---------------------------------------------------------------------- + apply strain controls +------------------------------------------------------------------------- */ + +template +KOKKOS_INLINE_FUNCTION +void FixDeformKokkos::apply_strain() +{ // for NONE, target is current box size // for TRATE, set target directly based on current time, also set h_rate // for WIGGLE, set target directly based on current time, also set h_rate // for VARIABLE, set target directly via variable eval, also set h_rate // for others except VOLUME, target is linear value between start and stop - for (i = 0; i < 3; i++) { - if (set[i].style == NONE) { - set[i].lo_target = domain->boxlo[i]; - set[i].hi_target = domain->boxhi[i]; - } else if (set[i].style == TRATE) { + double delta = update->ntimestep - update->beginstep; + if (delta != 0.0) delta /= update->endstep - update->beginstep; + + for (int i = 0; i < 3; i++) { + if (d_set[i].style == NONE) { + d_set[i].lo_target = domainKK->boxlo[i]; + d_set[i].hi_target = domainKK->boxhi[i]; + } else if (d_set[i].style == TRATE) { double delt = (update->ntimestep - update->beginstep) * update->dt; - set[i].lo_target = 0.5*(set[i].lo_start+set[i].hi_start) - - 0.5*((set[i].hi_start-set[i].lo_start) * exp(set[i].rate*delt)); - set[i].hi_target = 0.5*(set[i].lo_start+set[i].hi_start) + - 0.5*((set[i].hi_start-set[i].lo_start) * exp(set[i].rate*delt)); - h_rate[i] = set[i].rate * domain->h[i]; - h_ratelo[i] = -0.5*h_rate[i]; - } else if (set[i].style == WIGGLE) { + double shift = 0.5 * ((d_set[i].hi_start - d_set[i].lo_start) * exp(d_set[i].rate * delt)); + d_set[i].lo_target = 0.5 * (d_set[i].lo_start + d_set[i].hi_start) - shift; + d_set[i].hi_target = 0.5 * (d_set[i].lo_start + d_set[i].hi_start) + shift; + h_rate[i] = d_set[i].rate * domainKK->h[i]; + h_ratelo[i] = -0.5 * h_rate[i]; + } else if (d_set[i].style == WIGGLE) { double delt = (update->ntimestep - update->beginstep) * update->dt; - set[i].lo_target = set[i].lo_start - - 0.5*set[i].amplitude * sin(MY_2PI*delt/set[i].tperiod); - set[i].hi_target = set[i].hi_start + - 0.5*set[i].amplitude * sin(MY_2PI*delt/set[i].tperiod); - h_rate[i] = MY_2PI/set[i].tperiod * set[i].amplitude * - cos(MY_2PI*delt/set[i].tperiod); - h_ratelo[i] = -0.5*h_rate[i]; - } else if (set[i].style == VARIABLE) { - double del = input->variable->compute_equal(set[i].hvar); - set[i].lo_target = set[i].lo_start - 0.5*del; - set[i].hi_target = set[i].hi_start + 0.5*del; - h_rate[i] = input->variable->compute_equal(set[i].hratevar); - h_ratelo[i] = -0.5*h_rate[i]; - } else if (set[i].style != VOLUME) { - set[i].lo_target = set[i].lo_start + - delta*(set[i].lo_stop - set[i].lo_start); - set[i].hi_target = set[i].hi_start + - delta*(set[i].hi_stop - set[i].hi_start); - } - } - - // set new box size for VOLUME dims that are linked to other dims - // NOTE: still need to set h_rate for these dims - - for (i = 0; i < 3; i++) { - if (set[i].style != VOLUME) continue; - - if (set[i].substyle == ONE_FROM_ONE) { - set[i].lo_target = 0.5*(set[i].lo_start+set[i].hi_start) - - 0.5*(set[i].vol_start / - (set[set[i].dynamic1].hi_target - - set[set[i].dynamic1].lo_target) / - (set[set[i].fixed].hi_start-set[set[i].fixed].lo_start)); - set[i].hi_target = 0.5*(set[i].lo_start+set[i].hi_start) + - 0.5*(set[i].vol_start / - (set[set[i].dynamic1].hi_target - - set[set[i].dynamic1].lo_target) / - (set[set[i].fixed].hi_start-set[set[i].fixed].lo_start)); - - } else if (set[i].substyle == ONE_FROM_TWO) { - set[i].lo_target = 0.5*(set[i].lo_start+set[i].hi_start) - - 0.5*(set[i].vol_start / - (set[set[i].dynamic1].hi_target - - set[set[i].dynamic1].lo_target) / - (set[set[i].dynamic2].hi_target - - set[set[i].dynamic2].lo_target)); - set[i].hi_target = 0.5*(set[i].lo_start+set[i].hi_start) + - 0.5*(set[i].vol_start / - (set[set[i].dynamic1].hi_target - - set[set[i].dynamic1].lo_target) / - (set[set[i].dynamic2].hi_target - - set[set[i].dynamic2].lo_target)); - - } else if (set[i].substyle == TWO_FROM_ONE) { - set[i].lo_target = 0.5*(set[i].lo_start+set[i].hi_start) - - 0.5*sqrt(set[i].vol_start / - (set[set[i].dynamic1].hi_target - - set[set[i].dynamic1].lo_target) / - (set[set[i].fixed].hi_start - - set[set[i].fixed].lo_start) * - (set[i].hi_start - set[i].lo_start)); - set[i].hi_target = 0.5*(set[i].lo_start+set[i].hi_start) + - 0.5*sqrt(set[i].vol_start / - (set[set[i].dynamic1].hi_target - - set[set[i].dynamic1].lo_target) / - (set[set[i].fixed].hi_start - - set[set[i].fixed].lo_start) * - (set[i].hi_start - set[i].lo_start)); + double shift = 0.5 * d_set[i].amplitude * sin(MY_2PI * delt / d_set[i].tperiod); + d_set[i].lo_target = d_set[i].lo_start - shift; + d_set[i].hi_target = d_set[i].hi_start + shift; + h_rate[i] = MY_2PI / d_set[i].tperiod * d_set[i].amplitude * + cos(MY_2PI * delt / d_set[i].tperiod); + h_ratelo[i] = -0.5 * h_rate[i]; + } else if (d_set[i].style == VARIABLE) { + double del = input->variable->compute_equal(d_set[i].hvar); + d_set[i].lo_target = d_set[i].lo_start - 0.5 * del; + d_set[i].hi_target = d_set[i].hi_start + 0.5 * del; + h_rate[i] = input->variable->compute_equal(d_set[i].hratevar); + h_ratelo[i] = -0.5 * h_rate[i]; + } else if (d_set[i].style == FINAL || d_set[i].style == DELTA || d_set[i].style == SCALE || + d_set[i].style == VEL || d_set[i].style == ERATE) { + d_set[i].lo_target = d_set[i].lo_start + delta * (d_set[i].lo_stop - d_set[i].lo_start); + d_set[i].hi_target = d_set[i].hi_start + delta * (d_set[i].hi_stop - d_set[i].hi_start); } } @@ -196,55 +191,101 @@ void FixDeformKokkos::end_of_step() // for other styles, target is linear value between start and stop values if (triclinic) { - double *h = domain->h; - - for (i = 3; i < 6; i++) { - if (set[i].style == NONE) { - if (i == 5) set[i].tilt_target = domain->xy; - else if (i == 4) set[i].tilt_target = domain->xz; - else if (i == 3) set[i].tilt_target = domain->yz; - } else if (set[i].style == TRATE) { + for (int i = 3; i < 6; i++) { + if (d_set[i].style == NONE) { + if (i == 5) d_set[i].tilt_target = domainKK->xy; + else if (i == 4) d_set[i].tilt_target = domainKK->xz; + else if (i == 3) d_set[i].tilt_target = domainKK->yz; + } else if (d_set[i].style == TRATE) { double delt = (update->ntimestep - update->beginstep) * update->dt; - set[i].tilt_target = set[i].tilt_start * exp(set[i].rate*delt); - h_rate[i] = set[i].rate * domain->h[i]; - } else if (set[i].style == WIGGLE) { + d_set[i].tilt_target = d_set[i].tilt_start * exp(d_set[i].rate * delt); + h_rate[i] = d_set[i].rate * domainKK->h[i]; + } else if (d_set[i].style == WIGGLE) { double delt = (update->ntimestep - update->beginstep) * update->dt; - set[i].tilt_target = set[i].tilt_start + - set[i].amplitude * sin(MY_2PI*delt/set[i].tperiod); - h_rate[i] = MY_2PI/set[i].tperiod * set[i].amplitude * - cos(MY_2PI*delt/set[i].tperiod); - } else if (set[i].style == VARIABLE) { - double delta_tilt = input->variable->compute_equal(set[i].hvar); - set[i].tilt_target = set[i].tilt_start + delta_tilt; - h_rate[i] = input->variable->compute_equal(set[i].hratevar); + d_set[i].tilt_target = d_set[i].tilt_start + + d_set[i].amplitude * sin(MY_2PI * delt / d_set[i].tperiod); + h_rate[i] = MY_2PI / d_set[i].tperiod * d_set[i].amplitude * + cos(MY_2PI * delt / d_set[i].tperiod); + } else if (d_set[i].style == VARIABLE) { + double delta_tilt = input->variable->compute_equal(d_set[i].hvar); + d_set[i].tilt_target = d_set[i].tilt_start + delta_tilt; + h_rate[i] = input->variable->compute_equal(d_set[i].hratevar); } else { - set[i].tilt_target = set[i].tilt_start + - delta*(set[i].tilt_stop - set[i].tilt_start); + d_set[i].tilt_target = d_set[i].tilt_start + delta * (d_set[i].tilt_stop - d_set[i].tilt_start); } + } + } +} - // tilt_target can be large positive or large negative value - // add/subtract box lengths until tilt_target is closest to current value +/* ---------------------------------------------------------------------- + apply volume controls +------------------------------------------------------------------------- */ +template +KOKKOS_INLINE_FUNCTION +void FixDeformKokkos::apply_volume() +{ + for (int i = 0; i < 3; i++) { + if (d_set[i].style != VOLUME) continue; + + int dynamic1 = d_set[i].dynamic1; + int dynamic2 = d_set[i].dynamic2; + int fixed = d_set[i].fixed; + double v0 = d_set[i].vol_start; + double shift = 0.0; + + if (d_set[i].substyle == ONE_FROM_ONE) { + shift = 0.5 * (v0 / (d_set[dynamic1].hi_target - d_set[dynamic1].lo_target) / + (d_set[fixed].hi_start - d_set[fixed].lo_start)); + } else if (d_set[i].substyle == ONE_FROM_TWO) { + shift = 0.5 * (v0 / (d_set[dynamic1].hi_target - d_set[dynamic1].lo_target) / + (d_set[dynamic2].hi_target - d_set[dynamic2].lo_target)); + } else if (d_set[i].substyle == TWO_FROM_ONE) { + shift = 0.5 * sqrt(v0 * (d_set[i].hi_start - d_set[i].lo_start) / + (d_set[dynamic1].hi_target - d_set[dynamic1].lo_target) / + (d_set[fixed].hi_start - d_set[fixed].lo_start)); + } + + h_rate[i] = (2.0 * shift / (domainKK->boxhi[i] - domainKK->boxlo[i]) - 1.0) / update->dt; + h_ratelo[i] = -0.5 * h_rate[i]; + + d_set[i].lo_target = 0.5 * (d_set[i].lo_start + d_set[i].hi_start) - shift; + d_set[i].hi_target = 0.5 * (d_set[i].lo_start + d_set[i].hi_start) + shift; + } +} + +/* ---------------------------------------------------------------------- + Update box domain +------------------------------------------------------------------------- */ + +template +KOKKOS_INLINE_FUNCTION +void FixDeformKokkos::update_domain() +{ + // tilt_target can be large positive or large negative value + // add/subtract box lengths until tilt_target is closest to current value + + if (triclinic) { + double *h = domainKK->h; + for (int i = 3; i < 6; i++) { int idenom = 0; if (i == 5) idenom = 0; else if (i == 4) idenom = 0; else if (i == 3) idenom = 1; - double denom = set[idenom].hi_target - set[idenom].lo_target; + double denom = d_set[idenom].hi_target - d_set[idenom].lo_target; - double current = h[i]/h[idenom]; + double current = h[i] / h[idenom]; - while (set[i].tilt_target/denom - current > 0.0) - set[i].tilt_target -= denom; - while (set[i].tilt_target/denom - current < 0.0) - set[i].tilt_target += denom; - if (fabs(set[i].tilt_target/denom - 1.0 - current) < - fabs(set[i].tilt_target/denom - current)) - set[i].tilt_target -= denom; + while (d_set[i].tilt_target / denom - current > 0.0) + d_set[i].tilt_target -= denom; + while (d_set[i].tilt_target / denom - current < 0.0) + d_set[i].tilt_target += denom; + if (fabs(d_set[i].tilt_target / denom - 1.0 - current) < + fabs(d_set[i].tilt_target / denom - current)) + d_set[i].tilt_target -= denom; } } - if (varflag) modify->addstep_compute(update->ntimestep + nevery); - // if any tilt ratios exceed 0.5, set flip = 1 and compute new tilt values // do not flip in x or y if non-periodic (can tilt but not flip) // this is b/c the box length would be changed (dramatically) by flip @@ -255,48 +296,48 @@ void FixDeformKokkos::end_of_step() // flip is performed on next timestep, before reneighboring in pre-exchange() if (triclinic && flipflag) { - double xprd = set[0].hi_target - set[0].lo_target; - double yprd = set[1].hi_target - set[1].lo_target; + double xprd = d_set[0].hi_target - d_set[0].lo_target; + double yprd = d_set[1].hi_target - d_set[1].lo_target; double xprdinv = 1.0 / xprd; double yprdinv = 1.0 / yprd; - if (set[3].tilt_target*yprdinv < -0.5 || - set[3].tilt_target*yprdinv > 0.5 || - set[4].tilt_target*xprdinv < -0.5 || - set[4].tilt_target*xprdinv > 0.5 || - set[5].tilt_target*xprdinv < -0.5 || - set[5].tilt_target*xprdinv > 0.5) { - set[3].tilt_flip = set[3].tilt_target; - set[4].tilt_flip = set[4].tilt_target; - set[5].tilt_flip = set[5].tilt_target; + if (d_set[3].tilt_target * yprdinv < -0.5 || + d_set[3].tilt_target * yprdinv > 0.5 || + d_set[4].tilt_target * xprdinv < -0.5 || + d_set[4].tilt_target * xprdinv > 0.5 || + d_set[5].tilt_target * xprdinv < -0.5 || + d_set[5].tilt_target * xprdinv > 0.5) { + d_set[3].tilt_flip = d_set[3].tilt_target; + d_set[4].tilt_flip = d_set[4].tilt_target; + d_set[5].tilt_flip = d_set[5].tilt_target; flipxy = flipxz = flipyz = 0; - if (domain->yperiodic) { - if (set[3].tilt_flip*yprdinv < -0.5) { - set[3].tilt_flip += yprd; - set[4].tilt_flip += set[5].tilt_flip; + if (domainKK->yperiodic) { + if (d_set[3].tilt_flip * yprdinv < -0.5) { + d_set[3].tilt_flip += yprd; + d_set[4].tilt_flip += d_set[5].tilt_flip; flipyz = 1; - } else if (set[3].tilt_flip*yprdinv > 0.5) { - set[3].tilt_flip -= yprd; - set[4].tilt_flip -= set[5].tilt_flip; + } else if (d_set[3].tilt_flip * yprdinv > 0.5) { + d_set[3].tilt_flip -= yprd; + d_set[4].tilt_flip -= d_set[5].tilt_flip; flipyz = -1; } } - if (domain->xperiodic) { - if (set[4].tilt_flip*xprdinv < -0.5) { - set[4].tilt_flip += xprd; + if (domainKK->xperiodic) { + if (d_set[4].tilt_flip * xprdinv < -0.5) { + d_set[4].tilt_flip += xprd; flipxz = 1; } - if (set[4].tilt_flip*xprdinv > 0.5) { - set[4].tilt_flip -= xprd; + if (d_set[4].tilt_flip * xprdinv > 0.5) { + d_set[4].tilt_flip -= xprd; flipxz = -1; } - if (set[5].tilt_flip*xprdinv < -0.5) { - set[5].tilt_flip += xprd; + if (d_set[5].tilt_flip * xprdinv < -0.5) { + d_set[5].tilt_flip += xprd; flipxy = 1; } - if (set[5].tilt_flip*xprdinv > 0.5) { - set[5].tilt_flip -= xprd; + if (d_set[5].tilt_flip * xprdinv > 0.5) { + d_set[5].tilt_flip -= xprd; flipxy = -1; } } @@ -310,60 +351,65 @@ void FixDeformKokkos::end_of_step() // convert atoms and rigid bodies to lamda coords if (remapflag == Domain::X_REMAP) { - int nlocal = atom->nlocal; + atomKK->sync(execution_space, X_MASK | MASK_MASK ); + d_x = atomKK->k_x.template view(); + d_mask = atomKK->k_mask.template view(); + int nlocal = atomKK->nlocal; - domainKK->x2lamda(nlocal); + for (int i = 0; i < nlocal; i++) + if (d_mask(i) & groupbit) + domainKK->x2lamda(&d_x(i,0), &d_x(i,0)); - if (rfix.size() > 0) { - atomKK->sync(Host,ALL_MASK); - for (auto &ifix : rfix) - ifix->deform(0); - atomKK->modified(Host,ALL_MASK); - } + for (auto &ifix : rfix) + ifix->deform(0); } // reset global and local box to new size/shape // only if deform fix is controlling the dimension - if (set[0].style) { - domain->boxlo[0] = set[0].lo_target; - domain->boxhi[0] = set[0].hi_target; + if (dimflag[0]) { + domainKK->boxlo[0] = d_set[0].lo_target; + domainKK->boxhi[0] = d_set[0].hi_target; } - if (set[1].style) { - domain->boxlo[1] = set[1].lo_target; - domain->boxhi[1] = set[1].hi_target; + if (dimflag[1]) { + domainKK->boxlo[1] = d_set[1].lo_target; + domainKK->boxhi[1] = d_set[1].hi_target; } - if (set[2].style) { - domain->boxlo[2] = set[2].lo_target; - domain->boxhi[2] = set[2].hi_target; + if (dimflag[2]) { + domainKK->boxlo[2] = d_set[2].lo_target; + domainKK->boxhi[2] = d_set[2].hi_target; } if (triclinic) { - if (set[3].style) domain->yz = set[3].tilt_target; - if (set[4].style) domain->xz = set[4].tilt_target; - if (set[5].style) domain->xy = set[5].tilt_target; + if (dimflag[3]) domainKK->yz = d_set[3].tilt_target; + if (dimflag[4]) domainKK->xz = d_set[4].tilt_target; + if (dimflag[5]) domainKK->xy = d_set[5].tilt_target; } - domain->set_global_box(); - domain->set_local_box(); + domainKK->set_global_box(); + domainKK->set_local_box(); // convert atoms and rigid bodies back to box coords if (remapflag == Domain::X_REMAP) { - int nlocal = atom->nlocal; + atomKK->sync(execution_space, X_MASK | MASK_MASK ); + d_x = atomKK->k_x.template view(); + d_mask = atomKK->k_mask.template view(); + int nlocal = atomKK->nlocal; - domainKK->lamda2x(nlocal); + for (int i = 0; i < nlocal; i++) + if (d_mask(i) & groupbit) + domainKK->lamda2x(&d_x(i,0), &d_x(i,0)); - if (rfix.size() > 0) { - atomKK->sync(Host,ALL_MASK); - for (auto &ifix : rfix) - ifix->deform(1); - atomKK->modified(Host,ALL_MASK); - } + for (auto &ifix : rfix) + ifix->deform(1); } - - // redo KSpace coeffs since box has changed - - if (kspace_flag) force->kspace->setup(); } +namespace LAMMPS_NS { +template class FixDeformKokkos; +#ifdef LMP_KOKKOS_GPU +template class FixDeformKokkos; +#endif +} + diff --git a/src/KOKKOS/fix_deform_kokkos.h b/src/KOKKOS/fix_deform_kokkos.h index 32140c8766..184f1d0b07 100644 --- a/src/KOKKOS/fix_deform_kokkos.h +++ b/src/KOKKOS/fix_deform_kokkos.h @@ -13,9 +13,9 @@ #ifdef FIX_CLASS // clang-format off -FixStyle(deform/kk,FixDeformKokkos); -FixStyle(deform/kk/device,FixDeformKokkos); -FixStyle(deform/kk/host,FixDeformKokkos); +FixStyle(deform/kk,FixDeformKokkos); +FixStyle(deform/kk/device,FixDeformKokkos); +FixStyle(deform/kk/host,FixDeformKokkos); // clang-format on #else @@ -24,19 +24,40 @@ FixStyle(deform/kk/host,FixDeformKokkos); #define LMP_FIX_DEFORM_KOKKOS_H #include "fix_deform.h" +#include "kokkos_type.h" namespace LAMMPS_NS { +template class FixDeformKokkos : public FixDeform { public: - FixDeformKokkos(class LAMMPS *, int, char **); + typedef DeviceType device_type; + typedef ArrayTypes AT; + FixDeformKokkos(class LAMMPS *, int, char **); + ~FixDeformKokkos(); + + void init() override; void pre_exchange() override; void end_of_step() override; private: class DomainKokkos *domainKK; + typename AT::t_x_array d_x; + typename AT::t_int_1d d_mask; + + Kokkos::DualView k_set; + Kokkos::View d_set; + + KOKKOS_INLINE_FUNCTION + void virtual apply_volume(); + + KOKKOS_INLINE_FUNCTION + void apply_strain(); + + KOKKOS_INLINE_FUNCTION + void update_domain(); }; }