complete rewrite of kokkos version

- array of structs set[i] from base class, converted to view- Host, converted to execution_space
- atom->nlocal converted to atomKK->nlocal- domain converted to domainKK- class now templated for DeviceType- SCALE not implemented in kokkos version
... actually by the time i was done it was a complete rewrite of the kokkos version
This commit is contained in:
alphataubio
2024-08-01 18:55:16 -04:00
parent 99a2bd799e
commit f62d1b5d55
2 changed files with 271 additions and 204 deletions

View File

@ -13,7 +13,7 @@
------------------------------------------------------------------------- */
/* ----------------------------------------------------------------------
Contributing author: Pieter in 't Veld (SNL)
Contributing author: Mitch Murphy (alphataubio@gmail.com)
------------------------------------------------------------------------- */
#include "fix_deform_kokkos.h"
@ -26,6 +26,7 @@
#include "irregular.h"
#include "kspace.h"
#include "math_const.h"
#include "memory_kokkos.h"
#include "modify.h"
#include "update.h"
#include "variable.h"
@ -36,18 +37,38 @@ using namespace LAMMPS_NS;
using namespace FixConst;
using namespace MathConst;
enum{NONE=0,FINAL,DELTA,SCALE,VEL,ERATE,TRATE,VOLUME,WIGGLE,VARIABLE};
enum{ONE_FROM_ONE,ONE_FROM_TWO,TWO_FROM_ONE};
/* ---------------------------------------------------------------------- */
FixDeformKokkos::FixDeformKokkos(LAMMPS *lmp, int narg, char **arg) : FixDeform(lmp, narg, arg)
template <class DeviceType>
FixDeformKokkos<DeviceType>::FixDeformKokkos(LAMMPS *lmp, int narg, char **arg) : FixDeform(lmp, narg, arg)
{
kokkosable = 1;
atomKK = (AtomKokkos *) atom;
domainKK = (DomainKokkos *) domain;
execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
datamask_read = EMPTY_MASK;
datamask_modify = EMPTY_MASK;
memoryKK->create_kokkos(k_set,6,"fix_deform:set");
d_set = k_set.template view<DeviceType>();
}
template<class DeviceType>
FixDeformKokkos<DeviceType>::~FixDeformKokkos()
{
if (copymode) return;
memoryKK->destroy_kokkos(k_set,set);
}
template <class DeviceType>
void FixDeformKokkos<DeviceType>::init()
{
FixDeform::init();
memcpy((void *)k_set.h_view.data(), (void *)set, sizeof(Set)*6);
k_set.template modify<LMPHostType>();
k_set.template sync<DeviceType>();
}
/* ----------------------------------------------------------------------
@ -60,131 +81,105 @@ FixDeformKokkos::FixDeformKokkos(LAMMPS *lmp, int narg, char **arg) : FixDeform(
image flags to new values, making eqs in doc of Domain:image_flip incorrect
------------------------------------------------------------------------- */
void FixDeformKokkos::pre_exchange()
template <class DeviceType>
void FixDeformKokkos<DeviceType>::pre_exchange()
{
if (flip == 0) return;
domain->yz = set[3].tilt_target = set[3].tilt_flip;
domain->xz = set[4].tilt_target = set[4].tilt_flip;
domain->xy = set[5].tilt_target = set[5].tilt_flip;
domain->set_global_box();
domain->set_local_box();
domainKK->yz = d_set[3].tilt_target = d_set[3].tilt_flip;
domainKK->xz = d_set[4].tilt_target = d_set[4].tilt_flip;
domainKK->xy = d_set[5].tilt_target = d_set[5].tilt_flip;
domainKK->set_global_box();
domainKK->set_local_box();
domainKK->image_flip(flipxy,flipxz,flipyz);
domainKK->image_flip(flipxy, flipxz, flipyz);
// FIXME: just replace with domainKK->remap_all(), is this correct ?
//double **x = atom->x;
//imageint *image = atom->image;
//int nlocal = atom->nlocal;
//for (int i = 0; i < nlocal; i++) domain->remap(x[i],image[i]);
//domain->x2lamda(atom->nlocal);
//irregular->migrate_atoms();
//domain->lamda2x(atom->nlocal);
domainKK->remap_all();
domainKK->x2lamda(atom->nlocal);
atomKK->sync(Host,ALL_MASK);
irregular->migrate_atoms();
atomKK->modified(Host,ALL_MASK);
domainKK->lamda2x(atom->nlocal);
flip = 0;
}
/* ---------------------------------------------------------------------- */
void FixDeformKokkos::end_of_step()
template <class DeviceType>
void FixDeformKokkos<DeviceType>::end_of_step()
{
int i;
double delta = update->ntimestep - update->beginstep;
if (delta != 0.0) delta /= update->endstep - update->beginstep;
// wrap variable evaluations with clear/add
if (varflag) modify->clearstep_compute();
// set new box size
// set new box size for strain-based dims
apply_strain();
// set new box size for VOLUME dims that are linked to other dims
// NOTE: still need to set h_rate for these dims
apply_volume();
if (varflag) modify->addstep_compute(update->ntimestep + nevery);
update_domain();
// redo KSpace coeffs since box has changed
if (kspace_flag) force->kspace->setup();
}
/* ----------------------------------------------------------------------
apply strain controls
------------------------------------------------------------------------- */
template <class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixDeformKokkos<DeviceType>::apply_strain()
{
// for NONE, target is current box size
// for TRATE, set target directly based on current time, also set h_rate
// for WIGGLE, set target directly based on current time, also set h_rate
// for VARIABLE, set target directly via variable eval, also set h_rate
// for others except VOLUME, target is linear value between start and stop
for (i = 0; i < 3; i++) {
if (set[i].style == NONE) {
set[i].lo_target = domain->boxlo[i];
set[i].hi_target = domain->boxhi[i];
} else if (set[i].style == TRATE) {
double delta = update->ntimestep - update->beginstep;
if (delta != 0.0) delta /= update->endstep - update->beginstep;
for (int i = 0; i < 3; i++) {
if (d_set[i].style == NONE) {
d_set[i].lo_target = domainKK->boxlo[i];
d_set[i].hi_target = domainKK->boxhi[i];
} else if (d_set[i].style == TRATE) {
double delt = (update->ntimestep - update->beginstep) * update->dt;
set[i].lo_target = 0.5*(set[i].lo_start+set[i].hi_start) -
0.5*((set[i].hi_start-set[i].lo_start) * exp(set[i].rate*delt));
set[i].hi_target = 0.5*(set[i].lo_start+set[i].hi_start) +
0.5*((set[i].hi_start-set[i].lo_start) * exp(set[i].rate*delt));
h_rate[i] = set[i].rate * domain->h[i];
h_ratelo[i] = -0.5*h_rate[i];
} else if (set[i].style == WIGGLE) {
double shift = 0.5 * ((d_set[i].hi_start - d_set[i].lo_start) * exp(d_set[i].rate * delt));
d_set[i].lo_target = 0.5 * (d_set[i].lo_start + d_set[i].hi_start) - shift;
d_set[i].hi_target = 0.5 * (d_set[i].lo_start + d_set[i].hi_start) + shift;
h_rate[i] = d_set[i].rate * domainKK->h[i];
h_ratelo[i] = -0.5 * h_rate[i];
} else if (d_set[i].style == WIGGLE) {
double delt = (update->ntimestep - update->beginstep) * update->dt;
set[i].lo_target = set[i].lo_start -
0.5*set[i].amplitude * sin(MY_2PI*delt/set[i].tperiod);
set[i].hi_target = set[i].hi_start +
0.5*set[i].amplitude * sin(MY_2PI*delt/set[i].tperiod);
h_rate[i] = MY_2PI/set[i].tperiod * set[i].amplitude *
cos(MY_2PI*delt/set[i].tperiod);
h_ratelo[i] = -0.5*h_rate[i];
} else if (set[i].style == VARIABLE) {
double del = input->variable->compute_equal(set[i].hvar);
set[i].lo_target = set[i].lo_start - 0.5*del;
set[i].hi_target = set[i].hi_start + 0.5*del;
h_rate[i] = input->variable->compute_equal(set[i].hratevar);
h_ratelo[i] = -0.5*h_rate[i];
} else if (set[i].style != VOLUME) {
set[i].lo_target = set[i].lo_start +
delta*(set[i].lo_stop - set[i].lo_start);
set[i].hi_target = set[i].hi_start +
delta*(set[i].hi_stop - set[i].hi_start);
}
}
// set new box size for VOLUME dims that are linked to other dims
// NOTE: still need to set h_rate for these dims
for (i = 0; i < 3; i++) {
if (set[i].style != VOLUME) continue;
if (set[i].substyle == ONE_FROM_ONE) {
set[i].lo_target = 0.5*(set[i].lo_start+set[i].hi_start) -
0.5*(set[i].vol_start /
(set[set[i].dynamic1].hi_target -
set[set[i].dynamic1].lo_target) /
(set[set[i].fixed].hi_start-set[set[i].fixed].lo_start));
set[i].hi_target = 0.5*(set[i].lo_start+set[i].hi_start) +
0.5*(set[i].vol_start /
(set[set[i].dynamic1].hi_target -
set[set[i].dynamic1].lo_target) /
(set[set[i].fixed].hi_start-set[set[i].fixed].lo_start));
} else if (set[i].substyle == ONE_FROM_TWO) {
set[i].lo_target = 0.5*(set[i].lo_start+set[i].hi_start) -
0.5*(set[i].vol_start /
(set[set[i].dynamic1].hi_target -
set[set[i].dynamic1].lo_target) /
(set[set[i].dynamic2].hi_target -
set[set[i].dynamic2].lo_target));
set[i].hi_target = 0.5*(set[i].lo_start+set[i].hi_start) +
0.5*(set[i].vol_start /
(set[set[i].dynamic1].hi_target -
set[set[i].dynamic1].lo_target) /
(set[set[i].dynamic2].hi_target -
set[set[i].dynamic2].lo_target));
} else if (set[i].substyle == TWO_FROM_ONE) {
set[i].lo_target = 0.5*(set[i].lo_start+set[i].hi_start) -
0.5*sqrt(set[i].vol_start /
(set[set[i].dynamic1].hi_target -
set[set[i].dynamic1].lo_target) /
(set[set[i].fixed].hi_start -
set[set[i].fixed].lo_start) *
(set[i].hi_start - set[i].lo_start));
set[i].hi_target = 0.5*(set[i].lo_start+set[i].hi_start) +
0.5*sqrt(set[i].vol_start /
(set[set[i].dynamic1].hi_target -
set[set[i].dynamic1].lo_target) /
(set[set[i].fixed].hi_start -
set[set[i].fixed].lo_start) *
(set[i].hi_start - set[i].lo_start));
double shift = 0.5 * d_set[i].amplitude * sin(MY_2PI * delt / d_set[i].tperiod);
d_set[i].lo_target = d_set[i].lo_start - shift;
d_set[i].hi_target = d_set[i].hi_start + shift;
h_rate[i] = MY_2PI / d_set[i].tperiod * d_set[i].amplitude *
cos(MY_2PI * delt / d_set[i].tperiod);
h_ratelo[i] = -0.5 * h_rate[i];
} else if (d_set[i].style == VARIABLE) {
double del = input->variable->compute_equal(d_set[i].hvar);
d_set[i].lo_target = d_set[i].lo_start - 0.5 * del;
d_set[i].hi_target = d_set[i].hi_start + 0.5 * del;
h_rate[i] = input->variable->compute_equal(d_set[i].hratevar);
h_ratelo[i] = -0.5 * h_rate[i];
} else if (d_set[i].style == FINAL || d_set[i].style == DELTA || d_set[i].style == SCALE ||
d_set[i].style == VEL || d_set[i].style == ERATE) {
d_set[i].lo_target = d_set[i].lo_start + delta * (d_set[i].lo_stop - d_set[i].lo_start);
d_set[i].hi_target = d_set[i].hi_start + delta * (d_set[i].hi_stop - d_set[i].hi_start);
}
}
@ -196,55 +191,101 @@ void FixDeformKokkos::end_of_step()
// for other styles, target is linear value between start and stop values
if (triclinic) {
double *h = domain->h;
for (i = 3; i < 6; i++) {
if (set[i].style == NONE) {
if (i == 5) set[i].tilt_target = domain->xy;
else if (i == 4) set[i].tilt_target = domain->xz;
else if (i == 3) set[i].tilt_target = domain->yz;
} else if (set[i].style == TRATE) {
for (int i = 3; i < 6; i++) {
if (d_set[i].style == NONE) {
if (i == 5) d_set[i].tilt_target = domainKK->xy;
else if (i == 4) d_set[i].tilt_target = domainKK->xz;
else if (i == 3) d_set[i].tilt_target = domainKK->yz;
} else if (d_set[i].style == TRATE) {
double delt = (update->ntimestep - update->beginstep) * update->dt;
set[i].tilt_target = set[i].tilt_start * exp(set[i].rate*delt);
h_rate[i] = set[i].rate * domain->h[i];
} else if (set[i].style == WIGGLE) {
d_set[i].tilt_target = d_set[i].tilt_start * exp(d_set[i].rate * delt);
h_rate[i] = d_set[i].rate * domainKK->h[i];
} else if (d_set[i].style == WIGGLE) {
double delt = (update->ntimestep - update->beginstep) * update->dt;
set[i].tilt_target = set[i].tilt_start +
set[i].amplitude * sin(MY_2PI*delt/set[i].tperiod);
h_rate[i] = MY_2PI/set[i].tperiod * set[i].amplitude *
cos(MY_2PI*delt/set[i].tperiod);
} else if (set[i].style == VARIABLE) {
double delta_tilt = input->variable->compute_equal(set[i].hvar);
set[i].tilt_target = set[i].tilt_start + delta_tilt;
h_rate[i] = input->variable->compute_equal(set[i].hratevar);
d_set[i].tilt_target = d_set[i].tilt_start +
d_set[i].amplitude * sin(MY_2PI * delt / d_set[i].tperiod);
h_rate[i] = MY_2PI / d_set[i].tperiod * d_set[i].amplitude *
cos(MY_2PI * delt / d_set[i].tperiod);
} else if (d_set[i].style == VARIABLE) {
double delta_tilt = input->variable->compute_equal(d_set[i].hvar);
d_set[i].tilt_target = d_set[i].tilt_start + delta_tilt;
h_rate[i] = input->variable->compute_equal(d_set[i].hratevar);
} else {
set[i].tilt_target = set[i].tilt_start +
delta*(set[i].tilt_stop - set[i].tilt_start);
d_set[i].tilt_target = d_set[i].tilt_start + delta * (d_set[i].tilt_stop - d_set[i].tilt_start);
}
}
}
}
// tilt_target can be large positive or large negative value
// add/subtract box lengths until tilt_target is closest to current value
/* ----------------------------------------------------------------------
apply volume controls
------------------------------------------------------------------------- */
template <class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixDeformKokkos<DeviceType>::apply_volume()
{
for (int i = 0; i < 3; i++) {
if (d_set[i].style != VOLUME) continue;
int dynamic1 = d_set[i].dynamic1;
int dynamic2 = d_set[i].dynamic2;
int fixed = d_set[i].fixed;
double v0 = d_set[i].vol_start;
double shift = 0.0;
if (d_set[i].substyle == ONE_FROM_ONE) {
shift = 0.5 * (v0 / (d_set[dynamic1].hi_target - d_set[dynamic1].lo_target) /
(d_set[fixed].hi_start - d_set[fixed].lo_start));
} else if (d_set[i].substyle == ONE_FROM_TWO) {
shift = 0.5 * (v0 / (d_set[dynamic1].hi_target - d_set[dynamic1].lo_target) /
(d_set[dynamic2].hi_target - d_set[dynamic2].lo_target));
} else if (d_set[i].substyle == TWO_FROM_ONE) {
shift = 0.5 * sqrt(v0 * (d_set[i].hi_start - d_set[i].lo_start) /
(d_set[dynamic1].hi_target - d_set[dynamic1].lo_target) /
(d_set[fixed].hi_start - d_set[fixed].lo_start));
}
h_rate[i] = (2.0 * shift / (domainKK->boxhi[i] - domainKK->boxlo[i]) - 1.0) / update->dt;
h_ratelo[i] = -0.5 * h_rate[i];
d_set[i].lo_target = 0.5 * (d_set[i].lo_start + d_set[i].hi_start) - shift;
d_set[i].hi_target = 0.5 * (d_set[i].lo_start + d_set[i].hi_start) + shift;
}
}
/* ----------------------------------------------------------------------
Update box domain
------------------------------------------------------------------------- */
template <class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixDeformKokkos<DeviceType>::update_domain()
{
// tilt_target can be large positive or large negative value
// add/subtract box lengths until tilt_target is closest to current value
if (triclinic) {
double *h = domainKK->h;
for (int i = 3; i < 6; i++) {
int idenom = 0;
if (i == 5) idenom = 0;
else if (i == 4) idenom = 0;
else if (i == 3) idenom = 1;
double denom = set[idenom].hi_target - set[idenom].lo_target;
double denom = d_set[idenom].hi_target - d_set[idenom].lo_target;
double current = h[i]/h[idenom];
double current = h[i] / h[idenom];
while (set[i].tilt_target/denom - current > 0.0)
set[i].tilt_target -= denom;
while (set[i].tilt_target/denom - current < 0.0)
set[i].tilt_target += denom;
if (fabs(set[i].tilt_target/denom - 1.0 - current) <
fabs(set[i].tilt_target/denom - current))
set[i].tilt_target -= denom;
while (d_set[i].tilt_target / denom - current > 0.0)
d_set[i].tilt_target -= denom;
while (d_set[i].tilt_target / denom - current < 0.0)
d_set[i].tilt_target += denom;
if (fabs(d_set[i].tilt_target / denom - 1.0 - current) <
fabs(d_set[i].tilt_target / denom - current))
d_set[i].tilt_target -= denom;
}
}
if (varflag) modify->addstep_compute(update->ntimestep + nevery);
// if any tilt ratios exceed 0.5, set flip = 1 and compute new tilt values
// do not flip in x or y if non-periodic (can tilt but not flip)
// this is b/c the box length would be changed (dramatically) by flip
@ -255,48 +296,48 @@ void FixDeformKokkos::end_of_step()
// flip is performed on next timestep, before reneighboring in pre-exchange()
if (triclinic && flipflag) {
double xprd = set[0].hi_target - set[0].lo_target;
double yprd = set[1].hi_target - set[1].lo_target;
double xprd = d_set[0].hi_target - d_set[0].lo_target;
double yprd = d_set[1].hi_target - d_set[1].lo_target;
double xprdinv = 1.0 / xprd;
double yprdinv = 1.0 / yprd;
if (set[3].tilt_target*yprdinv < -0.5 ||
set[3].tilt_target*yprdinv > 0.5 ||
set[4].tilt_target*xprdinv < -0.5 ||
set[4].tilt_target*xprdinv > 0.5 ||
set[5].tilt_target*xprdinv < -0.5 ||
set[5].tilt_target*xprdinv > 0.5) {
set[3].tilt_flip = set[3].tilt_target;
set[4].tilt_flip = set[4].tilt_target;
set[5].tilt_flip = set[5].tilt_target;
if (d_set[3].tilt_target * yprdinv < -0.5 ||
d_set[3].tilt_target * yprdinv > 0.5 ||
d_set[4].tilt_target * xprdinv < -0.5 ||
d_set[4].tilt_target * xprdinv > 0.5 ||
d_set[5].tilt_target * xprdinv < -0.5 ||
d_set[5].tilt_target * xprdinv > 0.5) {
d_set[3].tilt_flip = d_set[3].tilt_target;
d_set[4].tilt_flip = d_set[4].tilt_target;
d_set[5].tilt_flip = d_set[5].tilt_target;
flipxy = flipxz = flipyz = 0;
if (domain->yperiodic) {
if (set[3].tilt_flip*yprdinv < -0.5) {
set[3].tilt_flip += yprd;
set[4].tilt_flip += set[5].tilt_flip;
if (domainKK->yperiodic) {
if (d_set[3].tilt_flip * yprdinv < -0.5) {
d_set[3].tilt_flip += yprd;
d_set[4].tilt_flip += d_set[5].tilt_flip;
flipyz = 1;
} else if (set[3].tilt_flip*yprdinv > 0.5) {
set[3].tilt_flip -= yprd;
set[4].tilt_flip -= set[5].tilt_flip;
} else if (d_set[3].tilt_flip * yprdinv > 0.5) {
d_set[3].tilt_flip -= yprd;
d_set[4].tilt_flip -= d_set[5].tilt_flip;
flipyz = -1;
}
}
if (domain->xperiodic) {
if (set[4].tilt_flip*xprdinv < -0.5) {
set[4].tilt_flip += xprd;
if (domainKK->xperiodic) {
if (d_set[4].tilt_flip * xprdinv < -0.5) {
d_set[4].tilt_flip += xprd;
flipxz = 1;
}
if (set[4].tilt_flip*xprdinv > 0.5) {
set[4].tilt_flip -= xprd;
if (d_set[4].tilt_flip * xprdinv > 0.5) {
d_set[4].tilt_flip -= xprd;
flipxz = -1;
}
if (set[5].tilt_flip*xprdinv < -0.5) {
set[5].tilt_flip += xprd;
if (d_set[5].tilt_flip * xprdinv < -0.5) {
d_set[5].tilt_flip += xprd;
flipxy = 1;
}
if (set[5].tilt_flip*xprdinv > 0.5) {
set[5].tilt_flip -= xprd;
if (d_set[5].tilt_flip * xprdinv > 0.5) {
d_set[5].tilt_flip -= xprd;
flipxy = -1;
}
}
@ -310,60 +351,65 @@ void FixDeformKokkos::end_of_step()
// convert atoms and rigid bodies to lamda coords
if (remapflag == Domain::X_REMAP) {
int nlocal = atom->nlocal;
atomKK->sync(execution_space, X_MASK | MASK_MASK );
d_x = atomKK->k_x.template view<DeviceType>();
d_mask = atomKK->k_mask.template view<DeviceType>();
int nlocal = atomKK->nlocal;
domainKK->x2lamda(nlocal);
for (int i = 0; i < nlocal; i++)
if (d_mask(i) & groupbit)
domainKK->x2lamda(&d_x(i,0), &d_x(i,0));
if (rfix.size() > 0) {
atomKK->sync(Host,ALL_MASK);
for (auto &ifix : rfix)
ifix->deform(0);
atomKK->modified(Host,ALL_MASK);
}
for (auto &ifix : rfix)
ifix->deform(0);
}
// reset global and local box to new size/shape
// only if deform fix is controlling the dimension
if (set[0].style) {
domain->boxlo[0] = set[0].lo_target;
domain->boxhi[0] = set[0].hi_target;
if (dimflag[0]) {
domainKK->boxlo[0] = d_set[0].lo_target;
domainKK->boxhi[0] = d_set[0].hi_target;
}
if (set[1].style) {
domain->boxlo[1] = set[1].lo_target;
domain->boxhi[1] = set[1].hi_target;
if (dimflag[1]) {
domainKK->boxlo[1] = d_set[1].lo_target;
domainKK->boxhi[1] = d_set[1].hi_target;
}
if (set[2].style) {
domain->boxlo[2] = set[2].lo_target;
domain->boxhi[2] = set[2].hi_target;
if (dimflag[2]) {
domainKK->boxlo[2] = d_set[2].lo_target;
domainKK->boxhi[2] = d_set[2].hi_target;
}
if (triclinic) {
if (set[3].style) domain->yz = set[3].tilt_target;
if (set[4].style) domain->xz = set[4].tilt_target;
if (set[5].style) domain->xy = set[5].tilt_target;
if (dimflag[3]) domainKK->yz = d_set[3].tilt_target;
if (dimflag[4]) domainKK->xz = d_set[4].tilt_target;
if (dimflag[5]) domainKK->xy = d_set[5].tilt_target;
}
domain->set_global_box();
domain->set_local_box();
domainKK->set_global_box();
domainKK->set_local_box();
// convert atoms and rigid bodies back to box coords
if (remapflag == Domain::X_REMAP) {
int nlocal = atom->nlocal;
atomKK->sync(execution_space, X_MASK | MASK_MASK );
d_x = atomKK->k_x.template view<DeviceType>();
d_mask = atomKK->k_mask.template view<DeviceType>();
int nlocal = atomKK->nlocal;
domainKK->lamda2x(nlocal);
for (int i = 0; i < nlocal; i++)
if (d_mask(i) & groupbit)
domainKK->lamda2x(&d_x(i,0), &d_x(i,0));
if (rfix.size() > 0) {
atomKK->sync(Host,ALL_MASK);
for (auto &ifix : rfix)
ifix->deform(1);
atomKK->modified(Host,ALL_MASK);
}
for (auto &ifix : rfix)
ifix->deform(1);
}
// redo KSpace coeffs since box has changed
if (kspace_flag) force->kspace->setup();
}
namespace LAMMPS_NS {
template class FixDeformKokkos<LMPDeviceType>;
#ifdef LMP_KOKKOS_GPU
template class FixDeformKokkos<LMPHostType>;
#endif
}

View File

@ -13,9 +13,9 @@
#ifdef FIX_CLASS
// clang-format off
FixStyle(deform/kk,FixDeformKokkos);
FixStyle(deform/kk/device,FixDeformKokkos);
FixStyle(deform/kk/host,FixDeformKokkos);
FixStyle(deform/kk,FixDeformKokkos<LMPDeviceType>);
FixStyle(deform/kk/device,FixDeformKokkos<LMPDeviceType>);
FixStyle(deform/kk/host,FixDeformKokkos<LMPHostType>);
// clang-format on
#else
@ -24,19 +24,40 @@ FixStyle(deform/kk/host,FixDeformKokkos);
#define LMP_FIX_DEFORM_KOKKOS_H
#include "fix_deform.h"
#include "kokkos_type.h"
namespace LAMMPS_NS {
template <class DeviceType>
class FixDeformKokkos : public FixDeform {
public:
FixDeformKokkos(class LAMMPS *, int, char **);
typedef DeviceType device_type;
typedef ArrayTypes<DeviceType> AT;
FixDeformKokkos(class LAMMPS *, int, char **);
~FixDeformKokkos();
void init() override;
void pre_exchange() override;
void end_of_step() override;
private:
class DomainKokkos *domainKK;
typename AT::t_x_array d_x;
typename AT::t_int_1d d_mask;
Kokkos::DualView<Set*, Kokkos::LayoutRight, LMPDeviceType> k_set;
Kokkos::View<Set*,Kokkos::LayoutRight,DeviceType> d_set;
KOKKOS_INLINE_FUNCTION
void virtual apply_volume();
KOKKOS_INLINE_FUNCTION
void apply_strain();
KOKKOS_INLINE_FUNCTION
void update_domain();
};
}