I was solving leetcode 1155 which is about number of dice rolls with target sum. I was using dictionary-based memorization. Here's the exact code:
class Solution:
def numRollsToTarget(self, dices: int, faces: int, target: int) -> int:
dp = {}
def ways(t, rd):
if t == 0 and rd == 0: return 1
if t <= 0 or rd <= 0: return 0
if dp.get((t,rd)): return dp[(t,rd)]
dp[(t,rd)] = sum(ways(t-i, rd-1) for i in range(1,faces+1))
return dp[(t,rd)]
return ways(target, dices)
But this solution is invariably timing out for a combination of face and dices around 15*15
Then I found this solution which uses functools.lru_cache and the rest of it is exactly the same. This solution works very fast.
class Solution:
def numRollsToTarget(self, dices: int, faces: int, target: int) -> int:
from functools import lru_cache
def ways(t, rd):
if t == 0 and rd == 0: return 1
if t <= 0 or rd <= 0: return 0
return sum(ways(t-i, rd-1) for i in range(1,faces+1))
return ways(target, dices)
Earlier, I have compared and found that in most cases, lru_cache does not outperform dictionary-based cache by such a margin.
Can someone explain the reason why there is such a drastic performance difference between the two approaches?
First, running your OP code with cProfile
and this is the report:
You can spot right away there're some heavy calls in ways
and sum
. That's prob. need close examinations and try to improve/reduce. Next posting is for similar memo
version, but the calls
is much less. And that version has passed w/o timeout.
2864 function calls (366 primitive calls) in 0.018 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.018 0.018 <string>:1(<module>)
1 0.000 0.000 0.001 0.001 dice_rolls.py:23(numRollsToTarget2)
1075/1 0.001 0.000 0.001 0.001 dice_rolls.py:25(ways)
1253/7 0.001 0.000 0.001 0.000 dice_rolls.py:30(<genexpr>)
1 0.000 0.000 0.018 0.018 dice_rolls.py:36(main)
21 0.000 0.000 0.000 0.000 rpc.py:153(debug)
3 0.000 0.000 0.017 0.006 rpc.py:216(remotecall)
3 0.000 0.000 0.000 0.000 rpc.py:226(asynccall)
3 0.000 0.000 0.016 0.005 rpc.py:246(asyncreturn)
3 0.000 0.000 0.000 0.000 rpc.py:252(decoderesponse)
3 0.000 0.000 0.016 0.005 rpc.py:290(getresponse)
3 0.000 0.000 0.000 0.000 rpc.py:298(_proxify)
3 0.000 0.000 0.016 0.005 rpc.py:306(_getresponse)
3 0.000 0.000 0.000 0.000 rpc.py:328(newseq)
3 0.000 0.000 0.000 0.000 rpc.py:332(putmessage)
2 0.000 0.000 0.001 0.000 rpc.py:559(__getattr__)
3 0.000 0.000 0.000 0.000 rpc.py:57(dumps)
1 0.000 0.000 0.001 0.001 rpc.py:577(__getmethods)
2 0.000 0.000 0.000 0.000 rpc.py:601(__init__)
2 0.000 0.000 0.016 0.008 rpc.py:606(__call__)
4 0.000 0.000 0.000 0.000 run.py:412(encoding)
4 0.000 0.000 0.000 0.000 run.py:416(errors)
2 0.000 0.000 0.017 0.008 run.py:433(write)
6 0.000 0.000 0.000 0.000 threading.py:1306(current_thread)
3 0.000 0.000 0.000 0.000 threading.py:222(__init__)
3 0.000 0.000 0.016 0.005 threading.py:270(wait)
3 0.000 0.000 0.000 0.000 threading.py:81(RLock)
3 0.000 0.000 0.000 0.000 {built-in method _struct.pack}
3 0.000 0.000 0.000 0.000 {built-in method _thread.allocate_lock}
6 0.000 0.000 0.000 0.000 {built-in method _thread.get_ident}
1 0.000 0.000 0.018 0.018 {built-in method builtins.exec}
6 0.000 0.000 0.000 0.000 {built-in method builtins.isinstance}
9 0.000 0.000 0.000 0.000 {built-in method builtins.len}
1 0.000 0.000 0.017 0.017 {built-in method builtins.print}
179/1 0.000 0.000 0.001 0.001 {built-in method builtins.sum}
3 0.000 0.000 0.000 0.000 {built-in method select.select}
3 0.000 0.000 0.000 0.000 {method '_acquire_restore' of '_thread.RLock' objects}
3 0.000 0.000 0.000 0.000 {method '_is_owned' of '_thread.RLock' objects}
3 0.000 0.000 0.000 0.000 {method '_release_save' of '_thread.RLock' objects}
3 0.000 0.000 0.000 0.000 {method 'acquire' of '_thread.RLock' objects}
6 0.016 0.003 0.016 0.003 {method 'acquire' of '_thread.lock' objects}
3 0.000 0.000 0.000 0.000 {method 'append' of 'collections.deque' objects}
2 0.000 0.000 0.000 0.000 {method 'decode' of 'bytes' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
3 0.000 0.000 0.000 0.000 {method 'dump' of '_pickle.Pickler' objects}
2 0.000 0.000 0.000 0.000 {method 'encode' of 'str' objects}
201 0.000 0.000 0.000 0.000 {method 'get' of 'dict' objects}
3 0.000 0.000 0.000 0.000 {method 'getvalue' of '_io.BytesIO' objects}
3 0.000 0.000 0.000 0.000 {method 'release' of '_thread.RLock' objects}
3 0.000 0.000 0.000 0.000 {method 'send' of '_socket.socket' objects}
Then I tried to run modified/simplified version, and compare the results.
387 function calls (193 primitive calls) in 0.006 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.006 0.006 <string>:1(<module>)
1 0.000 0.000 0.006 0.006 dice_rolls.py:36(main)
1 0.000 0.000 0.000 0.000 dice_rolls.py:5(numRollsToTarget)
195/1 0.000 0.000 0.000 0.000 dice_rolls.py:8(dp)
21 0.000 0.000 0.000 0.000 rpc.py:153(debug)
3 0.000 0.000 0.006 0.002 rpc.py:216(remotecall)
3 0.000 0.000 0.000 0.000 rpc.py:226(asynccall)
3 0.000 0.000 0.006 0.002 rpc.py:246(asyncreturn)
3 0.000 0.000 0.000 0.000 rpc.py:252(decoderesponse)
3 0.000 0.000 0.006 0.002 rpc.py:290(getresponse)
3 0.000 0.000 0.000 0.000 rpc.py:298(_proxify)
3 0.000 0.000 0.006 0.002 rpc.py:306(_getresponse)
3 0.000 0.000 0.000 0.000 rpc.py:328(newseq)
3 0.000 0.000 0.000 0.000 rpc.py:332(putmessage)
2 0.000 0.000 0.001 0.000 rpc.py:559(__getattr__)
3 0.000 0.000 0.000 0.000 rpc.py:57(dumps)
1 0.000 0.000 0.001 0.001 rpc.py:577(__getmethods)
2 0.000 0.000 0.000 0.000 rpc.py:601(__init__)
2 0.000 0.000 0.005 0.003 rpc.py:606(__call__)
4 0.000 0.000 0.000 0.000 run.py:412(encoding)
4 0.000 0.000 0.000 0.000 run.py:416(errors)
2 0.000 0.000 0.006 0.003 run.py:433(write)
6 0.000 0.000 0.000 0.000 threading.py:1306(current_thread)
3 0.000 0.000 0.000 0.000 threading.py:222(__init__)
3 0.000 0.000 0.006 0.002 threading.py:270(wait)
3 0.000 0.000 0.000 0.000 threading.py:81(RLock)
3 0.000 0.000 0.000 0.000 {built-in method _struct.pack}
3 0.000 0.000 0.000 0.000 {built-in method _thread.allocate_lock}
6 0.000 0.000 0.000 0.000 {built-in method _thread.get_ident}
1 0.000 0.000 0.006 0.006 {built-in method builtins.exec}
6 0.000 0.000 0.000 0.000 {built-in method builtins.isinstance}
9 0.000 0.000 0.000 0.000 {built-in method builtins.len}
34 0.000 0.000 0.000 0.000 {built-in method builtins.max}
1 0.000 0.000 0.006 0.006 {built-in method builtins.print}
3 0.000 0.000 0.000 0.000 {built-in method select.select}
3 0.000 0.000 0.000 0.000 {method '_acquire_restore' of '_thread.RLock' objects}
3 0.000 0.000 0.000 0.000 {method '_is_owned' of '_thread.RLock' objects}
3 0.000 0.000 0.000 0.000 {method '_release_save' of '_thread.RLock' objects}
3 0.000 0.000 0.000 0.000 {method 'acquire' of '_thread.RLock' objects}
6 0.006 0.001 0.006 0.001 {method 'acquire' of '_thread.lock' objects}
3 0.000 0.000 0.000 0.000 {method 'append' of 'collections.deque' objects}
2 0.000 0.000 0.000 0.000 {method 'decode' of 'bytes' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
3 0.000 0.000 0.000 0.000 {method 'dump' of '_pickle.Pickler' objects}
2 0.000 0.000 0.000 0.000 {method 'encode' of 'str' objects}
2 0.000 0.000 0.000 0.000 {method 'get' of 'dict' objects}
3 0.000 0.000 0.000 0.000 {method 'getvalue' of '_io.BytesIO' objects}
3 0.000 0.000 0.000 0.000 {method 'release' of '_thread.RLock' objects}
3 0.000 0.000 0.000 0.000 {method 'send' of '_socket.socket' objects}
The profiling codes are here:
import cProfile
from typing import List
def numRollsToTarget(d, f, target):
memo = {}
def dp(d, target):
if d == 0:
return 0 if target > 0 else 1
if (d, target) in memo:
return memo[(d, target)]
result = 0
for k in range(max(0, target-f), target):
result += dp(d-1, k)
memo[(d, target)] = result
return result
return dp(d, target) % (10**9 + 7)
def numRollsToTarget2(dices: int, faces: int, target: int) -> int:
dp = {}
def ways(t, rd):
if t == 0 and rd == 0: return 1
if t <= 0 or rd <= 0: return 0
if dp.get((t,rd)): return dp[(t,rd)]
dp[(t,rd)] = sum(ways(t-i, rd-1) for i in range(1,faces+1))
return dp[(t,rd)]
return ways(target, dices)
def numRollsToTarget3(dices: int, faces: int, target: int) -> int:
from functools import lru_cache
def ways(t, rd):
if t == 0 and rd == 0: return 1
if t <= 0 or rd <= 0: return 0
return sum(ways(t-i, rd-1) for i in range(1,faces+1))
return ways(target, dices)
def main():
print(numRollsToTarget(4, 6, 20))
#print(numRollsToTarget2(4, 6, 20))
#print(numRollsToTarget3(4, 6, 20)) # not faster than first
if __name__ == '__main__':