I am trying to launch a kernel in pyCUDA and then terminate the kernel by writing to a GPU global memory location. Here is a simple example kernel that I would like to be able to terminate at some point after it enters the infinite while loop:
__global__ void countUp(u16 *inShot, u64 *counter) {
while(inShot[0]) {
counter[0]++;
}
}
From what I have read about streams in CUDA, I should be able to launch this kernel after creating a stream and it will be non-blocking on the host, ie. I should be able to do stuff on the host after this kernel is launched and is running. I compile the above kernel to a cubin file and launch it in pyCUDA like so:
import numpy as np
from pycuda import driver, compiler, gpuarray, tools
# -- initialize the device
import pycuda.autoinit
strm1 = driver.Stream()
h_inShot = np.zeros((1,1))
d_inShot = gpuarray.to_gpu_async(h_inShot.astype(np.uint16), stream = strm1)
h_inShot = np.ones((1,1))
h_counter = np.zeros((1,1))
d_counter = gpuarray.to_gpu_async(h_counter.astype(np.uint64), stream = strm1)
testCubin = "testKernel.cubin"
mod = driver.module_from_file(testCubin)
countUp = mod.get_function("countUp")
countUp(d_inShot, d_counter,
grid = (1, 1, 1),
block = (1, 1, 1),
stream = strm1
)
Running this script causes the kernel to enter an infinite while loop for obvious reasons. Launching this script from the ipython environment does not seem to return control to the host after the kernel launch (I can't input new commands as I guess its waiting for the kernel to finish). I would like control to return to the host so that I can change the value in GPU global memory pointer d_inShot and have the kernel exit the while loop. Is this even possible and if so, how do I do it in pyCUDA? Thanks.
I figured this out, so am posting my solution. Even though asynchronous memcpy's are non-blocking, I discovered that doing a memcpy using the same stream as a running kernel does not work. My solution was to create another stream:
strm2 = driver.Stream()
and then change d_inShot like so:
d_inShot.set_async(h_inShot.astype(np.uint16), stream = strm2)
And this has worked for me.