Search code examples
c++setstl-algorithm

Execute a function on matching pairs in a map


I have some code that looks roughly like this; given two maps, if the first key exists in both maps, then multiply the two second values together, then sum all the products. For example:

s1 = {{1, 2.5}, {2, 10.0}, {3, 0.5}};
s2 = {{1, 10.0},           {3, 20.0}, {4, 3.33}};

The answer should be 2.5*10.0 + 0.5*20.0, the sum of the products of the matching keys.

double calcProduct(std::map<int, double> const &s1, std::map<int, double> const &s2)
{
    auto s1_it = s1.begin();
    auto s2_it = s2.begin();

    double result = 0;
    while (s1_it != s1.end() && s2_it != s2.end())
    {
        if (s1_it->first == s2_it->first)
        {
            result += s1_it->second * s2_it->second;
            s1_it++:
            s2_it++;
        }
        else if (s1_it->first < s2_it->first)
        {
            s1_it = s1.lower_bound(s2_it->first);
        }
        else
        {
            s2_it = s2.lower_bound(s1_it->first);
        }
    }
    return result;
}

I would like to refactor this and std::set_intersection seems to be close to what I want as the documentation has an example using std::back_inserter, but is there a way to get this to work on maps and avoid the intermediate array?


Solution

  • The code you're using is already very close to the way that set_intersect would be implemented. I can't see any advantage to creating a new map and iterating over it.

    However there were a couple of things with your code I wanted to mention.

    If you're going to increment your iterators you shouldn't make them constant.

    I would expect that there will be more misses than hits when looking for equivalent elements. I would suggest having the less than comparisons first:

    double calcProduct( std::map<int , double> const &s1 , std::map<int , double> const &s2 )
    {
        auto  s1_it = s1.begin();
        auto  s2_it = s2.begin();
    
        double result = 0;
        while ( s1_it != s1.end() && s2_it != s2.end() )
        {
            if ( s1_it->first < s2_it->first )
            {
                s1_it = s1.lower_bound( s2_it->first );
            }
            else if(s2_it->first < s1_it->first )
            {
                s2_it = s2.lower_bound( s1_it->first );
            }
            else
            {
                result += s1_it->second * s2_it->second;
                s1_it++;
                s2_it++;
            }
        }
        return result;
    }