diff --git a/src/KOKKOS/kokkos.cpp b/src/KOKKOS/kokkos.cpp index 91ea6d37ac..84a8f59dd0 100644 --- a/src/KOKKOS/kokkos.cpp +++ b/src/KOKKOS/kokkos.cpp @@ -137,13 +137,13 @@ KokkosLMP::KokkosLMP(LAMMPS *lmp, int narg, char **arg) : Pointers(lmp) int set_flag = 0; char *str; - if ((str = getenv("SLURM_LOCALID"))) { + if (str = getenv("SLURM_LOCALID")) { int local_rank = atoi(str); device = local_rank % ngpus; if (device >= skip_gpu) device++; set_flag = 1; } - if ((str = getenv("MPT_LRANK"))) { + if (str = getenv("FLUX_TASK_LOCAL_ID")) { if (ngpus > 0) { int local_rank = atoi(str); device = local_rank % ngpus; @@ -151,7 +151,7 @@ KokkosLMP::KokkosLMP(LAMMPS *lmp, int narg, char **arg) : Pointers(lmp) set_flag = 1; } } - if ((str = getenv("MV2_COMM_WORLD_LOCAL_RANK"))) { + if (str = getenv("MPT_LRANK")) { if (ngpus > 0) { int local_rank = atoi(str); device = local_rank % ngpus; @@ -159,7 +159,7 @@ KokkosLMP::KokkosLMP(LAMMPS *lmp, int narg, char **arg) : Pointers(lmp) set_flag = 1; } } - if ((str = getenv("OMPI_COMM_WORLD_LOCAL_RANK"))) { + if (str = getenv("MV2_COMM_WORLD_LOCAL_RANK")) { if (ngpus > 0) { int local_rank = atoi(str); device = local_rank % ngpus; @@ -167,7 +167,15 @@ KokkosLMP::KokkosLMP(LAMMPS *lmp, int narg, char **arg) : Pointers(lmp) set_flag = 1; } } - if ((str = getenv("PMI_LOCAL_RANK"))) { + if (str = getenv("OMPI_COMM_WORLD_LOCAL_RANK")) { + if (ngpus > 0) { + int local_rank = atoi(str); + device = local_rank % ngpus; + if (device >= skip_gpu) device++; + set_flag = 1; + } + } + if (str = getenv("PMI_LOCAL_RANK")) { if (ngpus > 0) { int local_rank = atoi(str); device = local_rank % ngpus; diff --git a/src/KOKKOS/pair_kokkos.h b/src/KOKKOS/pair_kokkos.h index 2c2a622791..da6bd11006 100644 --- a/src/KOKKOS/pair_kokkos.h +++ b/src/KOKKOS/pair_kokkos.h @@ -138,9 +138,9 @@ struct PairComputeFunctor { F_FLOAT fztmp = 0.0; if (NEIGHFLAG == FULL) { - f(i,0) = 0.0; - f(i,1) = 0.0; - f(i,2) = 0.0; + a_f(i,0) -= f(i,0); + a_f(i,1) -= f(i,1); + a_f(i,2) -= f(i,2); } for (int jj = 0; jj < jnum; jj++) { @@ -212,9 +212,9 @@ struct PairComputeFunctor { F_FLOAT fztmp = 0.0; if (NEIGHFLAG == FULL) { - f(i,0) = 0.0; - f(i,1) = 0.0; - f(i,2) = 0.0; + a_f(i,0) -= f(i,0); + a_f(i,1) -= f(i,1); + a_f(i,2) -= f(i,2); } for (int jj = 0; jj < jnum; jj++) {