Search code examples
algorithmmultidimensional-arraybinary-search

If you're doing binary search on a 2D array, how does matrix[mid/n][mid%n] give you the middle value?


I'm trying to understand the following solution from Leetcode:

def searchMatrix(self, matrix, target):
    n = len(matrix[0])
    lo, hi = 0, len(matrix) * n
    while lo < hi:
        mid = (lo + hi) / 2
        x = matrix[mid/n][mid%n]
        if x < target:
            lo = mid + 1
        elif x > target:
            hi = mid
        else:
            return True
    return False

How does matrix[mid/n][mid%n] give you the middle value?


Solution

  • mid is the index of the linear version of the m*n matrix. You need to convert that into row and column indices. What you see is the long-known conversion for q given n columns: row = int(q / n), col = q % n.

    It may help to view this with n = 10; in this case, mid is the straightforward, base-ten number. The first digit is the row, the second is the column. Visualize:

     0  1  2  3  4  5  6  7  8  9
    10 11 12 13 14 15 16 17 18 19
    20 21 22 23 24 25 26 27 28 29
    30 31 32 ...
    

    Do you see how that works? The tens digit (row = mid // 10) and the ones digit (col = mid % 10) form the indices into the 10-column matrix.