diff --git a/src/KOKKOS/modify_kokkos.cpp b/src/KOKKOS/modify_kokkos.cpp index 581f6598f1..d0f57df8dc 100644 --- a/src/KOKKOS/modify_kokkos.cpp +++ b/src/KOKKOS/modify_kokkos.cpp @@ -394,7 +394,7 @@ void ModifyKokkos::final_integrate() /* ---------------------------------------------------------------------- - 2nd half of integrate call, only for relevant fixes + fused initial and final integrate call, only for relevant fixes ------------------------------------------------------------------------- */ void ModifyKokkos::fused_integrate() @@ -900,3 +900,22 @@ int ModifyKokkos::min_reset_ref() } return itmpall; } + +/* ---------------------------------------------------------------------- + check if initial and final integrate can be fused +------------------------------------------------------------------------- */ + +int ModifyKokkos::check_fuse_integrate() +{ + int fuse_integrate_flag = 1; + + for (int i = 0; i < n_initial_integrate; i++) + if (!fix[list_initial_integrate[i]]->fuse_integrate_flag) + fuse_integrate_flag = 0; + + for (int i = 0; i < n_final_integrate; i++) + if (!fix[list_final_integrate[i]]->fuse_integrate_flag) + fuse_integrate_flag = 0; + + return fuse_integrate_flag; +} diff --git a/src/KOKKOS/modify_kokkos.h b/src/KOKKOS/modify_kokkos.h index e440693fbb..8dd9e6d9df 100644 --- a/src/KOKKOS/modify_kokkos.h +++ b/src/KOKKOS/modify_kokkos.h @@ -70,6 +70,8 @@ class ModifyKokkos : public Modify { int min_dof() override; int min_reset_ref() override; + int check_fuse_integrate(); + protected: }; diff --git a/src/KOKKOS/verlet_kokkos.cpp b/src/KOKKOS/verlet_kokkos.cpp index 0339df663b..2b70f3db61 100644 --- a/src/KOKKOS/verlet_kokkos.cpp +++ b/src/KOKKOS/verlet_kokkos.cpp @@ -27,7 +27,7 @@ #include "kspace.h" #include "output.h" #include "update.h" -#include "modify.h" +#include "modify_kokkos.h" #include "timer.h" #include "memory_kokkos.h" #include "kokkos.h" @@ -629,5 +629,5 @@ void VerletKokkos::fuse_check(int i, int n) if (i == 0 || i == n-1) fuse_integrate = 0; if (update->ntimestep == output->next) fuse_integrate = 0; if (timer->has_timeout()) fuse_integrate = 0; - if (!modify->check_fuse_integrate()) fuse_integrate = 0; + if (!((ModifyKokkos*)modify)->check_fuse_integrate()) fuse_integrate = 0; } diff --git a/src/modify.cpp b/src/modify.cpp index 71f7cb8889..d0656d3895 100644 --- a/src/modify.cpp +++ b/src/modify.cpp @@ -475,17 +475,6 @@ void Modify::final_integrate() for (int i = 0; i < n_final_integrate; i++) fix[list_final_integrate[i]]->final_integrate(); } - -/* ---------------------------------------------------------------------- - 2nd half of integrate call, only for relevant fixes -------------------------------------------------------------------------- */ - -void Modify::fused_integrate() -{ - for (int i = 0; i < n_final_integrate; i++) - fix[list_final_integrate[i]]->fused_integrate(); -} - /* ---------------------------------------------------------------------- end-of-timestep call, only for relevant fixes only call fix->end_of_step() on timesteps that are multiples of nevery @@ -1810,22 +1799,3 @@ double Modify::memory_usage() for (int i = 0; i < ncompute; i++) bytes += compute[i]->memory_usage(); return bytes; } - -/* ---------------------------------------------------------------------- - check if initial and final integrate can be fused -------------------------------------------------------------------------- */ - -int Modify::check_fuse_integrate() -{ - int fuse_integrate_flag = 1; - - for (int i = 0; i < n_initial_integrate; i++) - if (!fix[list_initial_integrate[i]]->fuse_integrate_flag) - fuse_integrate_flag = 0; - - for (int i = 0; i < n_final_integrate; i++) - if (!fix[list_final_integrate[i]]->fuse_integrate_flag) - fuse_integrate_flag = 0; - - return fuse_integrate_flag; -} diff --git a/src/modify.h b/src/modify.h index da3cab55d2..9686a9a7ec 100644 --- a/src/modify.h +++ b/src/modify.h @@ -61,7 +61,7 @@ class Modify : protected Pointers { virtual void setup_pre_force(int); virtual void setup_pre_reverse(int, int); virtual void initial_integrate(int); - virtual void fused_integrate(); + virtual void fused_integrate() {} virtual void post_integrate(); virtual void pre_exchange(); virtual void pre_neighbor(); @@ -151,8 +151,6 @@ class Modify : protected Pointers { double memory_usage(); - int check_fuse_integrate(); - protected: // internal fix counts