Search code examples
c++stlstl-algorithm

Finding a C++ object in the set by object comparision instead of using functors


I want to populate a std::set of GraphNode objects and check if another GraphNode with the same value exists in the set. In Java, objects can be compared by overloading equals and compareTo methods, instead of creating some functor object. I implemented operator==(T& t) and expected to find the object in the set like this,

std::find(nodesSet->begin(),nodesSet->end(), new GraphNode<T>(1))!=nodesSet->end())

But I am not getting the break point in neither == nor ()() operator functions. Why is it so? Is there a way to find the object by object comparison?

template<class T>
class GraphNode
{
    friend class Graph<T>;
    friend bool operator==(GraphNode<T>& node1, GraphNode<T>& node2);

    private:
        T t;
        std::vector<GraphNode<T>*> adjNodes;

    public:
        bool operator==(T& t);
};

template<class T>
inline bool GraphNode<T>::operator==(T & t)
{
    return this->t == t ? true : false;
}

template<class T>
inline bool operator==(GraphNode<T>& node1, GraphNode<T>& node2)
{
    return node1.t == node2.t ? true : false;
}

void populate()
{
     std::set<GraphNode<T>*>* nodesSet = new set<GraphNode<T>*>;
     nodeSet->insert(new GraphNode<T>(1));
     nodeSet->insert(new GraphNode<T>(2));
     if ( std::find( nodesSet->begin(),nodesSet->end(),
                     new GraphNode<T>(1) ) != nodesSet->end() )
     {
         cout<<"found value";
     }
}

Solution

  • As aschepler pointed out, the problem with your code is that you end up comparing pointers, not objects. std::find (look at the possible implementations in the linked page), if called without a predicate, uses the == operator to compare what is returned when the iterators you give it are dereferenced. In your case, you have a std::set<GraphNode<T>*> nodesSet, so the type of *nodesSet.begin() is GraphNode<T>*, not GraphNode<T> (note the lack of star). In order for you to be able to use the == operator defined for your GraphNode, you need to have your set be std::set<GraphNode<T>>, that is of objects of your type rather than of pointers.

    If you have to store pointers in your set (e.g. because you don't want to copy the objects), you can write a wrapper for pointers that uses the comparison operator for the underlying class of the pointers. Here's an example:

    #include <iostream>
    #include <set>
    #include <algorithm>
    
    class obj {
      int i;
    public:
      obj(int i): i(i) { }
      bool operator<(const obj& o) const { return i < o.i; }
      bool operator==(const obj& o) const { return i == o.i; }
      int get() const { return i; }
    };
    
    template <typename T>
    class ptr_cmp {
      T* p;
    public:
      ptr_cmp(T* p): p(p) { }
      template <typename U>
      bool operator<(const ptr_cmp<U>& o) const { return *o.p < *p; }
      template <typename U>
      bool operator==(const ptr_cmp<U>& o) const { return *o.p == *p; }
      T& operator*() const { return *p; }
      T* operator->() const { return p; }
    };
    
    int main(int argc, char* argv[])
    {
      obj five(5), seven(7);
    
      std::set<ptr_cmp<obj>> s;
      s.insert(&five);
      s.insert(&seven);
    
      obj x(7);
    
      std::cout << (*std::find(s.begin(),s.end(), ptr_cmp<obj>(&x)))->get()
                << std::endl;
    
      return 0;
    }
    

    It turned out that my compiler (gcc 6.2.0) required both operator== and operator< for std::find to work without a predicate.

    What is wrong with using a predicate though? It is a more generalizable approach. Here's an example:

    #include <iostream>
    #include <set>
    #include <algorithm>
    
    class obj {
      int i;
    public:
      obj(int i): i(i) { }
      bool operator==(const obj& o) const { return i == o.i; }
      int get() const { return i; }
    };
    
    template <typename T>
    struct ptr_cmp {
      const T *l;
      ptr_cmp(const T* p): l(p) { }
      template <typename R>
      bool operator()(const R* r) { return *l == *r; }
    };
    template <typename T>
    ptr_cmp<T> make_ptr_cmp(const T* p) { return ptr_cmp<T>(p); }
    
    int main(int argc, char* argv[])
    {
      obj five(5), seven(7);
    
      std::set<obj*> s;
      s.insert(&five);
      s.insert(&seven);
    
      obj x(7);
    
      std::cout << (*std::find_if(s.begin(),s.end(), make_ptr_cmp(&x)))->get()
                << std::endl;
    
      return 0;
    }
    

    Note, that make_ptr_cmp allows you to avoid explicitly stating the type, so you can write generic code.

    If you can use C++11, use can just use a lambda function instead of ptr_cmp,

    std::find_if(s.begin(),s.end(), [&x](const obj* p){ return *p == x; } )