Search code examples
cudaoperator-overloadingthrust

Thrust reduction and overloaded operator-(const float3&, const float3&) won't compile


I overload operators to have a vector space over float3 (and similar structs) in vectorspace.cuh:

// Boilerplate vector space over data type Pt
#pragma once

#include <type_traits>


// float3
__device__ __host__ float3 operator+=(float3& a, const float3& b) {
    a.x += b.x; a.y += b.y; a.z += b.z;
    return a;
}

__device__ __host__ float3 operator*=(float3& a, const float b) {
    a.x *= b; a.y *= b; a.z *= b;
    return a;
}

// float4
__device__ __host__ float4 operator+=(float4& a, const float4& b) {
    a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w;
    return a;
}

__device__ __host__ float4 operator*=(float4& a, const float b) {
    a.x *= b; a.y *= b; a.z *= b; a.w *= b;
    return a;
}


// Generalize += and *= to +, -=, -, *, /= and /
template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator+(const Pt& a, const Pt& b) {
    auto sum = a;
    sum += b;
    return sum;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator-=(Pt& a, const Pt& b) {
    a += -1*b;
    return a;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator-(const Pt& a, const Pt& b) {
    auto diff = a;
    diff -= b;
    return diff;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator-(const Pt& a) {
    return -1*a;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator*(const Pt& a, const float b) {
    auto prod = a;
    prod *= b;
    return prod;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator*(const float b, const Pt& a) {
    auto prod = a;
    prod *= b;
    return prod;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator/=(Pt& a, const float b) {
    a *= 1./b;
    return a;
}

template<typename Pt> __device__ __host__
typename std::enable_if<std::is_class<Pt>::value || std::is_enum<Pt>::value, Pt>::type
operator/(const Pt& a, const float b) {
    auto quot = a;
    quot /= b;
    return quot;
}

These overloads break compilation of thrust::reduce, here an example:

#include <thrust/reduce.h>
#include <thrust/execution_policy.h>

#include "vectorspace.cuh"


int main(int argc, char const *argv[]) {
    int n = 10;
    float3* d_arr;
    cudaMalloc(&d_arr, n*sizeof(float3));

    auto sum = thrust::reduce(thrust::device, d_arr, d_arr + n, float3 {0});

    return 0;
}

Using nvcc -std=c++11 -arch=sm_52 on Ubuntu 16.04 this results in 200+ lines of compiler errors:

$ nvcc -std=c++11 -arch=sm_52 sandbox/mean.cu 
sandbox/mean.cu(26): error: no operator "*" matches these operands
            operand types are: int * const thrust::zip_iterator<thrust::tuple<const float3 *, thrust::pointer<float3, thrust::system::cuda::detail::par_t, thrust::use_default, thrust::use_default>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>
          detected during:
            instantiation of "std::enable_if<<expression>, Pt>::type operator-=(Pt &, const Pt &) [with Pt=thrust::zip_iterator<thrust::tuple<const float3 *, thrust::pointer<float3, thrust::system::cuda::detail::par_t, thrust::use_default, thrust::use_default>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>]" 
(35): here
            instantiation of "std::enable_if<<expression>, Pt>::type operator-(const Pt &, const Pt &) [with Pt=thrust::zip_iterator<thrust::tuple<const float3 *, thrust::pointer<float3, thrust::system::cuda::detail::par_t, thrust::use_default, thrust::use_default>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>>]" 

...

How can I overload the operators without breaking thrust?


Solution

  • (Edited following OP's edits.)

    The problem is with the 'reach' of your operator overloads: You are not overloading only for the classes you're interested in, but for all classes fitting your enable_if condition - which is quite relaxed. That's already a serious bug even if things would compile.

    More specifically, thrust uses arithmetic operations, e.g. on "zip iterators" (never mind what they are), and compilations of your operations for such iterators fails, understandably.

    So you must either:

    • Specify exactly which classes the overload is relevant to (e.g., using a disjunction of std::is_same in the enable_if), or
    • use a trait class:

      template<class T> struct needs_qivs_arithmetic_operators : public std::false_type {};
      
      template<> struct needs_qivs_arithmetic_operators<float3> : public std::true_type {};
      template<> struct needs_qivs_arithmetic_operators<float4> : public std::true_type {};
      /* ... etc. You can also add specializations elsewhere in the translation unit. */