Search code examples
algorithmdivide-and-conquer

Counting inversions in an array of 2D pair


Problem Description:
Let there be an array of 2D pairs ((x1, y1), . . . ,(xn, yn)) . With a fixed constant y' a pair (i, j) is called half-inverted if i < j, xi > xj , and yi ≥ y' > yj . Devise an algorithm that counts the number of half-inverted pairs. You will get full marks if your algorithm is correct of complexity no more than O(n log n). \My idea is to treat this using similar method as counting inversion in a normal array, but my problem is that how do we maintain the order during the Merge And Count step?


Solution

  • It is a simple modification of the familiar merge-sort inversion counting algorithm which can be used to solve this problem so make you fully understand it as a prerequisite.

    If we examine the merge step of this algorithm we have 2 sorted halves and 2 pointers pointing to an element of each. Let our left pointer be i and our right, j. Using the traditional definition of an inversion, if our i pointer points to a value that is larger than the value pointed to by j then due the arrays being sorted and all the elements on the left being before those on the right in the real array, we know all the elements from i to the end of the left half meet our definition of an inversion for our value at j so we increase our count by mid - i where mid is the end of the left half.

    Switching back to your problem, we are dealing with pairs (x,y). If we can keep our x values sorted then, using the approach described above, we can simply count the number of inversions only considering x values. Looking at your definition of half inversions we will surely be over counting the number we need if we only count xi > xj. We are missing the additional constraint of yi >= y' > yj which must be filtered out of our counting.

    So, if we look back to our traditional algorithm when our i pointer is pointing to a value greater than the value at j we also need to make sure that our y value at j is less than y'. If this not true then none of the x's from i to mid will match our definition of a half inversion and so we cannot count them. Now let's assume our j's y is smaller than y', if we simply counted all the pairs from i to mid then we would still be over counting the pairs which have yi < y'.

    One way to fix this is to keep track of the of y values in the left half from i to mid which are >= y' and add that value to our count. We can keep track of how many y >= y' we see in the merge step up to any i, and subtract that from the total number of y's which are >= y' in the left half. To keep track of that total number we can return that value from our recursive function (total = left + right) and only use the number which came from the left half when merging. We also need to modify our base case which is straightforward.

    def count_half_inversions(l, y):
        return count_rec(l, 0, len(l), l.copy(), y)[0]
    
    def count_rec(l, begin, end, copy, y):
        if end-begin <= 1:
            # we have only 1 pair
            return (0, 1 if l[begin][1] >= y else 0)
        mid = begin + ((end-begin) // 2)
        left = count_rec(copy, begin, mid, l, y)
        right = count_rec(copy, mid, end, l, y)
        between = merge_count(l, begin, mid, end, copy, left[1], y)
        # return (inversion count, number of pairs, (i,j), with j >= y)
        return (left[0] + right[0] + between, left[1] + right[1])
    
    def merge_count(l, begin, mid, end, copy, left_y_count, y):
        result = 0
        i,j = begin, mid
        k = begin
        while i < mid and j < end:
            if copy[i][0] > copy[j][0]:
                if y > copy[j][1]:
                    result += left_y_count
                smaller = copy[j]
                j += 1
            else:
                if copy[i][1] >= y:
                    left_y_count -= 1
                smaller = copy[i]
                i += 1
            l[k] = smaller
            k += 1
        while i < mid:
            l[k] = copy[i]
            i += 1
            k += 1
        while j < end:
            l[k] = copy[j]
            j += 1
            k += 1
        return result
                    
    test_case = [(1,1), (6,4), (6,3), (1,2), (1,2), (3,3), (6,2), (0,1)]
    fixed_y = 2
    
    print(count_half_inversions(test_case, fixed_y))