Search code examples
jcuda

jcuda cuModuleLoad() cannot load file using the path of getClass().getResource().getPath()


I am trying to use cuModuleLoad() in JCuda to load a vectorAdd.ptx file from /src/main/resources. The code is as follows:

cuModuleLoad(module, getClass.getResource("vectorAdd.ptx").getPath())

But the cuModuleLoad() doesn't pick up this file. It only works when I pass in the absolute path of the ptx file. But I would like to have the ptx file shipped with compile jar files. Is there any way to accomplish this?


Solution

  • The cuModuleLoad function in JCuda is a direct mapping to the corresponding cuModuleLoad function in CUDA. It expects a file name as the second argument.

    The problem is: cuModuleLoad can not load the PTX file, because the PTX file simply does not exist for CUDA! The PTX file is hidden inside the JAR file.


    When you obtain a resource from a JAR file using someClass.getResource(), then it will point to the resource in the JAR file. When you do something like

    System.out.println(getClass().getResource("vectorAdd.ptx").getPath());
    

    and run this (as a JAR file), then you will see an output like this:

    file:/U:/YourWorkspace/YourJarFile.jar!/vectorAdd.ptx
    

    Note the .jar! part: This path is not a path to a real file, but only a path to a resource in the JAR.


    In order to load the PTX file from a JAR, you have to read the PTX file from the JAR into a byte[] array on Java side, and then pass it to the cuModuleLoadData function of JCuda (which corresponds to the cuModuleLoadData function of CUDA).

    Here is an example that loads the PTX data from a JAR file into a byte array, representing the zero-terminated string that can be passed to cuModuleLoadData:

    import static jcuda.driver.JCudaDriver.cuCtxCreate;
    import static jcuda.driver.JCudaDriver.cuDeviceGet;
    import static jcuda.driver.JCudaDriver.cuInit;
    import static jcuda.driver.JCudaDriver.cuModuleGetFunction;
    import static jcuda.driver.JCudaDriver.cuModuleLoadData;
    import static jcuda.runtime.JCuda.cudaDeviceReset;
    
    import java.io.ByteArrayOutputStream;
    import java.io.IOException;
    import java.io.InputStream;
    
    import jcuda.driver.CUcontext;
    import jcuda.driver.CUdevice;
    import jcuda.driver.CUfunction;
    import jcuda.driver.CUmodule;
    import jcuda.driver.JCudaDriver;
    
    public class JCudaPtxInJar
    {
        public static void main(String args[]) throws IOException
        {
            // Initialization
            JCudaDriver.setExceptionsEnabled(true);
            cuInit(0);
            CUdevice device = new CUdevice();
            cuDeviceGet(device, 0);
            CUcontext context = new CUcontext();
            cuCtxCreate(context, 0, device);
    
            // Read the PTX data into a zero-terminated string byte array
            byte ptxData[] = toZeroTerminatedStringByteArray(
                JCudaPtxInJar.class.getResourceAsStream(
                    "JCudaVectorAddKernel.ptx"));
    
            // Load the module data
            CUmodule module = new CUmodule();
            cuModuleLoadData(module, ptxData);
    
            // Obtain a function pointer to the "add" function
            // and print a simple test/debug message
            CUfunction function = new CUfunction();
            cuModuleGetFunction(function, module, "add");
            System.out.println("Got function "+function);
    
            cudaDeviceReset();
        }
    
        /**
         * Read the contents of the given input stream, and return it
         * as a byte array containing the ZERO-TERMINATED string data 
         * from the stream. The caller is responsible for closing the
         * given stream.
         * 
         * @param inputStream The input stream
         * @return The ZERO-TERMINATED string byte array
         * @throws IOException If an IO error occurs
         */
        private static byte[] toZeroTerminatedStringByteArray(
            InputStream inputStream) throws IOException
        {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            byte buffer[] = new byte[8192];
            while (true)
            {
                int read = inputStream.read(buffer);
                if (read == -1)
                {
                    break;
                }
                baos.write(buffer, 0, read);
            }
            baos.write(0);
            return baos.toByteArray();
        }
    }
    

    Compiling this and packing it into a JAR (together with the /resources/JCudaVectorAddKernel.ptx PTX file, of course) will allow you to start the program and obtain the example function from the PTX in the JAR.