Search code examples
pythonmachine-learningcudanumba

How to call device function inside an object from CUDA kernel in python


I am writing very specific Neural Network and I have many classes of different activation functions, each has function for normal python and one jitted as device function. The problem is calling that method from inside a CUDA kernel.

@cuda.jit(device=True)
def activation_fn(z):
    return max(0, z)

@cuda.jit
def backprop_kernel(arr):
    arr[cuda.threadIdx.x] = activation_fn(arr[cuda.threadIdx.x])

def backprop_GPU(x, y):
    arr = np.array([-3, -2, -1, 0, 1, 2, 3])
    print(arr)
    backprop_kernel[1, 7](arr)
    print(arr)

backprop_GPU(None, None)

This works perfectly fine but I want to make the code bellow work.

class Activation:
    
    @cuda.jit(device=True)
    def fn(z):
        return max(0, z)

class Network:

    def __init__(self):
        self.activation_fn = Activation()
    
    @cuda.jit
    def kernel(arr):
        arr[cuda.threadIdx.x] = activation_fn(arr[cuda.threadIdx.x])

    def backprop(self, x, y):
        arr = np.array([-3, -2, -1, 0, 1, 2, 3])
        self.kernel[1, 7](arr)

net = Network()
net.backprop(None, None)

How do I make the "activation_fn" accesible from the kernel?


Solution

  • @cuda.jit has to be used with functions, not members, so you need to define the decorated functions inside methods, and capture the activation function when you define the kernel:

    from numba import cuda
    import numpy as np
    
    
    class Activation:
        def __init__(self):
            @cuda.jit(device=True)
            def fn(z):
                return max(0, z)
    
            self.fn = fn
    
    
    class Network:
        def __init__(self):
            self.activation = Activation()
    
            activation_fn = self.activation.fn
    
            @cuda.jit
            def kernel(arr):
                arr[cuda.threadIdx.x] = activation_fn(arr[cuda.threadIdx.x])
    
            self.kernel = kernel
    
        def backprop(self, x, y):
            arr = np.array([-3, -2, -1, 0, 1, 2, 3])
            self.kernel[1, 7](arr)
            print(arr)
    
    
    net = Network()
    net.backprop(None, None)
    

    prints:

    $ python repro.py 
    [0 0 0 0 1 2 3]
    

    (Note, I omitted the performance warnings that come out here, as they're orthogonal to the issue at hand)