diff --git a/src/imbalance_neigh.cpp b/src/imbalance_neigh.cpp index 3a523114ee..0a1c2a87cc 100644 --- a/src/imbalance_neigh.cpp +++ b/src/imbalance_neigh.cpp @@ -1,4 +1,3 @@ -// clang-format off /* ---------------------------------------------------------------------- LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator https://www.lammps.org/, Sandia National Laboratories @@ -14,6 +13,7 @@ #include "imbalance_neigh.h" +#include "accelerator_kokkos.h" #include "atom.h" #include "comm.h" #include "error.h" @@ -36,9 +36,9 @@ ImbalanceNeigh::ImbalanceNeigh(LAMMPS *lmp) : Imbalance(lmp) int ImbalanceNeigh::options(int narg, char **arg) { - if (narg < 1) error->all(FLERR,"Illegal balance weight command"); - factor = utils::numeric(FLERR,arg[0],false,lmp); - if (factor <= 0.0) error->all(FLERR,"Illegal balance weight command"); + if (narg < 1) error->all(FLERR, "Illegal balance weight command"); + factor = utils::numeric(FLERR, arg[0], false, lmp); + if (factor <= 0.0) error->all(FLERR, "Illegal balance weight command"); return 1; } @@ -50,19 +50,30 @@ void ImbalanceNeigh::compute(double *weight) if (factor == 0.0) return; + // cannot use neighbor list weight with KOKKOS using GPUs + + if (lmp->kokkos && lmp->kokkos->kokkos_exists) { + if (lmp->kokkos->ngpus > 0) { + if (comm->me == 0 && !did_warn) + error->warning(FLERR, "Balance weight neigh skipped with KOKKOS using GPUs"); + did_warn = 1; + return; + } + } + // find suitable neighbor list // can only use certain conventional neighbor lists // NOTE: why not full list, if half does not exist? for (req = 0; req < neighbor->old_nrequest; ++req) { - if (neighbor->old_requests[req]->half && - neighbor->old_requests[req]->skip == 0 && - neighbor->lists[req] && neighbor->lists[req]->numneigh) break; + if (neighbor->old_requests[req]->half && neighbor->old_requests[req]->skip == 0 && + neighbor->lists[req] && neighbor->lists[req]->numneigh) + break; } if (req >= neighbor->old_nrequest || neighbor->ago < 0) { if (comm->me == 0 && !did_warn) - error->warning(FLERR,"Balance weight neigh skipped b/c no list found"); + error->warning(FLERR, "Balance weight neigh skipped b/c no list found"); did_warn = 1; return; } @@ -72,16 +83,16 @@ void ImbalanceNeigh::compute(double *weight) NeighList *list = neighbor->lists[req]; const int inum = list->inum; - const int * const ilist = list->ilist; - const int * const numneigh = list->numneigh; + const int *const ilist = list->ilist; + const int *const numneigh = list->numneigh; int nlocal = atom->nlocal; bigint neighsum = 0; for (int i = 0; i < inum; ++i) neighsum += numneigh[ilist[i]]; double localwt = 0.0; - if (nlocal) localwt = 1.0*neighsum/nlocal; + if (nlocal) localwt = 1.0 * neighsum / nlocal; - if (nlocal && localwt <= 0.0) error->one(FLERR,"Balance weight <= 0.0"); + if (nlocal && localwt <= 0.0) error->one(FLERR, "Balance weight <= 0.0"); // apply factor if specified != 1.0 // wtlo,wthi = lo/hi values excluding 0.0 due to no atoms on this proc @@ -90,15 +101,15 @@ void ImbalanceNeigh::compute(double *weight) // expand/contract all localwt values from lo->hi to lo->newhi if (factor != 1.0) { - double wtlo,wthi; + double wtlo, wthi; if (localwt == 0.0) localwt = BIG; - MPI_Allreduce(&localwt,&wtlo,1,MPI_DOUBLE,MPI_MIN,world); + MPI_Allreduce(&localwt, &wtlo, 1, MPI_DOUBLE, MPI_MIN, world); if (localwt == BIG) localwt = 0.0; - MPI_Allreduce(&localwt,&wthi,1,MPI_DOUBLE,MPI_MAX,world); + MPI_Allreduce(&localwt, &wthi, 1, MPI_DOUBLE, MPI_MAX, world); if (wtlo == wthi) return; - double newhi = wthi*factor; - localwt = wtlo + ((localwt-wtlo)/(wthi-wtlo)) * (newhi-wtlo); + double newhi = wthi * factor; + localwt = wtlo + ((localwt - wtlo) / (wthi - wtlo)) * (newhi - wtlo); } for (int i = 0; i < nlocal; i++) weight[i] *= localwt; @@ -108,5 +119,5 @@ void ImbalanceNeigh::compute(double *weight) std::string ImbalanceNeigh::info() { - return fmt::format(" neighbor weight factor: {}\n",factor); + return fmt::format(" neighbor weight factor: {}\n", factor); }