Search code examples
c++sortingequivalent

Equivalent of Python's list sort with key / Schwartzian transform


In Python, given a list, I can sort it by a key function, e.g.:

>>> def get_value(k):
...     print "heavy computation for", k
...     return {"a": 100, "b": 30, "c": 50, "d": 0}[k]
...
>>> items = ['a', 'b', 'c', 'd']
>>> items.sort(key=get_value)
heavy computation for a
heavy computation for b
heavy computation for c
heavy computation for d
>>> items
['d', 'b', 'c', 'a']

As you see, the list was sorted not alphanumerically but by the return value of get_value().

Is there an equivalent in C++? std::sort() only allows me to provide a custom comparator (equivalent of Python's items.sort(cmp=...)), not a key function. If not, is there any well-tested, efficient, publicly available implementation of the equivalent I can drop into my code?

Note that the Python version only calls the key function once per element, not twice per comparison.


Solution

  • You could just roll your own:

    template <typename RandomIt, typename KeyFunc>
    void sort_by_key(RandomIt first, RandomIt last, KeyFunc func) 
    {
        using Value = decltype(*first);
        std::sort(first, last, [=](const ValueType& a, const ValueType& b) {
            return func(a) < func(b);
        });
    }
    

    If KeyFunc is too expensive, you'll have to create a separate vector with the values.

    We can even hack together a class that will allow us to still use std::sort:

    template <typename RandomIter, typename KeyFunc>
    void sort_by_key(RandomIter first, RandomIter last, KeyFunc func)
    {
        using KeyT = decltype(func(*first));
        using ValueT = typename std::remove_reference<decltype(*first)>::type;
    
        struct Pair {
            KeyT key;
            RandomIter iter;
            boost::optional<ValueT> value;
    
            Pair(const KeyT& key, const RandomIter& iter)
                : key(key), iter(iter)
            { }
    
            Pair(Pair&& rhs)
                : key(std::move(rhs.key))
                , iter(rhs.iter)
                , value(std::move(*(rhs.iter)))
            { }
    
            Pair& operator=(Pair&& rhs) {
                key = std::move(rhs.key);
                *iter = std::move(rhs.value ? *rhs.value : *rhs.iter);
                value = boost::none;
                return *this;
            }
    
            bool operator<(const Pair& rhs) const {
                return key < rhs.key;
            }
        };
    
        std::vector<Pair> ordering;
        ordering.reserve(last - first);
    
        for (; first != last; ++first) {
            ordering.emplace_back(func(*first), first);
        }
    
        std::sort(ordering.begin(), ordering.end());
    }
    

    Or, if that's too hacky, here's my original solution, which requires us to write our own sort

    template <typename RandomIt, typename KeyFunc>
    void sort_by_key_2(RandomIt first, RandomIt last, KeyFunc func)
    {
        using KeyT = decltype(func(*first));
        std::vector<std::pair<KeyT, RandomIt> > ordering;
        ordering.reserve(last - first);
    
        for (; first != last; ++first) {
            ordering.emplace_back(func(*first), first);
        }
    
        // now sort this vector by the ordering - we're going
        // to sort ordering, but each swap has to do iter_swap too
        quicksort_with_benefits(ordering, 0, ordering.size());
    }
    

    Although now we have to reimplement quicksort:

    template <typename Key, typename Iter>
    void quicksort_with_benefits(std::vector<std::pair<Key,Iter>>& A, size_t p, size_t q) {
        if (p < q) {
            size_t r = partition_with_benefits(A, p, q);
            quicksort_with_benefits(A, p, r);
            quicksort_with_benefits(A, r+1, q);
        }
    }
    
    template <typename Key, typename Iter>
    size_t partition_with_benefits(std::vector<std::pair<Key,Iter>>& A, size_t p, size_t q) {
        auto key = A[p].first;
        size_t i = p;
        for (size_t j = p+1; j < q; ++j) {
            if (A[j].first < key) {
                ++i;
                std::swap(A[i].first, A[j].first);
                std::iter_swap(A[i].second, A[j].second);
            }
        }
    
        if (i != p) {
            std::swap(A[i].first, A[p].first);
            std::iter_swap(A[i].second, A[p].second);
        }
        return i;
    }
    

    Which, given a simple example:

    int main()
    {
        std::vector<int> v = {-2, 10, 4, 12, -1, -25};
    
        std::sort(v.begin(), v.end());
        print(v); // -25 -2 -1 4 10 12
    
        sort_by_key_2(v.begin(), v.end(), [](int i) { return i*i; }); 
        print(v); // -1 -2 4 10 12 -25
    }