LeetCode - Minimum Falling Path Sum - question on memoization

I am trying to solve this leetcode problem:

Given an n x n array of integers matrix, return the minimum sum of any falling path through matrix.

A falling path starts at any element in the first row and chooses the element in the next row that is either directly below or diagonally left/right. Specifically, the next element from position (row, col) will be (row + 1, col - 1), (row + 1, col), or (row + 1, col + 1).

It's a dynamic programming problem that I wanted to solve using recursion and memoization. The editorial section provided a java solution using row and col for memoization like below:

class Solution {
    public int minFallingPathSum(int[][] matrix) {
        int minFallingSum = Integer.MAX_VALUE;
        Integer memo[][] = new Integer[matrix.length][matrix[0].length];

        // start a DFS (with memoization) from each cell in the top row
        for (int startCol = 0; startCol < matrix.length; startCol++) {
            minFallingSum = Math.min(minFallingSum,
                findMinFallingPathSum(matrix, 0, startCol, memo));
        return minFallingSum;

    public int findMinFallingPathSum(int[][] matrix, int row, int col, Integer[][] memo) {
        //base cases
        if (col < 0 || col == matrix.length) {
            return Integer.MAX_VALUE;
        //check if we have reached the last row
        if (row == matrix.length - 1) {
            return matrix[row][col];
        //check if the results are calculated before
        if (memo[row][col] != null) {
            return memo[row][col];

        // calculate the minimum falling path sum starting from each possible next step
        int left = findMinFallingPathSum(matrix, row + 1, col, memo);
        int middle = findMinFallingPathSum(matrix, row + 1, col + 1, memo);
        int right = findMinFallingPathSum(matrix, row + 1, col - 1, memo);

        memo[row][col] = Math.min(left, Math.min(middle, right)) + matrix[row][col];
        return memo[row][col];

my initial approach using python was like below:

class Solution:
    def minFallingPathSum(self, matrix: List[List[int]]) -> int:
        d = {}
        min_sum = sys.maxsize
        for i in range(len(matrix)):
            min_sum = min(min_sum, self.recur(matrix, 1, i, matrix[0][i], d))
        return min_sum

    def recur(self, matrix: [], row: int, col: int, sum: int, d: {}):
        if row >= len(matrix):
            return sum
        if (row, col) not in d:
            l = []
            l.append(self.recur(matrix, row + 1, col, sum + matrix[row][col], d))
            if col - 1 >= 0:
                l.append(self.recur(matrix, row + 1, col - 1, sum + matrix[row][col-1], d))
            if col + 1 < len(matrix):
                l.append(self.recur(matrix, row + 1, col + 1, sum + matrix[row][col+1], d))
            d[row,col] = min(l)
        return d[row,col]

but it's failing with a wrong answer after 18/50 test cases. I changed it to below by using the sum along with row and col for memoization like below:

class Solution:
    def minFallingPathSum(self, matrix: List[List[int]]) -> int:
        d = {}
        min_sum = sys.maxsize
        for i in range(len(matrix)):
            min_sum = min(min_sum, self.recur(matrix, 1, i, matrix[0][i], d))
        return min_sum

    def recur(self, matrix: [], row: int, col: int, sum: int, d: {}):
        if row >= len(matrix):
            return sum
        if (row, col, sum) not in d:
            l = []
            l.append(self.recur(matrix, row + 1, col, sum + matrix[row][col], d))
            if col - 1 >= 0:
                l.append(self.recur(matrix, row + 1, col - 1, sum + matrix[row][col-1], d))
            if col + 1 < len(matrix):
                l.append(self.recur(matrix, row + 1, col + 1, sum + matrix[row][col+1], d))
            d[row,col,sum] = min(l)
        return d[row,col,sum]

this is working but time limit exceeded after 43/50 test cases.

I am wondering why my python code with using (row, col) for memoization is not working where as it's working for the Java code in the editorial.

any help would be appreciated.


  • it's working for the Java code in the editorial

    But you didn't really replicate the Java algorithm into your Python version:

    • The Java version works bottom up, returning a sum from the given coordinates downwards to the bottom of the matrix, while your algorithm tries to work top down, accumulating a sum from the path from the top to the current cell.

    • The Java version uses the cell coordinates as memoization key, while the Python version uses the cell coordinates together with a sum as key (which destroys the benefit you might get from memoization).

    • In l you also collect sums that include matrix[row][col+1], or matrix[row][col-1], but then store the best of these in dp[row,col,sum]. But these sums were going to siblings of that row and col, so that can't be correct.

    Here is a version that applies the same approach as in the Java version -- bottom up, so you don't need to pass a partial sum as argument. I chose to put the recur function inside the main function, so it is not needed to pass matrix or d as arguments:

    class Solution:
        def minFallingPathSum(self, matrix: List[List[int]]) -> int:
            d = {}
            def recur(row: int, col: int) -> int:
                if col < 0 or col >= len(matrix[0]):
                    return 10000000  # larger than any value
                if row == len(matrix) - 1:
                    return matrix[row][col]
                if (row, col) not in d:
                    d[row,col] = min(recur(row + 1, col),
                                     recur(row + 1, col - 1),
                                     recur(row + 1, col + 1)) + matrix[row][col]
                return d[row,col]
            return min(recur(0, i) for i in range(len(matrix[0])))

    You could improve on the memory usage, by implementing an iterative algorithm -- row by row, not depth-first -- with memoization only storing the results of the previously visited row. Also, you can use a list instead of a dict:

    class Solution:
        def minFallingPathSum(self, matrix: List[List[int]]) -> int:
            n = len(matrix[0])
            dp = [0] * n
            for row in matrix:
                dp = [
                    min(dp[max(i-1, 0): min(n, i+2)]) + val
                    for i, val in enumerate(row)
            return min(dp)