diff --git a/src/KOKKOS/kokkos.cpp b/src/KOKKOS/kokkos.cpp index 14a689e72a..53fda7684c 100644 --- a/src/KOKKOS/kokkos.cpp +++ b/src/KOKKOS/kokkos.cpp @@ -189,7 +189,7 @@ KokkosLMP::KokkosLMP(LAMMPS *lmp, int narg, char **arg) : Pointers(lmp) args.num_numa = numa; args.device_id = device; - Kokkos::initialize(args); + initialize(args); // default settings for package kokkos command @@ -302,6 +302,25 @@ KokkosLMP::~KokkosLMP() } +/* ---------------------------------------------------------------------- */ + +void KokkosLMP::initialize(Kokkos::InitArguments args) +{ + if (!Kokkos::is_initialized()) + Kokkos::initialize(args); +} + +/* ---------------------------------------------------------------------- */ + +void KokkosLMP::finalize() +{ + static int is_finalized = 0; + + if (Kokkos::is_initialized() && !is_finalized) + Kokkos::finalize(); + is_finalized = 1; +} + /* ---------------------------------------------------------------------- invoked by package kokkos command ------------------------------------------------------------------------- */ diff --git a/src/KOKKOS/kokkos.h b/src/KOKKOS/kokkos.h index 22060ecc09..7782e4e00c 100644 --- a/src/KOKKOS/kokkos.h +++ b/src/KOKKOS/kokkos.h @@ -51,6 +51,8 @@ class KokkosLMP : protected Pointers { KokkosLMP(class LAMMPS *, int, char **); ~KokkosLMP(); + static void initialize(Kokkos::InitArguments); + static void finalize(); void accelerator(int, char **); int neigh_count(int); diff --git a/src/accelerator_kokkos.h b/src/accelerator_kokkos.h index 2ee15e946a..118ac3b9c8 100644 --- a/src/accelerator_kokkos.h +++ b/src/accelerator_kokkos.h @@ -43,6 +43,10 @@ #include "modify.h" #include "neighbor.h" +namespace Kokkos { + typedef int InitArguements; +}; + #define LAMMPS_INLINE inline namespace LAMMPS_NS { @@ -56,17 +60,13 @@ class KokkosLMP { KokkosLMP(class LAMMPS *, int, char **) { kokkos_exists = 0; } ~KokkosLMP() {} + static void initialize(Kokkos::InitArguements args) {} + static void finalize() {} void accelerator(int, char **) {} int neigh_list_kokkos(int) { return 0; } int neigh_count(int) { return 0; } }; -class Kokkos { - public: - static int is_initialized() {return false;} - static void finalize() {} -}; - class AtomKokkos : public Atom { public: tagint **k_special; diff --git a/src/error.cpp b/src/error.cpp index fd12cbe80e..9811a1d3eb 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -81,7 +81,7 @@ void Error::universe_all(const std::string &file, int line, const std::string &s throw LAMMPSException(mesg); #else - if (Kokkos::is_initialized()) Kokkos::finalize(); + KokkosLMP::finalize(); MPI_Finalize(); exit(1); #endif @@ -107,6 +107,7 @@ void Error::universe_one(const std::string &file, int line, const std::string &s throw LAMMPSAbortException(mesg, universe->uworld); #else + KokkosLMP::finalize(); MPI_Abort(universe->uworld,1); exit(1); // to trick "smart" compilers into believing this does not return #endif @@ -173,8 +174,8 @@ void Error::all(const std::string &file, int line, const std::string &str) if (screen && screen != stdout) fclose(screen); if (logfile) fclose(logfile); + KokkosLMP::finalize(); if (universe->nworlds > 1) MPI_Abort(universe->uworld,1); - if (Kokkos::is_initialized()) Kokkos::finalize(); MPI_Finalize(); exit(1); #endif @@ -213,7 +214,7 @@ void Error::one(const std::string &file, int line, const std::string &str) #else if (screen) fflush(screen); if (logfile) fflush(logfile); - if (Kokkos::is_initialized()) Kokkos::finalize(); + KokkosLMP::finalize(); MPI_Abort(world,1); exit(1); // to trick "smart" compilers into believing this does not return #endif @@ -316,7 +317,7 @@ void Error::done(int status) if (screen && screen != stdout) fclose(screen); if (logfile) fclose(logfile); - if (Kokkos::is_initialized()) Kokkos::finalize(); + KokkosLMP::finalize(); MPI_Finalize(); exit(status); } diff --git a/src/library.cpp b/src/library.cpp index 4fcf382247..178364d777 100644 --- a/src/library.cpp +++ b/src/library.cpp @@ -356,7 +356,7 @@ void lammps_mpi_finalize() void lammps_kokkos_finalize() { - if (Kokkos::is_initialized()) Kokkos::finalize(); + KokkosLMP::finalize(); } // ---------------------------------------------------------------------- diff --git a/src/main.cpp b/src/main.cpp index ff1b3f9a09..76ae4fde09 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -78,15 +78,16 @@ int main(int argc, char **argv) lammps->input->file(); delete lammps; } catch (LAMMPSAbortException &ae) { - if (Kokkos::is_initialized()) Kokkos::finalize(); + KokkosLMP::finalize(); MPI_Abort(ae.universe, 1); } catch (LAMMPSException &e) { - if (Kokkos::is_initialized()) Kokkos::finalize(); + KokkosLMP::finalize(); MPI_Barrier(lammps_comm); MPI_Finalize(); exit(1); } catch (fmt::format_error &fe) { fprintf(stderr, "fmt::format_error: %s\n", fe.what()); + KokkosLMP::finalize(); MPI_Abort(MPI_COMM_WORLD, 1); exit(1); } @@ -97,11 +98,12 @@ int main(int argc, char **argv) delete lammps; } catch (fmt::format_error &fe) { fprintf(stderr, "fmt::format_error: %s\n", fe.what()); + KokkosLMP::finalize(); MPI_Abort(MPI_COMM_WORLD, 1); exit(1); } #endif - if (Kokkos::is_initialized()) Kokkos::finalize(); + KokkosLMP::finalize(); MPI_Barrier(lammps_comm); MPI_Finalize(); }