Search code examples
pythonnumba

How to handle Unsupported use of op_LOAD_CLOSURE encountered?


This is my MWE:

from numba import njit
import numpy as np

@njit   
def solve(n):
    count = np.zeros(n + 1, dtype=int)
    res = np.array([0], dtype=int)

    def search(sz=0, max_val=1, single=0, previous=None):
        nonlocal res
        if sz == 4 * n:
            res[0] += 1
            return
        if single and count[0] < 2 * n:
            count[0] += 1
            search(sz + 1, max_val, single)
            count[0] -= 1
        for i in range(1, max_val + 1):
            if i != previous and count[i] < 2:
                count[i] += 1
                search(sz + 1, max_val + (i == max_val and max_val < n), single + (count[i] == 1) - (count[i] == 2), i)
                count[i] -= 1
 
    search()
    return res[0]

for i in range(1, 6):
    print(solve(i))

This gives:

NotImplementedError: Failed in nopython mode pipeline (step: analyzing bytecode)
Unsupported use of op_LOAD_CLOSURE encountered

What's the right way to get this to work with numba? The code runs correctly, if slowly, if you remove the @njit line.


Solution

  • I'd put @njit to the inner function:

    import numpy as np
    from numba import njit
    
    
    def solve(n):
        @njit
        def search(count, res, sz=0, max_val=1, single=0, previous=None):
            if sz == 4 * n:
                res[0] += 1
                return
            if single and count[0] < 2 * n:
                count[0] += 1
                search(count, res, sz + 1, max_val, single)
                count[0] -= 1
            for i in range(1, max_val + 1):
                if i != previous and count[i] < 2:
                    count[i] += 1
                    search(
                        count,
                        res,
                        sz + 1,
                        max_val + (i == max_val and max_val < n),
                        single + (count[i] == 1) - (count[i] == 2),
                        i,
                    )
                    count[i] -= 1
    
        count = np.zeros(n + 1, dtype=int)
        res = np.array([0], dtype=int)
        search(count, res)
        return res[0]
    
    
    for i in range(1, 6):
        print(solve(i))
    

    Running this script using time command:

    andrej@MyPC:/app$ time python3 script.py
    1
    28
    1816
    180143
    23783809
    
    real    0m3,818s
    user    0m0,015s
    sys     0m0,004s
    

    For comparison, without the @njit:

    andrej@MyPC:/app$ time python3 script.py
    1
    28
    1816
    180143
    23783809
    
    real    1m42,000s
    user    0m0,011s
    sys     0m0,005s