Search code examples
c++overloadingc++20c++-concepts

C++20 concepts that take function type as parameter: how to express compactly while adhering to DRY


I have a hierarchy of classes for which I want to define a generic-type binary operation. I want to define and employ a concept in order to constrain the overloads. While the code works, the use of the concept is less than satisfactory: the use of the concept is not concise, and involves repeating some of the code inside the function.

Consider the following problem:

#include <utility>

// We have foo, which consists of some doubles...
struct foo
{
    double x;
    double y;
};

// ...and bar, which consists of some foos.
struct bar
{
    foo u;
    foo v;
};

// There is a hierarchy here: bar is at the top, and doubles at the bottom. Higher classes are
// comprised of lower.

// Let's introduce a binary operation, which we define over the doubles:
double binop(double lhs, double rhs)
{
    return lhs * rhs; // What it does is not important
}

// I want to extend this binary operation to the higher classes in the hierarchy, but want them to be defined in
// terms of operations with lower members. I want this to be done by distributing the operation about the members
// of the arguments.

// For foo: binop(foo1, foo2) = binop(foo1.x, foo2.x) + binop(foo1.x, foo2.y) + binop(foo1.y, foo2.x) + binop(foo1.y, foo2.y)
// For bar: binop(bar1, bar2) = binop(bar1.u, bar2.v) + binop(bar1.u, bar2.v) + binop(bar1.v, bar2.u) + binop(bar1.v, bar2.v)

// To enable this, a distribute template method is defined for each type:

template<typename Op>
auto distribute(foo arg, Op op)
{
    return op(arg.x) + op(arg.y); // distribute the operation amongst the doubles
}

template<typename Op>
auto distribute(bar arg, Op op)
{
    return op(arg.u) + op(arg.v); // distribute the operation amongst the foos
}

// I need a concept to distinguish between types that can be distributed (foo, bar) versus ones that can't (double)
template<typename T, typename F>
concept distributive =
    requires(T t, F f)
{
    { distribute(t, f) };
};

// Now I want to define the binary operation generically, for two cases:
// (Case 1) lhs is distributive (rhs may or may not be distributive)
// (Case 2) lhs is not distributive, rhs is distributive

// Case 1
template<typename T, typename U>
    requires (distributive<T, decltype([](const auto& a) { return binop(a, double{}); })>)
auto binop(const T& lhs, const U& rhs)
{
    return distribute(lhs, [rhs](const auto& sub_lhs) { return binop(sub_lhs, rhs); });
}

// Case 2
template<typename T, typename U>
    requires (not distributive<T, decltype([](const auto& a) { return binop(a, double{}); })>
              and distributive<U, decltype([](const auto& a) { return binop(double{}, a); })>)
auto binop(const T& lhs, const U& rhs)
{
    return distribute(rhs, [lhs](const auto& sub_rhs) { return binop(lhs, sub_rhs); });
}

int main()
{
    foo my_foo_a{ 1.0, 3.0 };
    foo my_foo_b{ 5.0, 7.0 };

    double val_a{ binop(my_foo_a, 2.0) }; // 8.0
    double val_b{ binop(2.0, my_foo_b) }; // 24.0
    double val_c{ binop(my_foo_a, my_foo_b) }; // 48.0

    bar my_bar_a{ foo{ 0.0, 1.0 }, foo{ 2.0, 3.0 } };
    bar my_bar_b{ foo{ 4.0, 5.0 }, foo{ 6.0, 7.0 } };

    double val_d{ binop(my_bar_a, 2.0) }; // 12.0
    double val_e{ binop(2.0, my_bar_b) }; // 44.0
    double val_f{ binop(my_bar_a, my_bar_b) }; // 132.0
}

This code compiles, and works as expected. A concept distributive was defined so that we could get three different overload cases without overlap:

  • binop(non-distributive, non-distributive)
  • binop(distributive, distributive or non-distributive)
  • binop(non-distributive, distributive)

Without the concept, repeated definitions would occur. Note that the concept takes not only the type of interest T, but also the type of the binary operation to distribute F. As far as I'm aware, it is not possible for a concept to check for the existence of a template method without having to evaluate the template over each of the types, hence the use of F.

However, supplying the type of the function appears to effectively require duplicating the lambda in the binop definition into the constraint of the function template. This feels a messy way to express the constraint of the function, and seems a far cry from the initially expected syntax of

// Case 1
auto binop(const distributive auto& lhs, const auto& rhs)
{
    return distribute(lhs, [rhs](const auto& sub_lhs) { return binop(sub_lhs, rhs); });
}

// Case 2
auto binop(const simplex auto& lhs, const distributive auto& rhs)
{
    return distribute(rhs, [lhs](const auto& sub_rhs) { return binop(lhs, sub_rhs); });
}

(simplex is just intended as the negation of distributive here) This of course does not compile, with the concepts requiring the verbose function-type parameter.

Am I missing a trick with concepts that helps avoid duplicating function logic into the function's constraints? Is there a way to more cleanly express this code using concepts?


Solution

  • If a given type is distributive or not in your example is really only dependent on wether distribute is defined for it or not, the callable does not change anything.

    Note that
    distributive<T, decltype([](const auto& a) { return binop(a, double{}); })>
    is equivalent to
    distributive<T, decltype([](const auto& a) { return 0.0D; })>
    in your example.

    • The result type of binop(a, b) (for any a & b) will always be double, because that is the return type of double binop(double lhs, double rhs) - which is the terminating overload of binop.
    • The lambda you specify there will never actually be called because it is in an unevaluated context; so the only thing that matters is the return type.

    So you could simplify your distributive concept to:

    template<class T>
    concept distributive = requires(T t) {
        { distribute(t, [](auto const&) { return 0.0D; }) };
    };
    
    template<class T>
    concept not_distributive = !distributive<T>;
    

    ... at which point you can write your overloads like you wanted:

    double binop(double lhs, double rhs)
    {
        return lhs * rhs;
    }
    
    auto binop(distributive auto const& lhs, auto const& rhs)
    {
        return distribute(lhs, [&](auto const& sub_lhs) { return binop(sub_lhs, rhs); });
    }
    
    auto binop(not_distributive auto const& lhs, distributive auto const& rhs)
    {
        return distribute(rhs, [&](auto const& sub_rhs) { return binop(lhs, sub_rhs); });
    }
    

    godbolt example


    Instead of implementing the 3 cases separately you could also consolidate this into one recursive function that does the unwrapping of distributive types:

    template<class T>
    concept distributive = requires(T t) {
        { distribute(t, [](auto const&) { return 0.0D; }) };
    };
    
    
    template<class T, class Op>
    auto distribute_recursive(T const& value, Op&& op) {
        if constexpr(distributive<T>)
            return distribute(value, [&](auto const& sub) {
                return distribute_recursive(sub, std::forward<Op>(op));
            });
        else
            return op(value);
    }
    
    template<class Lhs, class Rhs>
    auto binop(Lhs const& lhs, Rhs const& rhs) {
        return distribute_recursive(lhs, [&](auto const& lhs_value) {
            return distribute_recursive(rhs, [&](auto const& rhs_value){
                return lhs_value * rhs_value;
            });
        });
    }
    

    godbolt example


    If you need this for different operations than just multiplication with addition reduction you could also abstract away the actual operations:

    template<class Op>
    auto distribute(foo const& arg, Op&& op)
    {
        return op(arg.x, arg.y);
    }
    
    template<class Op>
    auto distribute(bar const& arg, Op&& op)
    {
        return op(arg.u, arg.v);
    }
    
    struct check_dist_fn_returns_op;
    
    template<class T>
    concept distributive = requires(T t) {
        {
            distribute(
                t,
                [](auto const&, auto const&...) {
                    return static_cast<check_dist_fn_returns_op*>(nullptr);
                }
            )
        } -> std::same_as<check_dist_fn_returns_op*>;
    };
    
    template<class T, class Reduce>
    struct fold_helper {
        T value;
        Reduce& reduce;
    
        template<class U>
        friend fold_helper<std::invoke_result_t<Reduce, T, U>, Reduce> operator|(
            fold_helper<T, Reduce>&& lhs,
            fold_helper<U, Reduce>&& rhs
        ) {
            return {
                lhs.reduce(lhs.value, rhs.value),
                lhs.reduce
            };
        }
    };
    
    template<class T, class Op, class Reduce>
    auto distribute_recursive(T const& value, Op&& op, Reduce&& reduce) {
        if constexpr(distributive<T>) {
            return distribute(value, [&](auto const&... sub) {
                return (
                    fold_helper{distribute_recursive(sub, op, reduce), reduce} | ...
                ).value;
            });
        } else {
            return op(value);
        }
    }
    
    template<class Lhs, class Rhs, class BinaryOp, class Reduce>
    auto binop(Lhs const& lhs, Rhs const& rhs, BinaryOp&& binaryOp, Reduce&& reduce) {
        return distribute_recursive(lhs, [&](auto const& lhs_value) {
            return distribute_recursive(rhs, [&](auto const& rhs_value) {
                return binaryOp(lhs_value, rhs_value);
            }, reduce);
        }, reduce);
    }
    

    This would then allow you to specify the operations you want for each call, e.g.:

    auto result = binop(my_foo_a, 2.0, std::multiplies{}, std::plus{});
    auto result = binop(my_foo_a, 2.0, std::minus{}, std::multiplies{});
    

    godbolt example