Search code examples
cudanumba

How to write a Numba function used both in CPU mode and in CUDA device mode?


I want to write a Numba function used both in CPU mode and in CUDA device mode. Of course, I can write two identical functions with and without the cuda.jit decorator. For example:

from numba import cuda, njit

@njit("i4(i4, i4)")
def func_cpu(a, b)
    return a + b

@cuda.jit("i4(i4, i4)", device=True)
def func_gpu(a, b)
    return a + b

But it is ugly in software engineering. Is there a more elegant way, i.e., combining the codes in one function?


Solution

  • A decorator is essentially a function, that takes a function as the input, and also returns a (often modified) function as the output. The addition of arguments and keywords arguments as done with Numba makes it slightly more complicated (internally), but you can think of it as a nested function where the outer one again returns a decorator.

    So instead of using it as a decorator like you do now (with the @), you can just call it as any function and capture the output. And the output will then be a callable function as well.

    This allows writing your function in pure Python, and then apply as many "decorators" on it as you'd like. For example:

    from numba import cuda, njit
    
    def func_py(a, b)
        return a + b
    
    func_njit = njit("i4(i4, i4)")(func_py)
    func_gpu = cuda.jit("i4(i4, i4)", device=True)(func_py)
    
    assert func_py(4, 3) == func_njit(4, 3)
    assert func_py(4, 3) == func_gpu(4, 3)