Search code examples
c++structcudanvcc

Launching Cuda Call from struct


Given a simple struct to wrap the cuda code one can write something like

func<float> s;
s.val = 3.f;
start_correct<<<1, 2>>>(s);

However, I would like to put the block, grid, shared memory computation into the struct and call the kernel like

func<float> s;
s.val = 3.f;
s.launch();

While the first is working, the second gives me an illegal memory access error.

A minimal example to reproduce my problem is

#include <stdio.h>

template<typename T>
struct func;

template<typename T>
__global__ void start(const func<T>& s){
  printf("host access val %f \n",s.val);
  s();
}

template<typename T>
struct func
{
  T val;

  __device__ void operator()() const{
    printf("device access val %f [%d]\n",val,threadIdx.x);
  }

  enum{ C_N = 2 };

  void launch()
  {
    start<<<1, C_N>>>(*this);
  }

};

template<typename T>
__global__ void start_correct(const func<T> s){
  printf("host access val %f \n", s.val);
  s();
}

int main(int argc, char const *argv[])
{
  cudaError_t err;

  func<float> s;
  s.val = 3.f;

  // launch cuda kernel <-- WORKS
  start_correct<<<1, 2>>>(s);
  cudaDeviceSynchronize();
  if (err != cudaSuccess) printf("Error: %s\n", cudaGetErrorString(err));


  // launch cuda kernel <-- DOES NOT WORK
  s.launch();
  cudaDeviceSynchronize();
  err = cudaGetLastError();
  if (err != cudaSuccess) printf("Error: %s\n", cudaGetErrorString(err));


  return 0;
}

The output is

host access val 3.000000 
host access val 3.000000 
device access val 3.000000 [0]
device access val 3.000000 [1]
host access val 0.000000 
host access val 0.000000 
device access val 0.000000 [0]
device access val 0.000000 [1]
Error: an illegal memory access was encountered

Shouldn't both ways be equivalent? Are there any alternatives, which also do the shm, grid calculations inside the struct?


Solution

  • Unless you are using managed memory (which you aren't), it is not legal to pass kernel parameters by reference:

    __global__ void start(const func<T>& s){
                                       ^
    

    When I remove that ampersand, your code runs without any runtime error for me, and gives sensible output:

    $ cuda-memcheck ./t355
    ========= CUDA-MEMCHECK
    host access val 3.000000
    host access val 3.000000
    device access val 3.000000 [0]
    device access val 3.000000 [1]
    host access val 3.000000
    host access val 3.000000
    device access val 3.000000 [0]
    device access val 3.000000 [1]
    ========= ERROR SUMMARY: 0 errors
    $
    

    Note that this doesn't really make sense:

      cudaDeviceSynchronize();
      if (err != cudaSuccess) printf("Error: %s\n", cudaGetErrorString(err));
    

    and throws a compiler warning for me.

    Perhaps you meant:

      err = cudaDeviceSynchronize();
      if (err != cudaSuccess) printf("Error: %s\n", cudaGetErrorString(err));