Search code examples
pythonrecursiondynamic-programmingmemoization

Issue with Memoization in Recursive Function for Finding Combinations Summing to Target


I need to write the following function:

Write a function that takes in a target (int) and a list of ints. The function should return a list of any combination of elements that add up to the target, if there is no combination that adds up to target return None.

This is my initial solution using recursion:

def how_sum(target: int, nums: list[int]) -> list[int] | None:
    if target == 0: return []
    if target < 0: return None

    for num in nums:
        remainder = target - num
        rv = how_sum(remainder, nums)
        
        if rv is not None:
            return rv + [num]

    return None

Then I tried to reduce the time complexity and make my code efficient even for large numbers:

def how_sum(target: int, nums: list[int], memo: dict[int, list[int]] = {}) -> list[int] | None:
    if target in memo: return memo[target]
    if target == 0: return []
    if target < 0: return None

    for num in nums:
        remainder = target - num
        rv = how_sum(remainder, nums, memo)

        if rv is not None:
            memo[target] = rv + [num] # Note: if I comment this line everything works fine!
            return rv + [num]

    memo[target] = None
    return None


def main():
    print(how_sum(7, [2, 3]))  # [3, 2, 2]
    print(how_sum(7, [5, 3, 4, 7]))  # [3, 2, 2]
    print(how_sum(7, [2, 4]))  # [3, 2, 2]
    print(how_sum(8, [2, 3, 5]))  # [2, 2, 2, 2]
    print(how_sum(500, [7, 14]))  # [3, 7, 7, 7, 7, 7, ..., 7]


main()

As you can see in the comments it returns the wrong output.

These are the correct output:

def main():
    print(how_sum(7, [2, 3]))  # [3, 2, 2]
    print(how_sum(7, [5, 3, 4, 7]))  # [4, 3]
    print(how_sum(7, [2, 4]))  # None
    print(how_sum(8, [2, 3, 5]))  # None
    print(how_sum(500, [7, 14]))  # None

When I comment this line memo[target] = rv + [num] everything works fine but I can't figure out why it doesn't work if I leave it as it is.


Solution

  • There are two issues I see with this code. First your claim about the correct solution to this input:

    print(how_sum(8, [2, 3, 5]))  # None
    

    seems incorrect. Given the explanation, either [3, 5] or [2, 2, 2, 2] are valid answers. Similarly for:

     print(how_sum(7, [5, 3, 4, 7]))  # [4, 3]
    

    where [7] is also a valid result. As far as your code issue, the problem is a common one in that you're using a dangerous default value in a situation where it isn't warranted:

    def how_sum(target, nums, memo={})
    

    Since nums is different between top level calls, the memo cache has to be reinitialized at the start of each recursion stack. Otherwise, you have results in it from a previous run where nums was different. One possible approach:

    def how_sum(target: int, nums: list[int], memo: dict[int, list[int]] = None) -> list[int] | None:
        if memo is None:
            memo = {0: []}
    
        if target not in memo:
            memo[target] = None
    
            if target > 0:
                for num in nums:
                    remainder = target - num
                    rv = how_sum(remainder, nums, memo)
    
                    if rv is not None:
                        memo[target] = [*rv, num]
                        break
    
        return memo[target]