Search code examples
c++type-traitsc++-conceptsbase-class

Making type trait work for all derived types


I have a type trait and concept which checks if an std::variant can hold a given type T. Now I have a type variant2 which derives from std::variant, and I want to use that type trait with the new variant2 type. How can I accomplish this elegantly?

This doesn't work (Demo):

#include <variant>
#include <string>
#include <iostream>
#include <concepts>
#include <type_traits>

template<typename T, typename Variant>
struct variant_type;

template<typename T, typename... Args>
struct variant_type<T, std::variant<Args...>> 
    : public std::disjunction<std::is_same<T, Args>...> {};

template<typename T, typename Variant>
concept is_variant_type = variant_type<T, Variant>::value;

template <typename... Ts>
struct variant2 : public std::variant<Ts...> {
};

variant2<std::monostate, int, bool, std::string> var;

int main() {
    using T = std::string;

    if constexpr (is_variant_type<T, decltype(var)>) {
        std::cout << "Worked!" << std::endl;
    }
}

"Worked" never appears on the screen which is obvious because the default type trait is SFINAED.


Solution

  • You can just fix your specialization by expressing that your template must inherit from std::variant

    #include <concepts>
    #include <iostream>
    #include <string>
    #include <type_traits>
    #include <variant>
    
    template <typename T, typename Variant>
    struct variant_type;
    
    template <typename T, template <typename...> typename Var, typename... Args>
    struct variant_type<T, Var<Args...>>
        : public std::conjunction<
              std::disjunction<std::is_same<T, Args>...>,
              std::is_base_of<std::variant<Args...>, Var<Args...>>> {};
    
    template <typename T, typename Variant>
    concept is_variant_type = variant_type<T, Variant>::value;
    
    template <typename... Ts>
    struct variant2 : public std::variant<Ts...> {};
    
    variant2<std::monostate, int, bool, std::string> var;
    
    int main() {
        using T = std::string;
    
        if constexpr (is_variant_type<T, decltype(var)>) {
            std::cout << "Worked!" << std::endl;
        }
    }
    

    I just added a conjunction to achieve that and we'are done.
    Live
    [EDIT] sorry, it's just a bit more complicated: I added a template template parameter that implies that your inherited class has exactly the same template parameters than the base one, which might be restrictive.