I have 3 jitted functions, a(x)
, b(x)
and c(x)
.
I need a switch function that does this:
@nb.njit
def switch(i, x):
if i == 0:
return a(x)
elif i == 1:
return b(x)
elif i == 2:
return c(c)
But I would like to write it in a more concise way without performance loss, such as:
functions = (a, b, c)
@nb.njit
def switch(i, x):
functions[i](x)
However, nopython jit compilation can't handle tuples of functions. How can I do this?
As long as the functions share the same type signature it should be possible, see the example at:
https://numba.discourse.group/t/typed-list-of-jitted-functions-in-jitclass/413/4
So an example would be:
from numba import njit
from numba.typed import List
from numba.types import float64
a = njit()(lambda x,y: x-y)
b = njit()(lambda x,y: x+y)
c = njit()(lambda x,y: x/y)
function_type = float64(float64, float64).as_type()
funcs = List.empty_list(function_type)
funcs.append(a)
funcs.append(b)
funcs.append(c)
@njit
def switch(i, funcs, *args):
return funcs[i](*args)
switch(2, funcs, 1, 2)
# 0.5