consolidate binary() member functions of Comm and Balance into utils::binary_search()

This commit is contained in:
Axel Kohlmeyer
2021-09-06 16:58:14 -04:00
parent 801cd647c3
commit 898f8086db
7 changed files with 71 additions and 80 deletions

View File

@ -1029,12 +1029,12 @@ void Balance::tally(int dim, int n, double *split)
if (wtflag) {
weight = fixstore->vstore;
for (int i = 0; i < nlocal; i++) {
index = binary(x[i][dim],n,split);
index = utils::binary_search(x[i][dim],n,split);
onecost[index] += weight[i];
}
} else {
for (int i = 0; i < nlocal; i++) {
index = binary(x[i][dim],n,split);
index = utils::binary_search(x[i][dim],n,split);
onecost[index] += 1.0;
}
}
@ -1131,16 +1131,16 @@ double Balance::imbalance_splits()
if (wtflag) {
weight = fixstore->vstore;
for (int i = 0; i < nlocal; i++) {
ix = binary(x[i][0],nx,xsplit);
iy = binary(x[i][1],ny,ysplit);
iz = binary(x[i][2],nz,zsplit);
ix = utils::binary_search(x[i][0],nx,xsplit);
iy = utils::binary_search(x[i][1],ny,ysplit);
iz = utils::binary_search(x[i][2],nz,zsplit);
proccost[iz*nx*ny + iy*nx + ix] += weight[i];
}
} else {
for (int i = 0; i < nlocal; i++) {
ix = binary(x[i][0],nx,xsplit);
iy = binary(x[i][1],ny,ysplit);
iz = binary(x[i][2],nz,zsplit);
ix = utils::binary_search(x[i][0],nx,xsplit);
iy = utils::binary_search(x[i][1],ny,ysplit);
iz = utils::binary_search(x[i][2],nz,zsplit);
proccost[iz*nx*ny + iy*nx + ix] += 1.0;
}
}
@ -1161,40 +1161,6 @@ double Balance::imbalance_splits()
return imbalance;
}
/* ----------------------------------------------------------------------
binary search for where value falls in N-length vec
note that vec actually has N+1 values, but ignore last one
values in vec are monotonically increasing, but adjacent values can be ties
value may be outside range of vec limits
always return index from 0 to N-1 inclusive
return 0 if value < vec[0]
reutrn N-1 if value >= vec[N-1]
return index = 1 to N-2 inclusive if vec[index] <= value < vec[index+1]
note that for adjacent tie values, index of lower tie is not returned
since never satisfies 2nd condition that value < vec[index+1]
------------------------------------------------------------------------- */
int Balance::binary(double value, int n, double *vec)
{
int lo = 0;
int hi = n-1;
if (value < vec[lo]) return lo;
if (value >= vec[hi]) return hi;
// insure vec[lo] <= value < vec[hi] at every iteration
// done when lo,hi are adjacent
int index = (lo+hi)/2;
while (lo < hi-1) {
if (value < vec[index]) hi = index;
else if (value >= vec[index]) lo = index;
index = (lo+hi)/2;
}
return index;
}
/* ----------------------------------------------------------------------
write dump snapshot of line segments in Pizza.py mdump mesh format
write xy lines around each proc's sub-domain for 2d

View File

@ -84,7 +84,6 @@ class Balance : public Command {
void shift_setup_static(char *);
void tally(int, int, double *);
int adjust(int, double *);
int binary(double, int, double *);
#ifdef BALANCE_DEBUG
void debug_shift_output(int, int, int, double *);
#endif

View File

@ -782,13 +782,13 @@ int Comm::coord2proc(double *x, int &igx, int &igy, int &igz)
} else if (layout == Comm::LAYOUT_NONUNIFORM) {
if (triclinic == 0) {
igx = binary((x[0]-boxlo[0])/prd[0],procgrid[0],xsplit);
igy = binary((x[1]-boxlo[1])/prd[1],procgrid[1],ysplit);
igz = binary((x[2]-boxlo[2])/prd[2],procgrid[2],zsplit);
igx = utils::binary_search((x[0]-boxlo[0])/prd[0],procgrid[0],xsplit);
igy = utils::binary_search((x[1]-boxlo[1])/prd[1],procgrid[1],ysplit);
igz = utils::binary_search((x[2]-boxlo[2])/prd[2],procgrid[2],zsplit);
} else {
igx = binary(x[0],procgrid[0],xsplit);
igy = binary(x[1],procgrid[1],ysplit);
igz = binary(x[2],procgrid[2],zsplit);
igx = utils::binary_search(x[0],procgrid[0],xsplit);
igy = utils::binary_search(x[1],procgrid[1],ysplit);
igz = utils::binary_search(x[2],procgrid[2],zsplit);
}
}
@ -802,36 +802,6 @@ int Comm::coord2proc(double *x, int &igx, int &igy, int &igz)
return grid2proc[igx][igy][igz];
}
/* ----------------------------------------------------------------------
binary search for value in N-length ascending vec
value may be outside range of vec limits
always return index from 0 to N-1 inclusive
return 0 if value < vec[0]
reutrn N-1 if value >= vec[N-1]
return index = 1 to N-2 if vec[index] <= value < vec[index+1]
------------------------------------------------------------------------- */
int Comm::binary(double value, int n, double *vec)
{
int lo = 0;
int hi = n-1;
if (value < vec[lo]) return lo;
if (value >= vec[hi]) return hi;
// insure vec[lo] <= value < vec[hi] at every iteration
// done when lo,hi are adjacent
int index = (lo+hi)/2;
while (lo < hi-1) {
if (value < vec[index]) hi = index;
else if (value >= vec[index]) lo = index;
index = (lo+hi)/2;
}
return index;
}
/* ----------------------------------------------------------------------
partition a global regular grid into one brick-shaped sub-grid per proc
if grid point is inside my sub-domain I own it,

View File

@ -98,7 +98,6 @@ class Comm : protected Pointers {
virtual void forward_comm_array(int, double **) = 0;
virtual int exchange_variable(int, double *, double *&) = 0;
int binary(double, int, double *);
// map a point to a processor, based on current decomposition

View File

@ -1284,6 +1284,28 @@ int utils::date2num(const std::string &date)
return num;
}
/* binary search in vector of ascending doubles */
int utils::binary_search(const double needle, const int n, const double *haystack)
{
int lo = 0;
int hi = n-1;
if (needle < haystack[lo]) return lo;
if (needle >= haystack[hi]) return hi;
// insure haystack[lo] <= needle < haystack[hi] at every iteration
// done when lo,hi are adjacent
int index = (lo+hi)/2;
while (lo < hi-1) {
if (needle < haystack[index]) hi = index;
else if (needle >= haystack[index]) lo = index;
index = (lo+hi)/2;
}
return index;
}
/* ----------------------------------------------------------------------
* Merge sort part 1: Loop over sublists doubling in size with each iteration.
* Pre-sort small sublists with insertion sort for better overall performance.

View File

@ -541,6 +541,23 @@ namespace utils {
int date2num(const std::string &date);
/*! Binary search in a vector of ascending doubles of length N
*
* If the value is smaller than the smallest value in the vector, 0 is returned.
* If the value is larger or equal than the largest value in the vector, N-1 is returned.
* Otherwise the index that satisfies the condition
*
* haystack[index] <= value < haystack[index+1]
*
* is returned, i.e. a value from 1 to N-2. Note that if there are tied values in the
* haystack, always the larger index is returned as only that satisfied the condition.
*
* \param needle search value for which are are looking for the closest index
* \param n size of the haystack array
* \param haystack array with data in ascending order.
* \return index of value in the haystack array smaller or equal to needle */
int binary_search(const double needle, const int n, const double *haystack);
/*! Custom merge sort implementation
*
* This function provides a custom upward hybrid merge sort

View File

@ -875,6 +875,24 @@ TEST(Utils, date2num)
ASSERT_EQ(utils::date2num("31December100"), 1001231);
}
TEST(Utils, binary_search)
{
double data[] = {-2.0, -1.8, -1.0, -1.0, -1.0, -0.5, -0.2, 0.0, 0.1, 0.1,
0.2, 0.3, 0.5, 0.5, 0.6, 0.7, 1.0, 1.2, 1.5, 2.0};
const int n = sizeof(data) / sizeof(double);
ASSERT_EQ(utils::binary_search(-5.0, n, data), 0);
ASSERT_EQ(utils::binary_search(-2.0, n, data), 0);
ASSERT_EQ(utils::binary_search(-1.9, n, data), 0);
ASSERT_EQ(utils::binary_search(-1.0, n, data), 4);
ASSERT_EQ(utils::binary_search(0.0, n, data), 7);
ASSERT_EQ(utils::binary_search(0.1, n, data), 9);
ASSERT_EQ(utils::binary_search(0.4, n, data), 11);
ASSERT_EQ(utils::binary_search(1.1, n, data), 16);
ASSERT_EQ(utils::binary_search(1.5, n, data), 18);
ASSERT_EQ(utils::binary_search(2.0, n, data), 19);
ASSERT_EQ(utils::binary_search(2.5, n, data), 19);
}
static int compare(int a, int b, void *)
{
if (a < b)