Search code examples
pythonarraysalgorithmmathtime-complexity

Sum of the [maximums of all subarrays multiplied by their lengths], in linear time


Given an array I should compute the following sum in linear time:

My most naive implementation is O(n3):

sum_ = 0

for i in range(n):
    for j in range(n, i, -1):
        sum_ += max(arr[i:j]) * (j-i)

I have no idea what to do. I have tried many algorithms but they were at best O(n*log(n)), but I should solve it in linear time. Also, I don't get the idea, is there a mathematical way of just looking at an array and telling the result of the above sum?


Solution

  • Keep a stack of (indices of) non-increasing values. So before appending the new value, pop smaller ones. Whenever you pop one, add its contribution to the total.

    def solution(arr):
        arr.append(float('inf'))
        I = [-1]
        total = 0
        for i in range(len(arr)):
            while arr[i] > arr[I[-1]]:
                j = I.pop()
                a = j - I[-1]
                b = i - j
                total += (a+b)*a*b//2 * arr[j]
            I.append(i)
        arr.pop()
        return total
    

    illustration

    The bars represent values, larger values are larger bars. The value at i is about to be added. The light grey ones come later. The green ones are on the stack. The brown ones already don't play a role anymore. First the one at i-1 gets popped, but that's less informative. Then the one at j gets popped. It dominates the range between I[-1] and i: it's the maximum in all subarrays in that range that contain it. These subarrays contain j as well as 0 to a-1 more elements to the left and 0 to b-1 more elements to the right. That's a*b subarrays and their average length is (a+b)/2.

    I temporarily append infinity to the values so it works as a sentinel on the left (avoids an extra check in the while condition) and as a cleaner at the end (it causes all remaining values to get popped from the stack). Non-Python-coders: Python supports negative indexes, -1 means "last element" (1st from the end).

    Correctness test with random lists of 500 values (Try it online!):

    import random
    
    def reference(arr):
        n = len(arr)
        return sum(max(arr[L : R+1]) * (R - (L-1))
                   for L in range(n)
                   for R in range(L, n))
    
    for _ in range(5):
        arr = random.choices(range(10000), k=500)
        expect = reference(arr)
        result = solution(arr)
        print(result == expect, result)
    

    Sample output (results for five lists, True means it's correct):

    True 207276773131
    True 208127393653
    True 208653950227
    True 208073567605
    True 206924015682