Keuin's

如何正确地编写二分查找算法

0. 编程挑战

当你看到这篇文章的时候,可以先掏出来记事本,试着完成如下原型的C语言函数:

/**
 * Given an sorted array `arr` (ascending), the lowest pos `low` and highest pos `high`,
 * find out `x` where `arr[x] == key`, in O(log(n)) time.
 * If `key` does not exist in `arr`, return -1.
 * Note: [low, high] is a valid range in array `arr`, i.e., 0 <= low <= high < arr.length.
 */
int binary_search(int arr[], int low, int high, int key) {
    // ...
}

如果你可以全程保持头脑清醒并确信你的实现是bug-free的话,可以关掉这个文章了,否则,请你继续读下去。

如果你不能完成这个算法,或者不确定你写的二分查找是否有BUG,也不用慌,因为Knuth在TAOCP中如是评论二分查找:

Although the basic idea of binary search is comparatively straightforward, the details can be surprisingly tricky…

1. 一种正确的实现:以OpenJDK为例

事实上,二分查找的简洁而正确的实现似乎不多。我们不妨先看一下最新的OpenJDK是如何实现Arrays.binarySearch的:

public class Arrays {
    /* ... */
    public static int binarySearch(int[] a, int key) {
        return binarySearch0(a, 0, a.length, key);
    }

    private static int binarySearch0(int[] a, int fromIndex, int toIndex,
                                     int key) {
        int low = fromIndex;
        int high = toIndex - 1;

        while (low <= high) {
            int mid = (low + high) >>> 1;
            int midVal = a[mid];

            if (midVal < key)
                low = mid + 1;
            else if (midVal > key)
                high = mid - 1;
            else
                return mid; // key found
        }
        return -(low + 1);  // key not found.
    }
    /* ... */
}

这个实现的思路是这样的:

闭区间[low, high]的初始值是[0, a.length - 1]
若闭区间[low, high]非空,则不断尝试:
    取分割点mid为low和high的算术均值  (1)
    记分割点mid处的值为midVal
    情况一:若midVal小于要找的值key,则key一定在[low, high]中割点mid右边(不含)的子区间,
        即[mid + 1, high] // [a[low], ..., midVal, ..., key, ..., a[high]]  (2)
    情况二:若midVal大于要找的值key,则key一定在[low, high]中割点mid左边(不含)的子区间,
        即[low, mid - 1]  // [a[low], ..., key, ..., midVal, ..., a[high]]  (3)
    情况三:若要找的值key等于midVal,则mid即为要寻找的下标,返回mid,算法结束
    (此处再无其他情况,当前状态必为三者之一)
此时闭区间[low, high]为空,返回“未找到”,算法结束

实际上,要理解该实现,关键是理解该算法内部的循环的循环不变量(loop invariant)

让我们分别看伪代码中的(1)、(2)和(3),指出一些常见的陷阱:

(1) 取分割点$mid$为$low$和$high$的算术均值

这条伪代码对应的Java代码是:

int mid = (low + high) >>> 1;

Java的>>>进行的是零扩展(zero extension),而>>进行的是符号扩展(sign extension)。这里的mid = (low + high) >>> 1不能换成mid = (low + high) / 2,因为后者可能因为整数溢出而使得mid变为一个负值,而无符号移位操作是零扩展的,可以保证得到的数是个正数。等价的C++写法是mid = ((unsigned)low + (unsigned)high) >> 1(C++里的>>是符号相关的)。

这一条赋值语句也可以写成mid = low + (high - low) / 2,这里的加减法永远不会溢出。

(2) (3) 缩小区间范围

这一步会将区间更新为$[mid + 1, high]$或者$[low, mid - 1]$。有人可能会觉得更新为$[mid, high]$或$[low, mid]$也是可以的,这不就是多判断了一个数字吗?实际上,这样做会导致算法可能陷入死循环:如果$low = high = 0$,那么$mid$在更新后仍然是$0$,区间将无法缩小到空区间(即循环条件$low <= high$不成立的情况)。

2. 另一种正确的实现:C++ STL

JDK的实现永远是那么的实用、易懂,而C++就不一样了。C++ STL总是能整一些花活,比如这样

template<class ForwardIt, class T>
bool binary_search(ForwardIt first, ForwardIt last, const T& value)
{
    first = std::lower_bound(first, last, value);
    return (!(first == last) && !(value < *first));
}

如果你觉得这个看起来有点难受的话,那么真正的STL实现(GNU C++ STL)就有点令人恶心了:

/**
 *  @brief Determines whether an element exists in a range.
 *  @ingroup binary_search_algorithms
 *  @param  __first   An iterator.
 *  @param  __last    Another iterator.
 *  @param  __val     The search term.
 *  @return True if @p __val (or its equivalent) is in [@p
 *  __first,@p __last ].
 *
 *  Note that this does not actually return an iterator to @p __val.  For
 *  that, use std::find or a container's specialized find member functions.
*/
template<typename _ForwardIterator, typename _Tp>
  bool
  binary_search(_ForwardIterator __first, _ForwardIterator __last, const _Tp& __val)
  {
       // concept requirements
       __glibcxx_function_requires(_ForwardIteratorConcept<_ForwardIterator>)
       __glibcxx_function_requires(_LessThanOpConcept<
       _Tp, typename iterator_traits<_ForwardIterator>::value_type>)
       __glibcxx_requires_partitioned_lower(__first, __last, __val);
       __glibcxx_requires_partitioned_upper(__first, __last, __val);
       _ForwardIterator __i
       = std::__lower_bound(__first, __last, __val,
                   __gnu_cxx::__ops::__iter_less_val());
       return __i != __last && !(__val < *__i);
  }

我在第一次看到这段代码的时候,差点直接关掉这个页面。但是理智告诉我应该去看一下std::__lower_bound的实现(不是吗?),于是我找到了这个玩意:

template<typename _ForwardIterator, typename _Tp, typename _Compare>
    _ForwardIterator
    __lower_bound(_ForwardIterator __first, _ForwardIterator __last,
        const _Tp& __val, _Compare __comp)
{
    typedef typename __mv_iter_traits<_ForwardIterator>::difference_type
        _DistanceType;

    _DistanceType __len = __last - __first;
    _DistanceType __half;
    _ForwardIterator __middle;

    while (__len > 0)
    {
        __half = __len >> 1;
        __middle = __first;
        __middle += __half;
        if (__comp(*__middle, __val))
        {
            __first = __middle;
            ++__first;
            __len = __len - __half - 1;
        }
        else
            __len = __half;
    }
    return __first;
}

好吧,这个看起来也很吓人,不过比较好的一点是,至少我们能看出他似乎是一个二分查找过程。我们不妨把他简化成一个运行在int型数组上的实现:

typedef unsigned int uint;
/**
 * @brief find the smallest x where first <= x < last and arr[x] == val. `arr` is ascending.
 * 
 * @return int -1 if val does not present in arr, otherwise the desired index.
 */
int lower_bound(int arr[], uint first, uint last, int val) {
    uint len = last - first;
    // while range [first, first + len) is valid
    while (len > 0) {
        uint half = len >> 1;
        if (arr[first + half] < val) {
            // discard first $(half + 1)$ elements
            first += half + 1;
            len -= half + 1;
        } else {
            // discard the last $half$ elements
            len = half;
        }
    }
    return first;
}

这个函数实现的功能是:在升序序列$arr$中,寻找最小的$0 <= x < arr.length$,使得$arr[x] == val$,即“下界”。

可以分析一下,当序列中不存在要找的$val$时,有两种情况:

  1. 序列的所有值均小于要找的$val$,此时返回值为$last$,即序列的长度(最后一个有效元素的下一个位置的下标)。这是一个越界的下标。比如:序列$1, 2, 3, 4, 5$,当$val = 10$时,运行结果是$5$。所以,如果返回值为$last$,则立即可以知道$arr$中不存在我们要找的$val$。
  2. 序列的所有值均大于要找的$val$,此时返回值为$first$,即第一个元素的下标。不过这是一个合法的下标,如果$arr[first] = val$,程序也可以返回$first$,但此时序列中显然存在$val$。但是我们知道,序列里的所有元素都比$val$大,因此检查这里的元素是否比$val$大即可知道序列中是否存在$val$。

所以,此两种条件可归纳为一个布尔表达式来判断序列$arr$中是否有$val$:

int ret = lower_bound(/* params */);
bool exists = !((ret == last) || (arr[ret] > val));

到此,我们可以验证C++ STL的写法是正确的。通过学习C++和Java标准库的二分查找实现,我们可以从思想上把握二分查找算法的具体实现:选定一个循环不变量,并总是让他不变。只有对算法有一个清醒的认识,才可能编写出好的算法实现。