Search code examples
pythonpython-3.xnumba

Cannot call stored, jit-compiled functions inside another jit-compiled function


  • Python 3.11.4
  • numba 0.57.1

I am building a "compiler" class generating numba njit-ed functions from existing code

import numba
from typing import Callable

class Foo:

    def __init__(self) -> None:
        
        self.compiled: dict[str, Callable] = dict()

        self.compile_bar()

    
    def compile_bar(self):

        @numba.njit
        def func1(a, b):
            return a * b

        @numba.njit
        def func2(a, b):
            return a / b

        l_c: list[Callable] = [func1, func2]

        @numba.njit
        def wrapper(a, b):
            for func in l_c:
                print(func(a, b))

        self.compiled.update({'bar': wrapper})

F = Foo()
print(F.compiled['bar'](1.,2.))

The last line throws an error

Exception has occurred: NumbaNotImplementedError X
Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x7f996cd00c50>, (List(type(CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>)), True),)
During: lowering "$6load_deref.0 = freevar(l_c: [CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>), CPUDispatcher(<function Foo.compile_bar.<locals>.func2 at 0x7f99264c2de0>)])" at /workdir/porepy/scripts_vl/playground.py (33)
  File "/workdir/scripts/playground.py", line 40, in <module>
    print(F.compiled['bar'](1.,2.))
          ^^^^^^^^^^^^^^^^^^^^^^^^
numba.core.errors.NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x7f996cd00c50>, (List(type(CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>)), True),)
During: lowering "$6load_deref.0 = freevar(l_c: [CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>), CPUDispatcher(<function Foo.compile_bar.<locals>.func2 at 0x7f99264c2de0>)])" at /workdir/scripts/playground.py (33)

Now, func1 and func2 are only examples to keep the code compact. In general it is a list of njit-ed functions created based on some input args for Foo

I tried making a typed list

l_c = numba.typed.List.empty_list(numba.float64(numba.float64, numba.float64).as_type())
l_c.append(func1)
l_c.append(func2)

But the outcome did not change.

I think the problem is the access of the function via list indexing, because it happens only there.

Accessing functions which are not in a list seems to work fine.

The error message is quite cryptic to me.

Does anybody have a solution for this and an explanation?


Solution

  • There are two different problems: First, lists are generally considered to be hetegonoues, which doesn't play well with function pointers. Use a tuple instead: l_c: tuple[Callable, ...] = (func1, func2)

    The second problem is that numba can't quite figure out the types of the functions before iterating over them. To fix that you can either provide an explicte siganture for at least one of the functions:

    
            @numba.njit("double(double, double)")
            def func1(a, b):
                return a * b
    

    Or use literal_unroll:

            @numba.njit
            def wrapper(a, b):
                for func in literal_unroll(l_c):
                    print(func(a, b))
    

    All of these option producs a lot of warnings about first class function types being an experimental feature, however, in this case it seems to work. If you want to build the list of functions dynamically, you need to do that before the function is compiled. Just build it up as a list and then convert it to a tuple.

    To silence the warnings, add warnings.simplefilter("ignore", NumbaExperimentalFeatureWarning) somewhere in your code.