diff --git a/src/KOKKOS/atom_kokkos.h b/src/KOKKOS/atom_kokkos.h index d2629b4441..bf2d74f511 100644 --- a/src/KOKKOS/atom_kokkos.h +++ b/src/KOKKOS/atom_kokkos.h @@ -14,7 +14,6 @@ #include "atom.h" // IWYU pragma: export #include "kokkos_type.h" -#include #ifndef LMP_ATOM_KOKKOS_H #define LMP_ATOM_KOKKOS_H @@ -71,25 +70,22 @@ class AtomKokkos : public Atom { ~AtomKokkos(); void map_init(int check = 1); - void map_clear(); void map_set(); void map_delete(); DAT::tdual_int_1d k_sametag; DAT::tdual_int_1d k_map_array; DAT::tdual_int_scalar k_error_flag; - - typedef Kokkos::UnorderedMap hash_type; - typedef Kokkos::DualView dual_hash_type; - typedef dual_hash_type::t_host::data_type host_hash_type; dual_hash_type k_map_hash; + hash_type d_map_hash; + host_hash_type h_map_hash; template KOKKOS_INLINE_FUNCTION static int map_find_hash_kokkos(tagint global, dual_hash_type k_map_hash) { int local = -1; - auto d_map_hash = k_map_hash.view()(); + auto d_map_hash = k_map_hash.view(); auto index = d_map_hash.find(global); if (d_map_hash.valid_at(index)) local = d_map_hash.value_at(index); diff --git a/src/KOKKOS/atom_map_kokkos.cpp b/src/KOKKOS/atom_map_kokkos.cpp index 4c3d5cd1e9..486594e182 100644 --- a/src/KOKKOS/atom_map_kokkos.cpp +++ b/src/KOKKOS/atom_map_kokkos.cpp @@ -13,8 +13,11 @@ ------------------------------------------------------------------------- */ #include "atom_kokkos.h" +#include "neighbor_kokkos.h" #include "comm.h" #include "error.h" +#include "modify.h" +#include "fix.h" #include "memory_kokkos.h" #include "atom_masks.h" @@ -61,8 +64,6 @@ void AtomKokkos::map_init(int check) map_free = 0; for (int i = 0; i < map_nhash; i++) map_hash[i].next = i+1; if (map_nhash > 0) map_hash[map_nhash-1].next = -1; - - k_map_hash.h_view().clear(); } // recreating: delete old map and create new one for array or hash @@ -106,34 +107,9 @@ void AtomKokkos::map_init(int check) for (int i = 0; i < map_nhash; i++) map_hash[i].next = i+1; map_hash[map_nhash-1].next = -1; - k_map_hash = dual_hash_type("atom:map_hash"); - k_map_hash.h_view() = host_hash_type(map_nhash); + h_map_hash = host_hash_type(map_nhash); } } - - if (map_style == Atom::MAP_ARRAY) - k_map_array.modify_host(); - else if (map_style == Atom::MAP_HASH) - k_map_hash.modify_host(); -} - -/* ---------------------------------------------------------------------- - clear global -> local map for all of my own and ghost atoms - for hash table option: - global ID may not be in table if image atom was already cleared -------------------------------------------------------------------------- */ - -void AtomKokkos::map_clear() -{ - Atom::map_clear(); - - if (map_style == MAP_ARRAY) { - k_map_array.modify_host(); - } else { - k_map_hash.h_view().clear(); - k_map_hash.modify_host(); - } - k_sametag.modify_host(); } /* ---------------------------------------------------------------------- @@ -225,8 +201,7 @@ void AtomKokkos::map_set() // Copy to Kokkos hash - k_map_hash.h_view().clear(); - auto h_map_hash = k_map_hash.h_view(); + h_map_hash.clear(); for (int i = nall-1; i >= 0 ; i--) { @@ -252,13 +227,24 @@ void AtomKokkos::map_set() } } - k_sametag.modify_host(); - if (map_style == Atom::MAP_ARRAY) - k_map_array.modify_host(); - else if (map_style == Atom::MAP_HASH) - k_map_hash.modify_host(); -} + // check if fix shake or neigh bond needs a device hash + int device_hash_flag = 0; + + auto neighborKK = (NeighborKokkos*) neighbor; + if (neighborKK->device_flag) device_hash_flag = 1; + + for (int n = 0; n < modify->nfix; n++) + if (utils::strmatch(modify->fix[n]->style,"^shake")) + if (modify->fix[n]->execution_space == Device) + device_hash_flag = 1; + + if (device_hash_flag) + Kokkos::deep_copy(d_map_hash,h_map_hash); + + k_map_hash.h_view = h_map_hash; + k_map_hash.d_view = d_map_hash; +} /* ---------------------------------------------------------------------- free the array or hash table for global to local mapping @@ -268,20 +254,14 @@ void AtomKokkos::map_delete() { memoryKK->destroy_kokkos(k_sametag,sametag); sametag = nullptr; - max_same = 0; if (map_style == MAP_ARRAY) { memoryKK->destroy_kokkos(k_map_array,map_array); map_array = nullptr; } else { - if (map_nhash) { - delete [] map_bucket; - delete [] map_hash; - map_bucket = nullptr; - map_hash = nullptr; - k_map_hash = dual_hash_type(); - } - map_nhash = map_nbucket = 0; + h_map_hash = host_hash_type(); + d_map_hash = hash_type(); } -} + Atom::map_delete(); +} diff --git a/src/KOKKOS/fix_shake_kokkos.cpp b/src/KOKKOS/fix_shake_kokkos.cpp index 594c9c60db..a50ec573b4 100644 --- a/src/KOKKOS/fix_shake_kokkos.cpp +++ b/src/KOKKOS/fix_shake_kokkos.cpp @@ -207,7 +207,6 @@ void FixShakeKokkos::pre_neighbor() k_map_array.template sync(); } else if (map_style == Atom::MAP_HASH) { k_map_hash = atomKK->k_map_hash; - k_map_hash.template sync(); } k_shake_flag.sync(); @@ -231,7 +230,6 @@ void FixShakeKokkos::pre_neighbor() k_map_array.template sync(); } else if (map_style == Atom::MAP_HASH) { k_map_hash = atomKK->k_map_hash; - k_map_hash.template sync(); } // build list of SHAKE clusters I compute @@ -320,7 +318,6 @@ void FixShakeKokkos::post_force(int vflag) k_map_array.template sync(); } else if (map_style == Atom::MAP_HASH) { k_map_hash = atomKK->k_map_hash; - k_map_hash.template sync(); } if (d_rmass.data()) diff --git a/src/KOKKOS/fix_shake_kokkos.h b/src/KOKKOS/fix_shake_kokkos.h index fbf6e1de88..698298c65d 100644 --- a/src/KOKKOS/fix_shake_kokkos.h +++ b/src/KOKKOS/fix_shake_kokkos.h @@ -187,8 +187,6 @@ class FixShakeKokkos : public FixShake, public KokkosBase { DAT::tdual_int_1d k_map_array; - typedef Kokkos::UnorderedMap hash_type; - typedef Kokkos::DualView dual_hash_type; dual_hash_type k_map_hash; KOKKOS_INLINE_FUNCTION diff --git a/src/KOKKOS/kokkos_type.h b/src/KOKKOS/kokkos_type.h index c9e5736410..7a27e3f20f 100644 --- a/src/KOKKOS/kokkos_type.h +++ b/src/KOKKOS/kokkos_type.h @@ -23,6 +23,7 @@ #include #include #include +#include enum{FULL=1u,HALFTHREAD=2u,HALF=4u}; @@ -551,6 +552,20 @@ typedef int T_INT; // LAMMPS types +typedef Kokkos::UnorderedMap hash_type; +typedef hash_type::HostMirror host_hash_type; + +struct dual_hash_type { + hash_type d_view; + host_hash_type h_view; + + template + std::enable_if_t::value,hash_type> view() {return d_view;} + template + std::enable_if_t::value,host_hash_type> view() {return h_view;} + +}; + template struct ArrayTypes; @@ -837,6 +852,8 @@ typedef tdual_neighbors_2d::t_dev_um t_neighbors_2d_um; typedef tdual_neighbors_2d::t_dev_const_um t_neighbors_2d_const_um; typedef tdual_neighbors_2d::t_dev_const_randomread t_neighbors_2d_randomread; +typedef hash_type t_hash; + }; #ifdef LMP_KOKKOS_GPU @@ -1098,6 +1115,8 @@ typedef tdual_neighbors_2d::t_host_um t_neighbors_2d_um; typedef tdual_neighbors_2d::t_host_const_um t_neighbors_2d_const_um; typedef tdual_neighbors_2d::t_host_const_randomread t_neighbors_2d_randomread; +typedef host_hash_type t_hash; + }; #endif //default LAMMPS Types diff --git a/src/KOKKOS/neigh_bond_kokkos.cpp b/src/KOKKOS/neigh_bond_kokkos.cpp index c427d18216..228d521fdd 100644 --- a/src/KOKKOS/neigh_bond_kokkos.cpp +++ b/src/KOKKOS/neigh_bond_kokkos.cpp @@ -1241,7 +1241,6 @@ void NeighBondKokkos::update_class_variables() k_map_array.template sync(); } else if (map_style == Atom::MAP_HASH) { k_map_hash = atomKK->k_map_hash; - k_map_hash.template sync(); } } diff --git a/src/KOKKOS/neigh_bond_kokkos.h b/src/KOKKOS/neigh_bond_kokkos.h index 9bf508ec7a..d60ad2f1aa 100644 --- a/src/KOKKOS/neigh_bond_kokkos.h +++ b/src/KOKKOS/neigh_bond_kokkos.h @@ -89,8 +89,6 @@ class NeighBondKokkos : protected Pointers { DAT::tdual_int_1d k_map_array; - typedef Kokkos::UnorderedMap hash_type; - typedef Kokkos::DualView dual_hash_type; dual_hash_type k_map_hash; KOKKOS_INLINE_FUNCTION diff --git a/src/KOKKOS/neighbor_kokkos.h b/src/KOKKOS/neighbor_kokkos.h index 1a560e2129..ef885535ed 100644 --- a/src/KOKKOS/neighbor_kokkos.h +++ b/src/KOKKOS/neighbor_kokkos.h @@ -64,13 +64,14 @@ class NeighborKokkos : public Neighbor { DAT::tdual_int_2d k_dihedrallist; DAT::tdual_int_2d k_improperlist; + int device_flag; + private: DAT::tdual_x_array x; DAT::tdual_x_array xhold; X_FLOAT deltasq; - int device_flag; void init_cutneighsq_kokkos(int); void create_kokkos_list(int);