Search code examples
c++operator-overloadingc++17stdmapstd-variant

Provide a operator== for std::variant


I am trying to create an operator== operator for an std::variant defined in the map like this:

struct anyType 
{
   template<typename T>
   void operator()(T t) const { std::cout << t; }
   void operator()(const std::string& s) const { std::cout << '\"' << s << '\"'; }
};

template<typename T>
bool operator==(const std::variant<float, int, bool, std::string>& v, const& T t) 
{
   return v == t;
}

int main()
{
   std::map<std::string, std::variant<float, int, bool, std::string>> kwargs;
   kwargs["interface"] = "linear"s;
   kwargs["flag"] = true;
   kwargs["height"] = 5;
   kwargs["length"] = 6;
   //test 
   if (kwarg["interface"] == "linear") // stack overflow Error here 
   { 
      std::cout << true << '\n';
   }
   else
   {
      std::cout << false << '\n';
   }
}

Can someone tell me why my operator isn't working?


Solution

  • You have a couple of issues in your code:

    • const &T t in your operator==, should be T const& t or const T& t.

    • You have forgotten to mention that you want to compare with a std::string not with char array in your if statement(i.e. "linear"). Meaning you need either of the followings:

      if (kwargs["interface"] == std::string{ "linear" })
      // or 
      using namespace std::string_literals;
      if (kwargs["interface"] == "linear"s)  // since C++14
      
    • When you do the comparison like this

      if (kwargs["interface"] == "linear") // checking std::variant == char [7] 
      

      You are checking the std::variant<float, int, bool, std::string>(i.e. v) with type char [7](i.e. type of linear). When the condition reaches the operator=='s definition you do again the same by

      return v == t; // checking std::variant == char [7] 
      

      This leads to a recursive call to the templated operator== itself and hence stack overflow.


    In order to fix, you needstrong text to explicitly specify the value from the variant either by index or by type. For example, chacking the type using std::is_same and if constexpr:

    (See live online)

    #include <type_traits> std::is_same_v
    
    template<typename T>
    bool operator==(const std::variant<float, int, bool, std::string>& v, T const& t)
    {
       if constexpr (std::is_same_v<T, float>)            // float    
          return std::get<float>(v) == t;        
       else if constexpr (std::is_same_v<T, int>)         // int
          return std::get<int>(v) == t;
       else if constexpr (std::is_same_v<T, bool>)        // boolean
          return std::get<bool>(v) == t;
       else if constexpr (std::is_same_v<T, std::string>) // std::string
          return std::get<std::string>(v) == t;
    }
    

    or simply (Credits @Barry)

    template<typename T>
    bool operator==(const std::variant<float, int, bool, std::string>& v, T const& t)
    {
       return std::get<T>(v) == t;
    }
    

    Now if you pass any other types other than v contains, you will get a compile-time error for the templated operator==.


    Generic Solution!

    For a generic std::varaint<Types...>, one can do as follows. In addition, it has been SFINAE d for only those types which are in the passed std::variant<Types>. I have used the is_one_of trait from this post.

    (See Live Online)

    #include <variant>
    #include <type_traits>
    
    // A trait to check that T is one of 'Types...'
    template <typename T, typename...Types>
    struct is_one_of final : std::disjunction<std::is_same<T, Types>...> {};
    
    template<typename... Types, typename T>
    auto operator==(const std::variant<Types...>& v, T const& t) noexcept
       -> std::enable_if_t<is_one_of<T, Types...>::value, bool>
    {
       return std::get<T>(v) == t;
    }