Search code examples

How to call device function inside an object from CUDA kernel in python

I am writing very specific Neural Network and I have many classes of different activation functions, each has function for normal python and one jitted as device function. The problem is calling that method from inside a CUDA kernel.

def activation_fn(z):
    return max(0, z)

def backprop_kernel(arr):
    arr[cuda.threadIdx.x] = activation_fn(arr[cuda.threadIdx.x])

def backprop_GPU(x, y):
    arr = np.array([-3, -2, -1, 0, 1, 2, 3])
    backprop_kernel[1, 7](arr)

backprop_GPU(None, None)

This works perfectly fine but I want to make the code bellow work.

class Activation:
    def fn(z):
        return max(0, z)

class Network:

    def __init__(self):
        self.activation_fn = Activation()
    def kernel(arr):
        arr[cuda.threadIdx.x] = activation_fn(arr[cuda.threadIdx.x])

    def backprop(self, x, y):
        arr = np.array([-3, -2, -1, 0, 1, 2, 3])
        self.kernel[1, 7](arr)

net = Network()
net.backprop(None, None)

How do I make the "activation_fn" accesible from the kernel?


  • @cuda.jit has to be used with functions, not members, so you need to define the decorated functions inside methods, and capture the activation function when you define the kernel:

    from numba import cuda
    import numpy as np
    class Activation:
        def __init__(self):
            def fn(z):
                return max(0, z)
            self.fn = fn
    class Network:
        def __init__(self):
            self.activation = Activation()
            activation_fn = self.activation.fn
            def kernel(arr):
                arr[cuda.threadIdx.x] = activation_fn(arr[cuda.threadIdx.x])
            self.kernel = kernel
        def backprop(self, x, y):
            arr = np.array([-3, -2, -1, 0, 1, 2, 3])
            self.kernel[1, 7](arr)
    net = Network()
    net.backprop(None, None)


    $ python 
    [0 0 0 0 1 2 3]

    (Note, I omitted the performance warnings that come out here, as they're orthogonal to the issue at hand)