Search code examples
c++type-traits

Using type trait to ensure a type cannot be derived from itself


I would like to statically check if a class is derived from a base class but not from itself. Below is an example of what I try to achieve. The code compiles unfortunately (never thought I would say this).

I was hoping for the second static assertion to kick in and see that I tried to derive a class from itself. I would appreciate if you guys could help me in understanding better what I am doing wrong. I tried to search online without success.

#include <type_traits>

struct Base {};

template <typename T>
struct Derived : T {
  static_assert(std::is_base_of<Base, T>::value, "Type must derive from Base");
  static_assert(!(std::is_base_of<Derived, T>::value),
                "Type must not derive from Derived");
};

int main(int argc, char** argv) {
  Derived<Base> d__base; // should be OK
  Derived<Derived<Base>> d_d_base; // should be KO

  return 0;
}

Solution

  • Type must not derive from Derived

    Derived is not a type in itself, it's a template which in std::is_base_of<Derived, T>::value gets resolved to the current specialization in the context it's in and it can never be T. If you have Derived<Derived<Base>> then T is Derived<Base> and the Derived without specified template parameters is Derived<Derived<Base>>, so, not the same as T.

    You could add a type trait to check if T is Derived<something>:

    template <template<class...> class F, class T>
    struct is_from_template {
        static std::false_type test(...);
    
        template <class... U>
        static std::true_type test(const F<U...>&);
    
        static constexpr bool value = decltype(test(std::declval<T>()))::value;
    };
    

    Now, using that would prevent the type Derived<Derived<Base>>:

    struct Base {};
    
    template <typename T>
    struct Derived : T {
        static_assert(std::is_base_of<Base, T>::value,
                      "Type must derive from Base");
        static_assert(!is_from_template<Derived, T>::value,
                      "Type must not derive from Derived<>");
    };
    
    int main() {
        Derived<Base> d_base;                // OK
        // Derived<Derived<Base>> d_d_base;  // error
    }