Search code examples
pythonarraysbinary-search

Using binary search to find the duplicate number in an array


The problem:

Given an array of integers nums containing n + 1 integers where each integer is in the range [1, n] inclusive.

There is only one repeated number in nums, return this repeated number.

You must solve the problem without modifying the array nums and uses only constant extra space.

Here is one of the possible solution using binary search

class Solution(object):
    def findDuplicate(self, nums):
        beg, end = 1, len(nums)-1
        
        while beg + 1 <= end:
            mid, count = (beg + end)//2, 0
            for num in nums:
                if num <= mid: count += 1        
            if count <= mid:
                beg = mid + 1
            else:
                end = mid
        return end

Example 1:

Input: nums = [1,3,4,2,2]
Output: 2
Example 2:

Input: nums = [3,1,3,4,2]
Output: 3

Can someone please explain this solution for me? I understand the code but I don't understand the logic behind this. In particular, I do not understand how to construct the if statements (lines 7 - 13). Why and how do you know that when num <= mid then I need to do count += 1 and so on. Many thanks.


Solution

  • The solution keeps halving the range of numbers the answer can still be in.

    For example, if the function starts with nums == [1, 3, 4, 2, 2], then the duplicate number must be between 1 and 4 inclusive by definition.

    By counting how many of the numbers are smaller than or equal to the middle of that range (2), you can decide if the duplicate must be in the upper or lower half of that range. Since the actual number is greater (3 numbers are lesser than or equal to 2, and 3 > 2), the number must be in the lower half.

    Repeating the process, knowing that the number must be between 1 and 2 inclusive, only 1 number is less than or equal to the middle of that range (1), which means the number must be in the upper half, and is 2.

    Consider a slightly longer series: [1, 2, 5, 6, 3, 4, 3, 7]. Between 1 and 7 lies 3, 4 numbers are less than or equal to 3, so the number must be between 1 and 3. Between 1 and 3 lies 2, 2 numbers are less than or equal to 2, so the number must be over 2, which leaves 3.

    The solution will iterate over all n elements of nums a limited number of times, since it keeps halving the search space. It's certainly more efficient than the naive:

        def findDuplicate(self, nums):
            for i, n in enumerate(nums):
                for j, m in enumerate(nums):
                    if i != j and n == m:
                        return n
    

    But as user @fas suggests in the comments, this is better:

        def findDuplicate(nums):
            p = 1
            while p < len(nums):
                p <<= 1
            r = 0
            for n in nums:
                r ^= n
            for n in range(len(nums), p):
                r ^= n
            return r