Search code examples
pythoncudagpucupycuda-wmma

How to use WMMA functions in Cupy kernels?


How to use WMMA functions such as wmma::load_matrix_sync in cupy.RawKernel or cupy.RawModule? can someone provide a minimal example?


Solution

  • We can combine information on cupy RawKernel and wmma programming to provide most of the needed material. I don't intend to give a tutorial on wmma programming, there are other resources for that such as this blog and the cutlass template library.

    Note that the wmma functions require compute capability 7.0 or higher. You must run on a Volta, Turing, or Ampere GPU.

    Let's take the kernel example given in the programming guide. To put this in a RawKernel, we need to provide it as a string. In order to support the use of the kernel C-style, I have broken the kernel code into a __device__ function that can use C++, while exporting the kernel entry point (wmma_ker) using C-style linkage. The example code performs a 16x16 matrix multiply (using a single warp). Here is a worked example:

    # cat t24.py
    import numpy
    import cupy as cp
    ddim = 16
    bdim = 32
    gdim = 1
    a = cp.ones(ddim*ddim, dtype=cp.float16)
    b = cp.ones(ddim*ddim, dtype=cp.float16)
    c = cp.zeros(ddim*ddim, dtype=cp.float32)
    wmma_ker = cp.RawKernel(r'''
      #include <mma.h>
      using namespace nvcuda;
      __device__ void wmma_ker_dev(half *a, half *b, float *c) {
      // Declare the fragments
        wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag;
        wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
        wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;
    
      // Initialize the output to zero
        wmma::fill_fragment(c_frag, 0.0f);
    
      // Load the inputs
        wmma::load_matrix_sync(a_frag, a, 16);
        wmma::load_matrix_sync(b_frag, b, 16);
    
      // Perform the matrix multiplication
        wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
    
      // Store the output
        wmma::store_matrix_sync(c, c_frag, 16, wmma::mem_row_major);
      }
      extern "C" {
        __global__ void wmma_ker(half *a, half *b, float *c) {
              wmma_ker_dev(a,b,c);
        }
      }
     ''', 'wmma_ker', options=("-restrict","-lineinfo"))
    wmma_ker((gdim,1), (bdim,1), (a,b,c))  # grid, block and arguments
    r_o = cp.asnumpy(c)
    print(r_o)
    # cuda-memcheck python t24.py
    ========= CUDA-MEMCHECK
    [16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16. 16.
     16. 16. 16. 16.]
    ========= ERROR SUMMARY: 0 errors
    #
    

    I used pip install cupy-cuda102 to set up cupy for this, otherwise running on a machine with CUDA 10.2 installed, and a Tesla V100 GPU. The RawKernel options I have provided are unnecessary for this demonstration, you could omit that argument entirely.

    The purpose of this code is to demonstrate an example method. I'm not suggesting the code is defect free or suitable for any particular purpose. Use it at your own risk. In particular, I would not expect this code to work correctly if any aspect of it is changed. I am not suggesting that it is a general/flexible/extensible matrix multiply routine.