Search code examples
pythonintervalspython-itertoolsfunctoolsaccumulate

Itertools.accumulate to find union of intervals (convert from reduce to accumulate)


I seem to have developed the right reduce operation to find the union of intervals, only to realize that reduce gives you a final result. So I looked up the documentation and figured out that what I should be using is in fact accumulate.

I need someone to help me convert this reduce to accumulate so I have the intermediate intervals

The code below is an example of how I used reduce. I'm assuming that the intermediate values can be stored using accumulate. I'm not sure if this is even possible.. But I looked at examples how accumulate gives you a list of items where each item is an intermediate calculated result.

example_interval = [[1,3],[2,6],[6,10],[15,18]]

def main():

    def function(item1, item2):


        if item1[1] >= item2[0]:

            return item1[0], max(item1[1], item2[1])

        else:

            return item2

    return reduce(function, example_interval)

To understand the problem, [1, 3], [2, 6] can be reduced to [1, 6] since item1[1] >= item2[0], [1, 6] is then taken in as item1 and then compared to [6,10] which is item2, to give [1, 10]. [1, 10] is then compared with the final item [15, 18], in this case, it is not merged, so the final result is [1, 10], [15, 18].

I do know how to do this problem without reduce and accumulate. I just have a general interest in understanding how I can use accumulate to replicate this task where intermediate values are stored.


Solution

  • from itertools import accumulate
    
    def function(item1, item2):
        if item1[1] >= item2[0]:
            return item1[0], max(item1[1], item2[1])
        return item2
    
    example_interval = [(1,3),(2,6),(6,10),(15,18)]
    print(list(accumulate(example_interval, function)))
    

    Result is:

    [(1, 3), (1, 6), (1, 10), (15, 18)]
    

    Notice that i changed the items on example_interval from lists to tuples. If you dont do that, when item1[1] < item2[0], the returned value is item2 which is a list object, but if item[1] >= item2[0], the returned expression is item1[0], max(item1[1], item2[1]), which is transformed to a tuple:

    example_interval = [[1,3],[2,6],[6,10],[15,18]]
    print(list(accumulate(example_interval, function)))
    

    Now the output is:

    [[1, 3], (1, 6), (1, 10), [15, 18]]