Search code examples
c++cudajitnvcc

JIT compilation of CUDA __device__ functions


I have a fixed kernel and I want the ability to incorporate user defined device functions to alter the output. The user defined functions will always have the same input arguments and will always output a scalar value. If I knew the user defined functions at compile time I could just pass them in as pointers to the kernel (and have a default device function that operates on the input if given no function). I have access to the user defined function's PTX code at runtime and am wondering if I could use something like NVIDIA's jitify to compile the PTX at run time, get a pointer to the device function, and then pass this device function to the precompiled kernel function.

I have seen a few postings that get close to answering this (How to generate, compile and run CUDA kernels at runtime) but most suggest compiling the entire kernel along with the device function at runtime. Given that the device function has fixed inputs and outputs I don't see any reason why the kernel function couldn't be compiled ahead of time. The piece I am missing is how to compile just the device function at run time and get a pointer to it to then pass to the kernel function.


Solution

  • You can do that doing the following:

    1. Generate your cuda project with --keep, and look-up the generated ptx or cubin for your cuda project.
    2. At runtime, generate your ptx (in our experiment, we needed to store the function pointer in a device memory region, declaring a global variable).
    3. Build a new module at runtime starting with cuLinkCreate, adding first the ptx or cubin from the --keep output and then your runtime generated ptx with cuLinkAddData.
    4. Finally, call your kernel. But you need to call the kernel using the freshly generated module and not using the <<<>>> notation. In the later case it would be in the module where the function pointer is not known. This last phase should be done using driver API (you may want to try runtime API cudaLaunchKernel also).

    The main element is to make sure to call the kernel from the generated module, and not from the module that is magically linked with your program.