This is my MWE:
from numba import njit
import numpy as np
@njit
def solve(n):
count = np.zeros(n + 1, dtype=int)
res = np.array([0], dtype=int)
def search(sz=0, max_val=1, single=0, previous=None):
nonlocal res
if sz == 4 * n:
res[0] += 1
return
if single and count[0] < 2 * n:
count[0] += 1
search(sz + 1, max_val, single)
count[0] -= 1
for i in range(1, max_val + 1):
if i != previous and count[i] < 2:
count[i] += 1
search(sz + 1, max_val + (i == max_val and max_val < n), single + (count[i] == 1) - (count[i] == 2), i)
count[i] -= 1
search()
return res[0]
for i in range(1, 6):
print(solve(i))
This gives:
NotImplementedError: Failed in nopython mode pipeline (step: analyzing bytecode)
Unsupported use of op_LOAD_CLOSURE encountered
What's the right way to get this to work with numba? The code runs correctly, if slowly, if you remove the @njit line.
I'd put @njit
to the inner function:
import numpy as np
from numba import njit
def solve(n):
@njit
def search(count, res, sz=0, max_val=1, single=0, previous=None):
if sz == 4 * n:
res[0] += 1
return
if single and count[0] < 2 * n:
count[0] += 1
search(count, res, sz + 1, max_val, single)
count[0] -= 1
for i in range(1, max_val + 1):
if i != previous and count[i] < 2:
count[i] += 1
search(
count,
res,
sz + 1,
max_val + (i == max_val and max_val < n),
single + (count[i] == 1) - (count[i] == 2),
i,
)
count[i] -= 1
count = np.zeros(n + 1, dtype=int)
res = np.array([0], dtype=int)
search(count, res)
return res[0]
for i in range(1, 6):
print(solve(i))
Running this script using time
command:
andrej@MyPC:/app$ time python3 script.py
1
28
1816
180143
23783809
real 0m3,818s
user 0m0,015s
sys 0m0,004s
For comparison, without the @njit
:
andrej@MyPC:/app$ time python3 script.py
1
28
1816
180143
23783809
real 1m42,000s
user 0m0,011s
sys 0m0,005s