Fearless Binary Search

Oct 7, 2020 · 1025 words · 5 minutes read algorithm

I find it’s hard to write bug free binary search code when solving problems in Leetcode. This post try to make it easier.

Basic Version

Problem: (Leetcode 704)
Given a non-empty sorted integer array, find target number and return its index.

Example:

Input: nums = [-1,0,3,5,9,12], target = 9
Output: 4 (because 9 exists in nums and its index is 4)
Input: nums = [-1,0,3,5,9,12], target = 2
Output: -1 (2 does not exist in nums so return -1)

Solution:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
def search(self, nums: List[int], target: int) -> int:    
    lo = 0
    hi = len(nums) - 1
    
    while lo <= hi:
        mid = lo + (hi - lo) // 2
        if nums[mid] == target:
            return mid
        elif nums[mid] < target:
            lo = mid + 1
        else:
            hi = mid - 1
    
    return -1

It’s easy. Attentions and explanation:

  • Initialize low and high pointers to 0 and len(nums)-1, which are inclusive.
  • Loop condition is lo <= hi. Why it is <=? Try a simple test case nums=[1], target=1 and you can find the reason.
  • Line6 using mid = lo + (hi - lo) // 2 just to avoid integer overflow instead of using mid = (lo + hi) // 2.
  • Line10 using lo = mid + 1 because when nums[mid] < target target must be in range [mid+1, hi] (inclusive), so we can safely move lo to skip mid and make it mid + 1. Similar at Line11 for hi.

Find First

Problem:
Given a integer array sorted in ascending order, find the index of the first number larger than target.

Example:

Input: nums = [1,3,5,7,9], target = 6
Output: 3 (because 7 is the first number larger than target and its index is 3)
Input: nums = [1,3,5,7,9], target = 9
Output: -1 (because no one larger than target)

Solution:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def search(self, nums: List[int], target: int) -> int:    
    lo = 0
    hi = len(nums) - 1
    
    while lo < hi:
        mid = lo + (hi - lo) // 2
        if nums[mid] > target:
            hi = mid 
        else:
            lo = mid + 1
    
    if nums[lo] <= target:
        return -1

    return lo

This problem is to find the first item that meet certain requirement.

1 3 5 7 9 (find first num > target 6)
F F F T T (F is False meaning num <= target; T is True)
      ^

You may use above example to verify the code.

Explanation:

  • Loop condition is lo < hi, it means when loop is ended we expect lo == hi, and lo is the result we are looking for.
  • Line7&8, num[mid] > target is when meeting the requirement, we need to do hi = mid. Why only to mid? consider following example
    1 3 5 7 9 (find first num > target 2)
    F T T T T
    ^
    

    in first loop run, mid is 2, nums[mid] is 5 which meet the requirement. We know every number after mid will also be True so the result is on the first half and we need to move hi. We don’t know if 5 is the first True, so we can only move hi to mid. (We can’t move hi to mid-1 since it will miss the result if 5 is the first True.)

  • Line9&10, when mid not meeting the requirement, we are safe to move lo to mid+1 because we know mid is not the result.
  • After loop, the result we are looking for is the lo (it is equal to hi given our loop condition).
  • But we need exception check, as line12. If the result we find still doesn’t meet the requirement, we return exception. consider this following example
    1 3 5 7 9 (find first num > target 9)
    F F F F F 
            ^ 
    

    eventually lo is 5, but still doesn’t meet the requirement, so we return -1.

Find Last

Problem:
Given a integer array sorted in ascending order, find the index of the last number smaller than target.

Example:

Input: nums = [1,3,5,7,9], target = 6
Output: 2 (because 5 is the last number smaller than target and its index is 2)
Input: nums = [1,3,5,7,9], target = 1
Output: -1 (because no one smaller than target)

Solution:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def search(self, nums: List[int], target: int) -> int:    
    lo = 0
    hi = len(nums) - 1
    
    while lo < hi:
        mid = lo + (hi - lo + 1) // 2
        if nums[mid] < target:
            lo = mid 
        else:
            hi = mid - 1
    
    if nums[lo] <= target:
        return -1

    return lo

This problem is to find the last item that meet certain requirement.

1 3 5 7 9 (find last num < target 6)
T T T F F 
    ^

You may use above example to verify the code.

Explanation:

  • Why line6 is different from the code in “find first” section? To avoid dead loop. Consider following example:
    1 3 (find last num < target 2)
    T F
    

    without the +1 it will loop forever. So always remember to verify your code with this 2 items test case (or you can try to remember the special handling for “find last”).

  • Now when meeting requirement, we know the result is in the second half and we need to move lo. We don’t know if mid is the last one that meets requirement, so we move only move lo to mid.
  • Similar handling for hi and exception as “find first”.

Summary

There are many patterns/templates to correctly implement binary search. This post introduces the one I like best.
The high level structure is simple: just two points plus a while loop. We need to pay attention to the while loop condition and how to move pointers in different cases. At last, don’t forget to check edge cases.

More Exercise:

Reference:

back to top ↑