Search code examples
pythonlistrecursionsumsuffix

Suffix sum using recursion: Memory limit exceeded error in Python


I found this Suffix Sum challenge on CodeForces as I was practising recursion sums:

Given two numbers 𝑁 and 𝑀, and an array 𝐴 of 𝑁 numbers. Calculate the sum of the last 𝑀 numbers.

Note: solve this problem using recursion.

Input

First line contains two numbers 𝑁 and 𝑀 (1 ≤ 𝑀 ≤ 𝑁 ≤ 105).

Second line contains 𝑁 numbers (−109 ≤ 𝐴𝑖 ≤ 109).

Output

Print the sum of the last 𝑀 numbers of the given array.

I tried to solve the question in Python, and this is my code:

def suffix_sum(arr, m):
    if m <= 0:
        return 0
    else:
        return arr[-1] + suffix_sum(arr[:-1], m - 1)

n, m = map(int,input().split())
arr = list(map(int,input().split()))
print(suffix_sum(arr, m))

Though I get the expected output in my compiler for my sample input, I always end up getting 'Memory limit exceeded on test 2' when I submit my code. Test case 2 is hidden, so I am not sure what the test case is, but my guess is that it is a really large array in the order of 104 or 105. When I checked the other submissions, they were all in C++.

So, does the memory limit exceeding have something to do with me using Python and PyPy, or is something wrong with my code?

The result gets accepted when I use a for loop, but as the question is for practising recursion, I am not sure where I am going wrong. It would be of great help if you could help me improve!


Solution

  • Slicing costs memory: each slice is a new list. So you should avoid slicing and only use indexing to retrieve the values you need.

    As 𝑚 could be as large as 105, you'll want to avoid a recursion depth of O(𝑚). You can achieve that by partitioning the problem into halves, making two recursive calls, where to each you provide two indices defining the section of the list to sum up. This second index could also be a distance from the end of the input list, so it has a similar meaning as m and can get a default value of 1:

    def suffix_sum(lst, m, last=1):
        if m == last:
            return lst[-m]
        mid = (m + last) // 2
        return suffix_sum(lst, m, mid + 1) + suffix_sum(lst, mid, last)
    

    Now the recursion depth will be ⌈log2𝑚⌉, which for the worst case of 𝑚=105 will be 17.

    Note that the base case relies on the fact that 𝑚 will be at least 1.