diff --git a/src/balance.cpp b/src/balance.cpp index feb5eae708..c8d21f1ac6 100644 --- a/src/balance.cpp +++ b/src/balance.cpp @@ -1029,12 +1029,12 @@ void Balance::tally(int dim, int n, double *split) if (wtflag) { weight = fixstore->vstore; for (int i = 0; i < nlocal; i++) { - index = binary(x[i][dim],n,split); + index = utils::binary_search(x[i][dim],n,split); onecost[index] += weight[i]; } } else { for (int i = 0; i < nlocal; i++) { - index = binary(x[i][dim],n,split); + index = utils::binary_search(x[i][dim],n,split); onecost[index] += 1.0; } } @@ -1131,16 +1131,16 @@ double Balance::imbalance_splits() if (wtflag) { weight = fixstore->vstore; for (int i = 0; i < nlocal; i++) { - ix = binary(x[i][0],nx,xsplit); - iy = binary(x[i][1],ny,ysplit); - iz = binary(x[i][2],nz,zsplit); + ix = utils::binary_search(x[i][0],nx,xsplit); + iy = utils::binary_search(x[i][1],ny,ysplit); + iz = utils::binary_search(x[i][2],nz,zsplit); proccost[iz*nx*ny + iy*nx + ix] += weight[i]; } } else { for (int i = 0; i < nlocal; i++) { - ix = binary(x[i][0],nx,xsplit); - iy = binary(x[i][1],ny,ysplit); - iz = binary(x[i][2],nz,zsplit); + ix = utils::binary_search(x[i][0],nx,xsplit); + iy = utils::binary_search(x[i][1],ny,ysplit); + iz = utils::binary_search(x[i][2],nz,zsplit); proccost[iz*nx*ny + iy*nx + ix] += 1.0; } } @@ -1161,40 +1161,6 @@ double Balance::imbalance_splits() return imbalance; } -/* ---------------------------------------------------------------------- - binary search for where value falls in N-length vec - note that vec actually has N+1 values, but ignore last one - values in vec are monotonically increasing, but adjacent values can be ties - value may be outside range of vec limits - always return index from 0 to N-1 inclusive - return 0 if value < vec[0] - reutrn N-1 if value >= vec[N-1] - return index = 1 to N-2 inclusive if vec[index] <= value < vec[index+1] - note that for adjacent tie values, index of lower tie is not returned - since never satisfies 2nd condition that value < vec[index+1] -------------------------------------------------------------------------- */ - -int Balance::binary(double value, int n, double *vec) -{ - int lo = 0; - int hi = n-1; - - if (value < vec[lo]) return lo; - if (value >= vec[hi]) return hi; - - // insure vec[lo] <= value < vec[hi] at every iteration - // done when lo,hi are adjacent - - int index = (lo+hi)/2; - while (lo < hi-1) { - if (value < vec[index]) hi = index; - else if (value >= vec[index]) lo = index; - index = (lo+hi)/2; - } - - return index; -} - /* ---------------------------------------------------------------------- write dump snapshot of line segments in Pizza.py mdump mesh format write xy lines around each proc's sub-domain for 2d diff --git a/src/balance.h b/src/balance.h index 0f2df8c5ee..e427c8bd9b 100644 --- a/src/balance.h +++ b/src/balance.h @@ -84,7 +84,6 @@ class Balance : public Command { void shift_setup_static(char *); void tally(int, int, double *); int adjust(int, double *); - int binary(double, int, double *); #ifdef BALANCE_DEBUG void debug_shift_output(int, int, int, double *); #endif diff --git a/src/comm.cpp b/src/comm.cpp index 17175f9517..f79c48ace1 100644 --- a/src/comm.cpp +++ b/src/comm.cpp @@ -782,13 +782,13 @@ int Comm::coord2proc(double *x, int &igx, int &igy, int &igz) } else if (layout == Comm::LAYOUT_NONUNIFORM) { if (triclinic == 0) { - igx = binary((x[0]-boxlo[0])/prd[0],procgrid[0],xsplit); - igy = binary((x[1]-boxlo[1])/prd[1],procgrid[1],ysplit); - igz = binary((x[2]-boxlo[2])/prd[2],procgrid[2],zsplit); + igx = utils::binary_search((x[0]-boxlo[0])/prd[0],procgrid[0],xsplit); + igy = utils::binary_search((x[1]-boxlo[1])/prd[1],procgrid[1],ysplit); + igz = utils::binary_search((x[2]-boxlo[2])/prd[2],procgrid[2],zsplit); } else { - igx = binary(x[0],procgrid[0],xsplit); - igy = binary(x[1],procgrid[1],ysplit); - igz = binary(x[2],procgrid[2],zsplit); + igx = utils::binary_search(x[0],procgrid[0],xsplit); + igy = utils::binary_search(x[1],procgrid[1],ysplit); + igz = utils::binary_search(x[2],procgrid[2],zsplit); } } @@ -802,36 +802,6 @@ int Comm::coord2proc(double *x, int &igx, int &igy, int &igz) return grid2proc[igx][igy][igz]; } -/* ---------------------------------------------------------------------- - binary search for value in N-length ascending vec - value may be outside range of vec limits - always return index from 0 to N-1 inclusive - return 0 if value < vec[0] - reutrn N-1 if value >= vec[N-1] - return index = 1 to N-2 if vec[index] <= value < vec[index+1] -------------------------------------------------------------------------- */ - -int Comm::binary(double value, int n, double *vec) -{ - int lo = 0; - int hi = n-1; - - if (value < vec[lo]) return lo; - if (value >= vec[hi]) return hi; - - // insure vec[lo] <= value < vec[hi] at every iteration - // done when lo,hi are adjacent - - int index = (lo+hi)/2; - while (lo < hi-1) { - if (value < vec[index]) hi = index; - else if (value >= vec[index]) lo = index; - index = (lo+hi)/2; - } - - return index; -} - /* ---------------------------------------------------------------------- partition a global regular grid into one brick-shaped sub-grid per proc if grid point is inside my sub-domain I own it, diff --git a/src/comm.h b/src/comm.h index 817d1230a2..43a30675f3 100644 --- a/src/comm.h +++ b/src/comm.h @@ -98,7 +98,6 @@ class Comm : protected Pointers { virtual void forward_comm_array(int, double **) = 0; virtual int exchange_variable(int, double *, double *&) = 0; - int binary(double, int, double *); // map a point to a processor, based on current decomposition diff --git a/src/utils.cpp b/src/utils.cpp index 528908807c..6ac3131f29 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1284,6 +1284,28 @@ int utils::date2num(const std::string &date) return num; } +/* binary search in vector of ascending doubles */ +int utils::binary_search(const double needle, const int n, const double *haystack) +{ + int lo = 0; + int hi = n-1; + + if (needle < haystack[lo]) return lo; + if (needle >= haystack[hi]) return hi; + + // insure haystack[lo] <= needle < haystack[hi] at every iteration + // done when lo,hi are adjacent + + int index = (lo+hi)/2; + while (lo < hi-1) { + if (needle < haystack[index]) hi = index; + else if (needle >= haystack[index]) lo = index; + index = (lo+hi)/2; + } + + return index; +} + /* ---------------------------------------------------------------------- * Merge sort part 1: Loop over sublists doubling in size with each iteration. * Pre-sort small sublists with insertion sort for better overall performance. diff --git a/src/utils.h b/src/utils.h index b761534d66..0ec1da379a 100644 --- a/src/utils.h +++ b/src/utils.h @@ -541,6 +541,23 @@ namespace utils { int date2num(const std::string &date); + /*! Binary search in a vector of ascending doubles of length N + * + * If the value is smaller than the smallest value in the vector, 0 is returned. + * If the value is larger or equal than the largest value in the vector, N-1 is returned. + * Otherwise the index that satisfies the condition + * + * haystack[index] <= value < haystack[index+1] + * + * is returned, i.e. a value from 1 to N-2. Note that if there are tied values in the + * haystack, always the larger index is returned as only that satisfied the condition. + * + * \param needle search value for which are are looking for the closest index + * \param n size of the haystack array + * \param haystack array with data in ascending order. + * \return index of value in the haystack array smaller or equal to needle */ + int binary_search(const double needle, const int n, const double *haystack); + /*! Custom merge sort implementation * * This function provides a custom upward hybrid merge sort diff --git a/unittest/utils/test_utils.cpp b/unittest/utils/test_utils.cpp index 8c2845e817..6db1c5c7ee 100644 --- a/unittest/utils/test_utils.cpp +++ b/unittest/utils/test_utils.cpp @@ -875,6 +875,24 @@ TEST(Utils, date2num) ASSERT_EQ(utils::date2num("31December100"), 1001231); } +TEST(Utils, binary_search) +{ + double data[] = {-2.0, -1.8, -1.0, -1.0, -1.0, -0.5, -0.2, 0.0, 0.1, 0.1, + 0.2, 0.3, 0.5, 0.5, 0.6, 0.7, 1.0, 1.2, 1.5, 2.0}; + const int n = sizeof(data) / sizeof(double); + ASSERT_EQ(utils::binary_search(-5.0, n, data), 0); + ASSERT_EQ(utils::binary_search(-2.0, n, data), 0); + ASSERT_EQ(utils::binary_search(-1.9, n, data), 0); + ASSERT_EQ(utils::binary_search(-1.0, n, data), 4); + ASSERT_EQ(utils::binary_search(0.0, n, data), 7); + ASSERT_EQ(utils::binary_search(0.1, n, data), 9); + ASSERT_EQ(utils::binary_search(0.4, n, data), 11); + ASSERT_EQ(utils::binary_search(1.1, n, data), 16); + ASSERT_EQ(utils::binary_search(1.5, n, data), 18); + ASSERT_EQ(utils::binary_search(2.0, n, data), 19); + ASSERT_EQ(utils::binary_search(2.5, n, data), 19); +} + static int compare(int a, int b, void *) { if (a < b)