Search code examples
pythonrecursioncompilationnumba

When does numba just-in-time compile if recursion takes place?


When I have a function that calls itself recursively without changing the type, but working on a large object (which gets passed around as an argument)... how does this work in numba behind the scenes?

Is my function (which is small but called a lot) just-in-time compiled the first time it is called, but not yet finished (because of the recursion) or does the compilation only finished if the first call to the function is finished?

For example in

@njit
def myfct(large_object):
    a, tail_condition = do_things_1(intermediate_result)
    if tail_condition == True:
       return a
    intermediate_result = myfct(large_object)
    b = do_things_2(a, intermediate_result)
    return b

final_result = myfct(ref_to_large_object)

when is myfct compiled? Is it already compiled before it get called the second time in line 6 or is only compiled when I get final_result and everything is already done anyways? If the later is the case - how can I avoid that?


Solution

  • Functions are compiled once per signature. If you don't provide an explicit signature, they are compiled the first time each signature is used:

    @nb.njit
    def f(n):
        return n * f(n-1) if n > 1 else 1
    
    f(4)           # Compiled here
    f(5)           # Already compiled
    f(5.1)         # Compiled again
    

    If you provide an explicit signature, functions are compiled when declared:

    @nb.njit([nb.int32(nb.int32)])
    def f(n):
        return n * f(n-1) if n > 1 else 1
    
    f(5)           # Already compiled
    f(5.1)         # Crashes. No attempt to compile an additional signature.
    

    Providing more than one explicit signature:

    @nb.njit([nb.int32(nb.int32), nb.float64(nb.float64)])
    def f(n):
        return n * f(n-1) if n > 1 else 1
    
    f(5)           # Already compiled
    f(5.1)         # Already compiled
    

    You can check this yourself by placing a breakpoint in numba.core.dispatcher.Dispatcher.compile().