Search code examples
pythonarraysalgorithmperformanceoptimization

Optimal way of counting the number of non-overlapping pairs given a list of intervals


I'm trying to count the number of non-overlapping pairs given a list of intervals.

For example:

[(1, 8), (7, 9), (3, 10), (7, 12), (11, 13), (13, 14), (9, 15)]

There are 8 pairs:

((1, 8), (11, 13))
((1, 8), (13, 14))
((1, 8), (9, 15))
((7, 9), (11, 13))
((7, 9), (13, 14))
((3, 10), (11, 13))
((3, 10), (13, 14))
((7, 12), (13, 14))

I can't seem to figure out a better solution other than to just brute force it by comparing everything with virtually everything else, resulting in a O(n^2) solution.

def count_non_overlapping_pairs(intervals):
    intervals = list(set(intervals))  # deduplicate any intervals
    intervals.sort(key=lambda x: x[1])
    pairs = 0
    for i in range(len(intervals)):
        for j in range(i+1, len(intervals)):
            if intervals[i][1] < intervals[j][0]:
                pairs += 1
    return pairs

Is there a more optimal solution than this?


Solution

  • Sort the intervals, once by start point, once by end point.

    Now, given an interval, perform a binary search using the start point in the intervals sorted by end point. The index you get tells you how many non-overlapping intervals come before: All intervals that end before your interval starts are non-overlapping.

    Do the same for the end point: Do a binary search in the array of intervals sorted by start point. All intervals that start after your interval ends are non-overlapping.

    All other intervals either start before your interval ends, but after it has started, or end after your interval starts, but start before it.

    Do this for every interval, sum the results. Make sure to halve, to not count intervals twice. This looks as follows:

    Overall you get O(n log n): Two sorts, O(n) times two O(log n) binary searches.

    Now observe that half of this is not even needed - if A and B are two non-overlapping intervals, it suffices if B counts the intervals before it, including A; A doesn't need to count the intervals after it. This lets us simplify the solution further; you just need to sort the end points to be able to count the intervals before an interval, and of course we now don't need to halve the resulting sum anymore:

    # Counting the intervals before suffices
    def count_non_overlapping_pairs(intervals):
        ends = sorted(interval[1] for interval in intervals)
        def count_before(interval):
            return bisect_left(ends, interval[0])
        return sum(map(count_before, intervals))
    

    (Symmetrically, you could also just count the intervals after an interval.)