Search code examples
pythondynamic-programmingmemoization

Leetcode 1155. dictionary based memoization vs LRU Cache


I was solving leetcode 1155 which is about number of dice rolls with target sum. I was using dictionary-based memorization. Here's the exact code:

class Solution:
    def numRollsToTarget(self, dices: int, faces: int, target: int) -> int:
        
        dp = {}
        def ways(t, rd):
            if t == 0  and rd == 0: return 1
            if t <= 0 or rd <= 0: return 0
            if dp.get((t,rd)): return dp[(t,rd)]
            dp[(t,rd)] = sum(ways(t-i, rd-1) for i in range(1,faces+1))
            return dp[(t,rd)]
        
        return ways(target, dices)

But this solution is invariably timing out for a combination of face and dices around 15*15

Then I found this solution which uses functools.lru_cache and the rest of it is exactly the same. This solution works very fast.

class Solution:
    def numRollsToTarget(self, dices: int, faces: int, target: int) -> int:
        from functools import lru_cache
        @lru_cache(None)
        def ways(t, rd):
            if t == 0  and rd == 0: return 1
            if t <= 0 or rd <= 0: return 0
            return sum(ways(t-i, rd-1) for i in range(1,faces+1))
        
        return ways(target, dices)

Earlier, I have compared and found that in most cases, lru_cache does not outperform dictionary-based cache by such a margin.

Can someone explain the reason why there is such a drastic performance difference between the two approaches?


Solution

  • First, running your OP code with cProfile and this is the report:

    • with print(numRollsToTarget2(4, 6, 20)) (OP version)

    You can spot right away there're some heavy calls in ways genexpr and sum. That's prob. need close examinations and try to improve/reduce. Next posting is for similar memo version, but the calls is much less. And that version has passed w/o timeout.

    35
             2864 function calls (366 primitive calls) in 0.018 seconds
    
       Ordered by: standard name
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.000    0.000    0.018    0.018 <string>:1(<module>)
            1    0.000    0.000    0.001    0.001 dice_rolls.py:23(numRollsToTarget2)
       1075/1    0.001    0.000    0.001    0.001 dice_rolls.py:25(ways)
       1253/7    0.001    0.000    0.001    0.000 dice_rolls.py:30(<genexpr>)
            1    0.000    0.000    0.018    0.018 dice_rolls.py:36(main)
           21    0.000    0.000    0.000    0.000 rpc.py:153(debug)
            3    0.000    0.000    0.017    0.006 rpc.py:216(remotecall)
            3    0.000    0.000    0.000    0.000 rpc.py:226(asynccall)
            3    0.000    0.000    0.016    0.005 rpc.py:246(asyncreturn)
            3    0.000    0.000    0.000    0.000 rpc.py:252(decoderesponse)
            3    0.000    0.000    0.016    0.005 rpc.py:290(getresponse)
            3    0.000    0.000    0.000    0.000 rpc.py:298(_proxify)
            3    0.000    0.000    0.016    0.005 rpc.py:306(_getresponse)
            3    0.000    0.000    0.000    0.000 rpc.py:328(newseq)
            3    0.000    0.000    0.000    0.000 rpc.py:332(putmessage)
            2    0.000    0.000    0.001    0.000 rpc.py:559(__getattr__)
            3    0.000    0.000    0.000    0.000 rpc.py:57(dumps)
            1    0.000    0.000    0.001    0.001 rpc.py:577(__getmethods)
            2    0.000    0.000    0.000    0.000 rpc.py:601(__init__)
            2    0.000    0.000    0.016    0.008 rpc.py:606(__call__)
            4    0.000    0.000    0.000    0.000 run.py:412(encoding)
            4    0.000    0.000    0.000    0.000 run.py:416(errors)
            2    0.000    0.000    0.017    0.008 run.py:433(write)
            6    0.000    0.000    0.000    0.000 threading.py:1306(current_thread)
            3    0.000    0.000    0.000    0.000 threading.py:222(__init__)
            3    0.000    0.000    0.016    0.005 threading.py:270(wait)
            3    0.000    0.000    0.000    0.000 threading.py:81(RLock)
            3    0.000    0.000    0.000    0.000 {built-in method _struct.pack}
            3    0.000    0.000    0.000    0.000 {built-in method _thread.allocate_lock}
            6    0.000    0.000    0.000    0.000 {built-in method _thread.get_ident}
            1    0.000    0.000    0.018    0.018 {built-in method builtins.exec}
            6    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
            9    0.000    0.000    0.000    0.000 {built-in method builtins.len}
            1    0.000    0.000    0.017    0.017 {built-in method builtins.print}
        179/1    0.000    0.000    0.001    0.001 {built-in method builtins.sum}
            3    0.000    0.000    0.000    0.000 {built-in method select.select}
            3    0.000    0.000    0.000    0.000 {method '_acquire_restore' of '_thread.RLock' objects}
            3    0.000    0.000    0.000    0.000 {method '_is_owned' of '_thread.RLock' objects}
            3    0.000    0.000    0.000    0.000 {method '_release_save' of '_thread.RLock' objects}
            3    0.000    0.000    0.000    0.000 {method 'acquire' of '_thread.RLock' objects}
            6    0.016    0.003    0.016    0.003 {method 'acquire' of '_thread.lock' objects}
            3    0.000    0.000    0.000    0.000 {method 'append' of 'collections.deque' objects}
            2    0.000    0.000    0.000    0.000 {method 'decode' of 'bytes' objects}
            1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
            3    0.000    0.000    0.000    0.000 {method 'dump' of '_pickle.Pickler' objects}
            2    0.000    0.000    0.000    0.000 {method 'encode' of 'str' objects}
          201    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
            3    0.000    0.000    0.000    0.000 {method 'getvalue' of '_io.BytesIO' objects}
            3    0.000    0.000    0.000    0.000 {method 'release' of '_thread.RLock' objects}
            3    0.000    0.000    0.000    0.000 {method 'send' of '_socket.socket' objects}
    
    

    Then I tried to run modified/simplified version, and compare the results.

    35
             387 function calls (193 primitive calls) in 0.006 seconds
    
       Ordered by: standard name
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.000    0.000    0.006    0.006 <string>:1(<module>)
            1    0.000    0.000    0.006    0.006 dice_rolls.py:36(main)
            1    0.000    0.000    0.000    0.000 dice_rolls.py:5(numRollsToTarget)
        195/1    0.000    0.000    0.000    0.000 dice_rolls.py:8(dp)
           21    0.000    0.000    0.000    0.000 rpc.py:153(debug)
            3    0.000    0.000    0.006    0.002 rpc.py:216(remotecall)
            3    0.000    0.000    0.000    0.000 rpc.py:226(asynccall)
            3    0.000    0.000    0.006    0.002 rpc.py:246(asyncreturn)
            3    0.000    0.000    0.000    0.000 rpc.py:252(decoderesponse)
            3    0.000    0.000    0.006    0.002 rpc.py:290(getresponse)
            3    0.000    0.000    0.000    0.000 rpc.py:298(_proxify)
            3    0.000    0.000    0.006    0.002 rpc.py:306(_getresponse)
            3    0.000    0.000    0.000    0.000 rpc.py:328(newseq)
            3    0.000    0.000    0.000    0.000 rpc.py:332(putmessage)
            2    0.000    0.000    0.001    0.000 rpc.py:559(__getattr__)
            3    0.000    0.000    0.000    0.000 rpc.py:57(dumps)
            1    0.000    0.000    0.001    0.001 rpc.py:577(__getmethods)
            2    0.000    0.000    0.000    0.000 rpc.py:601(__init__)
            2    0.000    0.000    0.005    0.003 rpc.py:606(__call__)
            4    0.000    0.000    0.000    0.000 run.py:412(encoding)
            4    0.000    0.000    0.000    0.000 run.py:416(errors)
            2    0.000    0.000    0.006    0.003 run.py:433(write)
            6    0.000    0.000    0.000    0.000 threading.py:1306(current_thread)
            3    0.000    0.000    0.000    0.000 threading.py:222(__init__)
            3    0.000    0.000    0.006    0.002 threading.py:270(wait)
            3    0.000    0.000    0.000    0.000 threading.py:81(RLock)
            3    0.000    0.000    0.000    0.000 {built-in method _struct.pack}
            3    0.000    0.000    0.000    0.000 {built-in method _thread.allocate_lock}
            6    0.000    0.000    0.000    0.000 {built-in method _thread.get_ident}
            1    0.000    0.000    0.006    0.006 {built-in method builtins.exec}
            6    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
            9    0.000    0.000    0.000    0.000 {built-in method builtins.len}
           34    0.000    0.000    0.000    0.000 {built-in method builtins.max}
            1    0.000    0.000    0.006    0.006 {built-in method builtins.print}
            3    0.000    0.000    0.000    0.000 {built-in method select.select}
            3    0.000    0.000    0.000    0.000 {method '_acquire_restore' of '_thread.RLock' objects}
            3    0.000    0.000    0.000    0.000 {method '_is_owned' of '_thread.RLock' objects}
            3    0.000    0.000    0.000    0.000 {method '_release_save' of '_thread.RLock' objects}
            3    0.000    0.000    0.000    0.000 {method 'acquire' of '_thread.RLock' objects}
            6    0.006    0.001    0.006    0.001 {method 'acquire' of '_thread.lock' objects}
            3    0.000    0.000    0.000    0.000 {method 'append' of 'collections.deque' objects}
            2    0.000    0.000    0.000    0.000 {method 'decode' of 'bytes' objects}
            1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
            3    0.000    0.000    0.000    0.000 {method 'dump' of '_pickle.Pickler' objects}
            2    0.000    0.000    0.000    0.000 {method 'encode' of 'str' objects}
            2    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
            3    0.000    0.000    0.000    0.000 {method 'getvalue' of '_io.BytesIO' objects}
            3    0.000    0.000    0.000    0.000 {method 'release' of '_thread.RLock' objects}
            3    0.000    0.000    0.000    0.000 {method 'send' of '_socket.socket' objects}
    
    

    The profiling codes are here:

    import cProfile
    from typing import List
    
    def numRollsToTarget(d, f, target):
        memo = {}
    
        def dp(d, target):
            if d == 0:
                return 0 if target > 0 else 1
            if (d, target) in memo:
                return memo[(d, target)]
    
            result = 0
            
            for k in range(max(0, target-f), target):
                result += dp(d-1, k)
            memo[(d, target)] = result
            return result 
        
        return dp(d, target) % (10**9 + 7)
        
    def numRollsToTarget2(dices: int, faces: int, target: int) -> int:
        dp = {}
        def ways(t, rd):
            if t == 0  and rd == 0: return 1
            if t <= 0 or rd <= 0: return 0
            if dp.get((t,rd)): return dp[(t,rd)]
            
            dp[(t,rd)] = sum(ways(t-i, rd-1) for i in range(1,faces+1))
            return dp[(t,rd)]
            
        return ways(target, dices)
    
    def numRollsToTarget3(dices: int, faces: int, target: int) -> int:
        from functools import lru_cache
        @lru_cache(None)
        def ways(t, rd):
            if t == 0  and rd == 0: return 1
            if t <= 0 or rd <= 0: return 0
            return sum(ways(t-i, rd-1) for i in range(1,faces+1))
            
        return ways(target, dices)
    def main():
        print(numRollsToTarget(4, 6, 20))
        #print(numRollsToTarget2(4, 6, 20))
        #print(numRollsToTarget3(4, 6, 20))  # not faster than first
    
    
    
    if __name__ == '__main__':
        cProfile.run('main()')