Search code examples
pythoncudacublas

Using cublas GEMM in a Python CUDA kernel


I have a simple matrix-matrix multiplication code as below:

TPB = 32
@cuda.jit('void(double[:, :], double[:,:], double[:, :])', device = True)
def GPU_Mat2(A, B, C):
    bx = cuda.blockIdx.x
    by = cuda.blockIdx.y
    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    ROW = bx * TPB + tx
    COL = by * TPB + ty
    res = 0

    for k in range(A.shape[1]):
        if ROW < A.shape[0] and COL < B.shape[1]:
           res += A[ROW, k] * B[k, COL]
    cuda.syncthreads()

    if ROW < A.shape[0] and COL < B.shape[1]:
       C[ROW, COL] = res
    cuda.syncthreads()

and then I call this function in another kernel twice.

@cuda.jit('void(double[:, :], double[:,:], double[:, :], double[:, :])')
def call_Mat2(A, B, C, D):
    for _ in range(200):
        GPU_Mat2(A, B, C)
        GPU_Mat2(C, B, D)        # Is this correct?

Unfortunately, this procedure does not give me the correct answer when compared to the same calculation in host. Even when I use cuda.syncthreads() after each GPU_Mat2 call, the answer is still wrong. My question is that "is it possible to use the output of a kernel call (here C) in another kernel as an input?"

def main():
N = 300
A = np.asfortranarray(np.random.random_sample((N,N)))
B = np.asfortranarray(np.random.random_sample((N,N)))
C_GPU = np.zeros((N,N), dtype = np.double, order = 'F')
D_GPU = np.zeros((N,N), dtype = np.double, order = 'F')

numThreads = [TPB, TPB]
numBlocks =[(A.shape[0]+TPB-1)//TPB, (B.shape[1]+TPB-1)//TPB]

d_A = cuda.to_device(A)
d_B = cuda.to_device(B)
d_C = cuda.to_device(C_GPU)
d_D = cuda.to_device(D_GPU)

call_Mat2[numBlocks, numThreads](d_A, d_B, d_C, d_D)

Second, based on this, it is possible to call "blas GEMM" in a kernel, but I could not find a similar example in python script. Is this type of call supported by python? Your help is appreciated.


Solution

  • As per the documentation:

    Note: newer CUDA devices support device-side kernel launching; this feature is called dynamic parallelism but Numba does not support it currently)

    So no, you cannot call other device library or @cuda.jit functions in numba compiled CUDA Python at the moment.