Search code examples
algorithmoptimizationtime-complexityschedulinginteger-programming

Find optimal points to cut a set of intervals


Given a set of intervals on the real line and some parameter d > 0. Find a sequence of points with gaps between neighbors less or equal to d, such that the number of intervals that contain any of the points is minimized. To prevent trivial solutions we ask that the first point from the sequence is before the first interval, and the last point is after the last interval. The intervals can be thought of right-open.

Does this problem have a name? Maybe even an algorithm and a complexity bound?

Some background: This is motivated by a question from topological data analysis, but it seems so general, that it could be interesting for other topics, e.g. task scheduling (given a factory that has to shut down at least once a year and wants to minimize the number of tasks inflicted by the maintenance...) We were thinking of integer programming and minimum cuts, but the d-parameter does not quite fit. We also implemented approximate greedy solutions in n^2 and n*logn time, but they can run into very bad local optima.

Show me a picture

I draw intervals by lines. The following diagram shows 7 intervals. d is such that you have to cut at least every fourth character. At the bottom of the diagram you see two solutions (marked with x and y) to the diagram. x cuts through the four intervals in the top, whereas y cuts through the three intervals at the bottom. y is optimal.

 ——— ———
 ——— ———
   ———
   ———
   ———
x x   x x
y   y   y

Show me some code: How should we define fun in the following snippet?

intervals = [(0, 1), (0.5, 1.5), (0.5, 1.5)]
d = 1.1
fun(intervals, d)
>>> [-0.55, 0.45, 1.55]  # Or something close to it

In this small example the optimal solution will cut the first interval, but not the second and third. Obviously, the algorithm should work with more complicated examples as well.

A tougher test can be the following: Given a uniform distribution of interval start times on [0, 100] and lengths uniform on [0, d], one can compute the expected number of cuts by a regular grid [0, d, 2d, 3d,..] to be slightly below 0.5*n. And the optimal solution should be better:

n = 10000
delta = 1
starts = np.random.uniform(low=0., high=99, size=n)
lengths = np.random.uniform(low=0., high=1, size=n)
rand_intervals = np.array([starts, starts + lengths]).T
regular_grid = np.arange(0, 101, 1)
optimal_grid = fun(rand_intervals)

# This computes the number of intervals being cut by one of the points
def cuts(intervals, grid):
    bins = np.digitize(intervals, grid)
    return sum(bins[:,0] != bins[:,1])

cuts(rand_intervals, regular_grid)
>>> 4987  # Expected to be slightly below 0.5*n
assert cuts(rand_intervals, optimal_grid) <= cuts(rand_intervals, regular_grid)

Solution

  • You can solve this optimally through dynamic programming by maintaining an array S[k] where S[k] is the best solution (covers the largest amount of space) while having k intervals with a point in it. Then you can repeatedly remove your lowest S[k], extend it in all possible ways (limiting yourself to the relevant endpoints of intervals plus the last point in S[k] + delta), and updating S with those new possible solutions. When the lowest possible S[k] in your table covers the entire range, you are done.

    A Python 3 solution using intervaltree from pip:

    from intervaltree import Interval, IntervalTree
    
    def optimal_points(intervals, d, epsilon=1e-9):
        intervals = [Interval(lr[0], lr[1]) for lr in intervals]
        tree = IntervalTree(intervals)
        start = min(iv.begin for iv in intervals)
        stop = max(iv.end for iv in intervals)
    
        # The best partial solution with k intervals containing a point.
        # We also store the intervals that these points are contained in as a set.
        sols = {0: ([start], set())}
    
        while True:
            lowest_k = min(sols.keys())
            s, contained = sols.pop(lowest_k)
            # print(lowest_k, s[-1])  # For tracking progress in slow instances.
            if s[-1] >= stop:
                return s
    
            relevant_intervals = tree[s[-1]:s[-1] + d]
            relevant_points = [iv.begin - epsilon for iv in relevant_intervals]
            relevant_points += [iv.end + epsilon for iv in relevant_intervals]
            extensions = {s[-1] + d} | {p for p in relevant_points if s[-1] < p < s[-1] + d}
    
            for ext in sorted(extensions, reverse=True):
                new_s = s + [ext]
                new_contained = set(tree[ext]) | contained
                new_k = len(new_contained)
                if new_k not in sols or new_s[-1] > sols[new_k][0][-1]:
                    sols[new_k] = (new_s, new_contained)