Search code examples
c++templatesoperator-overloadingc++17operator-keyword

How to use overload `operator==` in a `std::variant` wrapper class to make comparisons between Setting Vs Setting and T vs T?


I'm trying to write a templated operator== for a wrapper class around a std::variant. The idea is that the Setting class is comparible with other Setting objects, as well as the types supported by the variant. I have already solved this problem without templates as its easy to just write out the operators==, but its important for me to learn the templating way.

So, this is how I'd like the Setting to be used:

Setting s1("string");
Setting s2("string");
s1 == s2; // okay, equals true

As well as

Setting s3("string");
s3 == "string"; // should equal true
std::string s4 = "string";
s3 == s4; // also True

Here's what I've got so far, though I'm pretty sure I'm a long way off. The strategy is to template the operator= such that if the template argument T is a valid variant (setting_t) type, then extract the value from the variant as a type T and perform the comparison (T compares with T). Alternatively, when T is another Setting, we can just directly compare the setting_ member variables (Setting compares with Setting).

#include <type_traits>
#include <variant>

using setting_t = std::variant<std::string, int, double>;

/**
 * Utility which is true when
 * type T is in a variant, false otherwise.
 * For instance,
 *  std::string x("a String");
 *  bool truth = isValidVariantType<decltype(x), setting_t>(); // true
 *  
 *  unsigned long x = 4;
 *  bool truth = isValidVariantType<decltype(x), setting_t>(); // false
 */
template<typename T, typename ALL_T>
struct isValidVariantType;

template<typename T, typename... ALL_T>
struct isValidVariantType<T, std::variant<ALL_T...>>
        : public std::disjunction<std::is_same<T, ALL_T>...> {
};

class Setting {
public:
    explicit Setting(setting_t setting)
            : setting_(std::move(setting)) {}

    template <typename T,
            class = typename std::enable_if<isValidVariantType<T, setting_t>::value>::type>
    bool operator==(const T& setting){
        T val = std::get<T>(setting);
        return val == setting;
    }

private:
    setting_t setting_;
};

I've spent a good number of hours on this now so I'd appreciate any advice you can give me. Thanks, in advance!

edit - compiler errors

As requested, here is what the compiler currently generates

When I run SettingTests.SettingVsSetting

TEST(SettingTests, SettingVsSetting){
    Setting setting1("a String");
    Setting setting2("a String");
//    bool truth = setting1 == setting2;
}

generates the following compiler messages:

/home/ciaran/SettingTests/SRC/TemplateTutorialTests.cpp: In member function ‘virtual void SettingTests_SettingVsSetting_Test::TestBody()’:
/home/ciaran/SettingTests/SRC/TemplateTutorialTests.cpp:11:27: error: no match for ‘operator==’ (operand types are ‘Setting’ and ‘Setting’)
   11 |     bool truth = setting1 == setting2;
      |                  ~~~~~~~~ ^~ ~~~~~~~~
      |                  |           |
      |                  Setting     Setting
In file included from /home/ciaran/SettingTests/SRC/TemplateTutorialTests.cpp:2:
/home/ciaran/SettingTests/SRC/TermplateTutorial.hpp:32:10: note: candidate: ‘template<class T, class> bool Setting::operator==(const T&)’
   32 |     bool operator==(const T& setting){
      |          ^~~~~~~~
/home/ciaran/SettingTests/SRC/TermplateTutorial.hpp:32:10: note:   template argument deduction/substitution failed:
/home/ciaran/SettingTests/SRC/TermplateTutorial.hpp:31:13: error: no type named ‘type’ in ‘struct std::enable_if<false, void>’
   31 |             class = typename std::enable_if<isValidVariantType<T, setting_t>::value>::type>
      |             ^~~~~

whilst Setting.SettingVsString

TEST(SettingTests, SettingVsString){
    Setting setting1("a String");
    std::string setting2("a String");
    bool truth = setting1 == setting2;
}

generates

/home/ciaran/SettingTests/SRC/TermplateTutorial.hpp: In instantiation of ‘bool Setting::operator==(const T&) [with T = std::__cxx11::basic_string<char>; <template-parameter-1-2> = void]’:
/home/ciaran/SettingTests/SRC/TemplateTutorialTests.cpp:17:30:   required from here
/home/ciaran/SettingTests/SRC/TermplateTutorial.hpp:33:28: error: no matching function for call to ‘get<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >(const std::__cxx11::basic_string<char>&)’
   33 |         T val = std::get<T>(setting);
      |                 ~~~~~~~~~~~^~~~~~~~~
In file included from /usr/include/c++/10/bits/unique_ptr.h:36,
                 from /usr/include/c++/10/memory:83,
                 from /home/ciaran/SettingTests/googletest/googletest/include/gtest/gtest.h:57,
                 from /home/ciaran/SettingTests/SRC/TemplateTutorialTests.cpp:1:
/usr/include/c++/10/utility:223:5: note: candidate: ‘template<long unsigned int _Int, class _Tp1, class _Tp2> constexpr typename std::tuple_element<_Int, std::pair<_Tp1, _Tp2> >::type& std::get(std::pair<_Tp1, _Tp2>&)’
  223 |     get(std::pair<_Tp1, _Tp2>& __in) noexcept
      |     ^~~
/usr/include/c++/10/utility:223:5: note:   template argument deduction/substitution failed:
/usr/include/c++/10/utility:228:5: note: candidate: ‘template<long unsigned int _Int, class _Tp1, class _Tp2> constexpr typename std::tuple_element<_Int, std::pair<_Tp1, _Tp2> >::type&& std::get(std::pair<_Tp1, _Tp2>&&)’
  228 |     get(std::pair<_Tp1, _Tp2>&& __in) noexcept
      |     ^~~

... (it goes on like this for a while)

Edit 3 - alternative operator==

    template<typename T,
            class = typename std::enable_if<isValidVariantType<T, setting_t>::value>::type>
    bool operator==(const T &setting) {
        if (auto val = std::get_if<T>(&setting_)){
            return *val == setting;
        };
        return false;
    }

Solution

  • You have a typo; you are calling std::get on the setting argument, instead of this->setting_. Fixing that makes your code compile.

    But you can do better.

    template<class D, class T>
    class SettingEqual {
        D const& self() const { return *static_cast<D const*>(this); }
        D & self() { return *static_cast<D*>(this); }
        decltype(auto) setting() { return self().setting_; }
        decltype(auto) setting() const { return self().setting_; }
    
        friend bool operator==( SettingEqual const& self, T const& t ) {
          if (!std::holds_alternative<T>(self.setting())) return false;
          return std::get<T>(self.setting()) == t;
        }
        friend bool operator==( T const& t, SettingEqual const& self ) {
          return (self==t);
        }
        friend bool operator!=( T const& t, SettingEqual const& self ) {
          return !(self==t);
        }
        friend bool operator!=( SettingEqual const& self, T const& t ) {
          return !(self==t);
        }
    };
    template<class...Ts>
    class SettingT:
        public SettingEqual<SettingT<Ts...>, Ts>...
    {
        template<class D, class T>
        friend class SettingEqual;
    public:
        explicit SettingT(std::variant<Ts...> setting)
                : setting_(std::move(setting)) {}
    
    private:
        std::variant<Ts...> setting_;
    };
    using Setting = SettingT<std::string, int, double>;
    

    this introduces a == and != overload, on both left and right, that participates in overload resolution.

    It also checks if the type matches, and says mismatched types are not equal.

    The techniques used are "CRTP", where I shove implementation up into a parent class which static cast downcasts, and Koenig or ADL friend operators, which lets me inject non-template operator==s into the lookup of Setting == something to participate in overload resolution.

    Live example.

    Now this isn't bad, but it runs into the problem that it converts things to T. So SettingT<std::string> == "hello" goes off and creates a std::string then puts "hello" in it, then compares the std::string in the SettingT into it.

    Really, we want to just dispatch the "hello" directly to the std::string==, or do something better.

    template<class T>
    struct tag_t {using type=T;};
    template<class T>
    constexpr tag_t<T> tag{};
    
    template<class T>
    struct overload_detect {
      auto operator()(T const&){return tag<T>;};
    };
    template<class...Ts>
    struct overload_detector:overload_detect<Ts>... {
      using overload_detect<Ts>::operator()...;
    };
    template<class T0, class...Ts>
    using best_conversion = typename decltype(overload_detector<Ts...>{}( std::declval<T0 const&>() ))::type;
    template<class T, class...Ts>
    concept any_conversion = requires (T a) {
        { (std::void_t<best_conversion<T, Ts...>>)(0) };
    };
    
    template<class...Ts>
    class SettingT
    {
    public:
        explicit SettingT(std::variant<Ts...> setting)
                : setting_(std::move(setting)) {}
        template<any_conversion<Ts...> U>
        friend bool operator==(SettingT const& self, U const& u) {
            using T = best_conversion<U, Ts...>;
            if (!std::holds_alternative<T>(self.setting_)) return false;
            return std::get<T>(self.setting_) == u;
        }
        template<any_conversion<Ts...> U>
        friend bool operator==(U const& u,SettingT const& self) {
            return self==u;
        }
        template<any_conversion<Ts...> U>
        friend bool operator!=(U const& u,SettingT const& self) {
            return !(self==u);
        }
        template<any_conversion<Ts...> U>
        friend bool operator!=(SettingT const& self, U const& u) {
            return !(self==u);
        }
    private:
        std::variant<Ts...> setting_;
    };
    

    now best_conversion<X, Ys...> finds the best conversion among Ys for X and returns that type, and we dispatch to the == without first doing the conversion.

    == is free to do the conversion, or it can do something more efficient if it chooses.

    Live example.