Search code examples
compiler-errorscudaoverridingvirtual-functionsnvcc

Can I override a CUDA host-and-device function with a host-only function?


Consider the following program:

class A {
    __host__  __device__ void foo();
};

class B : A {
    __host__ void foo();
};

int main()
{
    A a; (void) a; 
    B b; (void) b;
}

This compiles (GodBolt) with nvcc 10.

Yet, in more complex programs, I sometimes get the following error (line breaks for readability):

whatever.hpp(88): error: execution space mismatch: overridden entity (function
"C::foo") is a __host__ __device__ function, but overriding entity (function "D::foo")
is a __host__ function

So, nvcc is telling me that I'm not supposed to drop an execution space when overriding methods. I'm not asking about my own code (which I haven't cited here), but about the principle:

  • If it's acceptable to override __host__ __device__ functions with just __host__ functions (which I find reasonable) - then how can nvcc even have such an error?
  • Alternatively, if it's not allowed - why is the small program above compiling?

Solution

  • Overriding (virtual) methods must respect execution space choice of the overridden method.

    "overriding" is only relevant to virtual methods - so it must be the case that your C::foo() is marked virtual. And indeed, if we mark foo() in the example program as virtual:

    class A {
        virtual __host__  __device__ void foo();
    };
    
    class B : A {
        __host__ void foo(); // can say "override" here; but it doesn't matter
    };
    
    int main()
    {
        A a; (void) a; 
        B b; (void) b;
    }
    

    This will fail to compile:

    <source>(6): error: member function declared with "override" does not override
    a base class member
    

    Does this limitation make sense? One could imagine an interpretation in which the base-class method will apply to __device__-side calls, and the subclass method to __host__-side calls. But that too is a bit awkward - and we need to call something when acting on an object via a base class ptr.