Search code examples
pythonnumba

How to make a function switch with numba


I have 3 jitted functions, a(x), b(x) and c(x).

I need a switch function that does this:

@nb.njit
def switch(i, x):
    if i == 0:
        return a(x)
    elif i == 1:
        return b(x)
    elif i == 2:
        return c(c)

But I would like to write it in a more concise way without performance loss, such as:

functions = (a, b, c)

@nb.njit
def switch(i, x):
    functions[i](x)

However, nopython jit compilation can't handle tuples of functions. How can I do this?


Solution

  • As long as the functions share the same type signature it should be possible, see the example at:
    https://numba.discourse.group/t/typed-list-of-jitted-functions-in-jitclass/413/4

    So an example would be:

    from numba import njit
    from numba.typed import List
    from numba.types import float64
    
    a = njit()(lambda x,y: x-y)
    b = njit()(lambda x,y: x+y)
    c = njit()(lambda x,y: x/y)
    
    function_type = float64(float64, float64).as_type()
    
    funcs = List.empty_list(function_type)
    funcs.append(a)
    funcs.append(b)
    funcs.append(c)
    
    @njit
    def switch(i, funcs, *args):
        return funcs[i](*args)
    
    switch(2, funcs, 1, 2)
    # 0.5