如何正确地编写二分查找算法
0. 编程挑战
当你看到这篇文章的时候,可以先掏出来记事本,试着完成如下原型的C语言函数:
/**
* Given an sorted array `arr` (ascending), the lowest pos `low` and highest pos `high`,
* find out `x` where `arr[x] == key`, in O(log(n)) time.
* If `key` does not exist in `arr`, return -1.
* Note: [low, high] is a valid range in array `arr`, i.e., 0 <= low <= high < arr.length.
*/
int binary_search(int arr[], int low, int high, int key) {
// ...
}
如果你可以全程保持头脑清醒并确信你的实现是bug-free的话,可以关掉这个文章了,否则,请你继续读下去。
如果你不能完成这个算法,或者不确定你写的二分查找是否有BUG,也不用慌,因为Knuth在TAOCP中如是评论二分查找:
Although the basic idea of binary search is comparatively straightforward, the details can be surprisingly tricky…
1. 一种正确的实现:以OpenJDK为例
事实上,二分查找的简洁而正确的实现似乎不多。我们不妨先看一下最新的OpenJDK是如何实现Arrays.binarySearch的:
public class Arrays {
/* ... */
public static int binarySearch(int[] a, int key) {
return binarySearch0(a, 0, a.length, key);
}
private static int binarySearch0(int[] a, int fromIndex, int toIndex,
int key) {
int low = fromIndex;
int high = toIndex - 1;
while (low <= high) {
int mid = (low + high) >>> 1;
int midVal = a[mid];
if (midVal < key)
low = mid + 1;
else if (midVal > key)
high = mid - 1;
else
return mid; // key found
}
return -(low + 1); // key not found.
}
/* ... */
}
这个实现的思路是这样的:
闭区间[low, high]的初始值是[0, a.length - 1]
若闭区间[low, high]非空,则不断尝试:
取分割点mid为low和high的算术均值 (1)
记分割点mid处的值为midVal
情况一:若midVal小于要找的值key,则key一定在[low, high]中割点mid右边(不含)的子区间,
即[mid + 1, high] // [a[low], ..., midVal, ..., key, ..., a[high]] (2)
情况二:若midVal大于要找的值key,则key一定在[low, high]中割点mid左边(不含)的子区间,
即[low, mid - 1] // [a[low], ..., key, ..., midVal, ..., a[high]] (3)
情况三:若要找的值key等于midVal,则mid即为要寻找的下标,返回mid,算法结束
(此处再无其他情况,当前状态必为三者之一)
此时闭区间[low, high]为空,返回“未找到”,算法结束
实际上,要理解该实现,关键是理解该算法内部的循环的循环不变量(loop invariant):
- 在循环内部,闭区间$[low, high]$一定非空。
- 若$key$存在于数组$a$中,且进入循环时$key$存在于闭区间$[low, high]$内,则当前这一遍循环结束时,$key$仍然存在于闭区间$[low, high]$内。
让我们分别看伪代码中的(1)、(2)和(3),指出一些常见的陷阱:
(1) 取分割点$mid$为$low$和$high$的算术均值
这条伪代码对应的Java代码是:
int mid = (low + high) >>> 1;
Java的>>>
进行的是零扩展(zero extension),而>>
进行的是符号扩展(sign extension)。这里的mid = (low + high) >>> 1
不能换成mid = (low + high) / 2
,因为后者可能因为整数溢出而使得mid变为一个负值,而无符号移位操作是零扩展的,可以保证得到的数是个正数。等价的C++写法是mid = ((unsigned)low + (unsigned)high) >> 1
(C++里的>>
是符号相关的)。
这一条赋值语句也可以写成mid = low + (high - low) / 2
,这里的加减法永远不会溢出。
(2) (3) 缩小区间范围
这一步会将区间更新为$[mid + 1, high]$或者$[low, mid - 1]$。有人可能会觉得更新为$[mid, high]$或$[low, mid]$也是可以的,这不就是多判断了一个数字吗?实际上,这样做会导致算法可能陷入死循环:如果$low = high = 0$,那么$mid$在更新后仍然是$0$,区间将无法缩小到空区间(即循环条件$low <= high$不成立的情况)。
2. 另一种正确的实现:C++ STL
JDK的实现永远是那么的实用、易懂,而C++就不一样了。C++ STL总是能整一些花活,比如这样:
template<class ForwardIt, class T>
bool binary_search(ForwardIt first, ForwardIt last, const T& value)
{
first = std::lower_bound(first, last, value);
return (!(first == last) && !(value < *first));
}
如果你觉得这个看起来有点难受的话,那么真正的STL实现(GNU C++ STL)就有点令人恶心了:
/**
* @brief Determines whether an element exists in a range.
* @ingroup binary_search_algorithms
* @param __first An iterator.
* @param __last Another iterator.
* @param __val The search term.
* @return True if @p __val (or its equivalent) is in [@p
* __first,@p __last ].
*
* Note that this does not actually return an iterator to @p __val. For
* that, use std::find or a container's specialized find member functions.
*/
template<typename _ForwardIterator, typename _Tp>
bool
binary_search(_ForwardIterator __first, _ForwardIterator __last, const _Tp& __val)
{
// concept requirements
__glibcxx_function_requires(_ForwardIteratorConcept<_ForwardIterator>)
__glibcxx_function_requires(_LessThanOpConcept<
_Tp, typename iterator_traits<_ForwardIterator>::value_type>)
__glibcxx_requires_partitioned_lower(__first, __last, __val);
__glibcxx_requires_partitioned_upper(__first, __last, __val);
_ForwardIterator __i
= std::__lower_bound(__first, __last, __val,
__gnu_cxx::__ops::__iter_less_val());
return __i != __last && !(__val < *__i);
}
我在第一次看到这段代码的时候,差点直接关掉这个页面。但是理智告诉我应该去看一下std::__lower_bound
的实现(不是吗?),于是我找到了这个玩意:
template<typename _ForwardIterator, typename _Tp, typename _Compare>
_ForwardIterator
__lower_bound(_ForwardIterator __first, _ForwardIterator __last,
const _Tp& __val, _Compare __comp)
{
typedef typename __mv_iter_traits<_ForwardIterator>::difference_type
_DistanceType;
_DistanceType __len = __last - __first;
_DistanceType __half;
_ForwardIterator __middle;
while (__len > 0)
{
__half = __len >> 1;
__middle = __first;
__middle += __half;
if (__comp(*__middle, __val))
{
__first = __middle;
++__first;
__len = __len - __half - 1;
}
else
__len = __half;
}
return __first;
}
好吧,这个看起来也很吓人,不过比较好的一点是,至少我们能看出他似乎是一个二分查找过程。我们不妨把他简化成一个运行在int型数组上的实现:
typedef unsigned int uint;
/**
* @brief find the smallest x where first <= x < last and arr[x] == val. `arr` is ascending.
*
* @return int -1 if val does not present in arr, otherwise the desired index.
*/
int lower_bound(int arr[], uint first, uint last, int val) {
uint len = last - first;
// while range [first, first + len) is valid
while (len > 0) {
uint half = len >> 1;
if (arr[first + half] < val) {
// discard first $(half + 1)$ elements
first += half + 1;
len -= half + 1;
} else {
// discard the last $half$ elements
len = half;
}
}
return first;
}
这个函数实现的功能是:在升序序列$arr$中,寻找最小的$0 <= x < arr.length$,使得$arr[x] == val$,即“下界”。
可以分析一下,当序列中不存在要找的$val$时,有两种情况:
- 序列的所有值均小于要找的$val$,此时返回值为$last$,即序列的长度(最后一个有效元素的下一个位置的下标)。这是一个越界的下标。比如:序列$1, 2, 3, 4, 5$,当$val = 10$时,运行结果是$5$。所以,如果返回值为$last$,则立即可以知道$arr$中不存在我们要找的$val$。
- 序列的所有值均大于要找的$val$,此时返回值为$first$,即第一个元素的下标。不过这是一个合法的下标,如果$arr[first] = val$,程序也可以返回$first$,但此时序列中显然存在$val$。但是我们知道,序列里的所有元素都比$val$大,因此检查这里的元素是否比$val$大即可知道序列中是否存在$val$。
所以,此两种条件可归纳为一个布尔表达式来判断序列$arr$中是否有$val$:
int ret = lower_bound(/* params */);
bool exists = !((ret == last) || (arr[ret] > val));
到此,我们可以验证C++ STL的写法是正确的。通过学习C++和Java标准库的二分查找实现,我们可以从思想上把握二分查找算法的具体实现:选定一个循环不变量,并总是让他不变。只有对算法有一个清醒的认识,才可能编写出好的算法实现。