I am building a "compiler" class generating numba njit-ed functions from existing code
import numba
from typing import Callable
class Foo:
def __init__(self) -> None:
self.compiled: dict[str, Callable] = dict()
self.compile_bar()
def compile_bar(self):
@numba.njit
def func1(a, b):
return a * b
@numba.njit
def func2(a, b):
return a / b
l_c: list[Callable] = [func1, func2]
@numba.njit
def wrapper(a, b):
for func in l_c:
print(func(a, b))
self.compiled.update({'bar': wrapper})
F = Foo()
print(F.compiled['bar'](1.,2.))
The last line throws an error
Exception has occurred: NumbaNotImplementedError X
Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x7f996cd00c50>, (List(type(CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>)), True),)
During: lowering "$6load_deref.0 = freevar(l_c: [CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>), CPUDispatcher(<function Foo.compile_bar.<locals>.func2 at 0x7f99264c2de0>)])" at /workdir/porepy/scripts_vl/playground.py (33)
File "/workdir/scripts/playground.py", line 40, in <module>
print(F.compiled['bar'](1.,2.))
^^^^^^^^^^^^^^^^^^^^^^^^
numba.core.errors.NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x7f996cd00c50>, (List(type(CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>)), True),)
During: lowering "$6load_deref.0 = freevar(l_c: [CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>), CPUDispatcher(<function Foo.compile_bar.<locals>.func2 at 0x7f99264c2de0>)])" at /workdir/scripts/playground.py (33)
Now, func1
and func2
are only examples to keep the code compact. In general it is a list of njit-ed functions created based on some input args for Foo
I tried making a typed list
l_c = numba.typed.List.empty_list(numba.float64(numba.float64, numba.float64).as_type())
l_c.append(func1)
l_c.append(func2)
But the outcome did not change.
I think the problem is the access of the function via list indexing, because it happens only there.
Accessing functions which are not in a list seems to work fine.
The error message is quite cryptic to me.
Does anybody have a solution for this and an explanation?
There are two different problems: First, lists are generally considered to be hetegonoues, which doesn't play well with function pointers. Use a tuple instead: l_c: tuple[Callable, ...] = (func1, func2)
The second problem is that numba can't quite figure out the types of the functions before iterating over them. To fix that you can either provide an explicte siganture for at least one of the functions:
@numba.njit("double(double, double)")
def func1(a, b):
return a * b
Or use literal_unroll
:
@numba.njit
def wrapper(a, b):
for func in literal_unroll(l_c):
print(func(a, b))
All of these option producs a lot of warnings about first class function types being an experimental feature, however, in this case it seems to work. If you want to build the list of functions dynamically, you need to do that before the function is compiled. Just build it up as a list and then convert it to a tuple.
To silence the warnings, add warnings.simplefilter("ignore", NumbaExperimentalFeatureWarning)
somewhere in your code.