Search code examples
c++c++14variadic-templatesperfect-forwarding

Generically taking linear combinations of indexibles/callables


I'm trying to globally scale and add together callable/indexible objects (vectors in the abstract mathematical sense of the word).

That is to say, I'm trying to take linear combinations of objects that define operator[] or operator().

For example, I want to be able to do this:

LinearCombination<std::function<double(double, double)>> A([](double x, double y){
    return 1+x+std::pow(x,2)+std::sin(y);
});
LinearCombination<std::function<double(double, double)>> B([](double x, double y){
    return 1-x+std::cos(y);
});
A*= 2.5;
A += B;
std::cout << A(1.0,2.0) << std::endl;

My attempt

// ZERO ///////////////////////////////////////////////////////////////////////////////////////////

namespace hidden {

    // tag dispatching: from https://stackoverflow.com/a/60248176/827280

    template<int r>
    struct rank : rank<r - 1> {};

    template<>
    struct rank<0> {};

    template<typename T>
    auto zero(rank<2>) -> decltype(static_cast<T>(0)) {
        return static_cast<T>(0);
    }

    template<typename T>
    auto zero(rank<1>) -> decltype(T::zero()) {
        return T::zero();
    }

    template<typename T>
    auto zero(rank<0>)->std::enable_if_t<
        std::is_assignable<std::function<double(double,double)>, T>::value
        , std::function<double(double,double)>> {
        return []() {
            return 0.0;
        };
    }
}

template<typename T>
auto zero() { return hidden::zero<T>(hidden::rank<10>{}); }

// LINEAR COMBINATION ///////////////////////////////////////////////////////////////////////////////////////////

template<typename V, typename C = double>
struct LinearCombination {
    struct Term {
        C coeff;
        V vector;

        // if V(x...) is defined
        template<typename ...X>
        auto operator()(X&&... x) const -> std::remove_reference_t<decltype(std::declval<V>()(std::forward<X>(x)...))> {
            return vector(std::forward<X>(x)...) * coeff;
        }

        // if V[i] is defined
        template<typename I>
        auto operator[](I i) const -> std::remove_reference_t<decltype(std::declval<V>()[i])> {
            return vector[i] * coeff;
        }

    };
    std::vector<Term> terms;

    LinearCombination() {} // zero

    /*implicit*/ LinearCombination(V&& v) {
        terms.push_back({ static_cast<C>(1), std::move(v) });
    }

    /*implicit*/ LinearCombination(Term&& term) {
        terms.push_back(std::move(term));
    }

    LinearCombination<V, C>& operator+=(LinearCombination<V, C>&& other) {
        terms.reserve(terms.size() + other.terms.size());
        std::move(std::begin(other.terms), std::end(other.terms), std::back_inserter(terms));
        other.terms.clear();
        return *this;
    }

    LinearCombination<V, C>& operator*=(C multiplier) {
        for (auto& term : terms) {
            term.coeff *= multiplier;
        }
        return *this;
    }

    // if V(x...) is defined
    template<typename ...X>
    auto operator()(X&&... x) const
         -> std::remove_reference_t<decltype(std::declval<V>()(std::forward<X>(x)...))> {
        auto result = zeroVector()(std::forward<X>(x)...);  <--------------- *** BAD FUNCTION CALL ***
                                                                             *************************
        for (const auto& term : terms) {
            result += term(std::forward<X>(x)...);
        }
        return result;
    }

    // if V[i] is defined
    template<typename I>
    auto operator[](I i) const -> std::remove_reference_t<decltype(std::declval<V>()[i])> {
        auto result = zeroVector()[i];
        for (const auto& term : terms) {
            result += term[i];
        }
        return result;
    }

private:
    static const V& zeroVector() {
        static V z = zero<V>();
        return z;
    }
};

This compiles fine for me, but I get an exception on the indicated line (bad function call). Can you help?


Solution

  • This function:

    template<typename T>
    auto zero(rank<2>) -> decltype(static_cast<T>(0));
    

    wins overload resolution against:

    template<typename T>
    auto zero(rank<0>)->std::enable_if_t<
        std::is_assignable<std::function<double(double,double)>, T>::value
        , std::function<double(double,double)>>;
    

    This is because rank<2> is a better match for rank<10>{} than rank<0>, and also:

    static_cast<std::function<double(double,double)>>(0)
    

    is a valid expression.

    That is, std::function has the following constructor:

    function(std::nullptr_t) noexcept;
    

    which makes it a viable choice for the 0 argument, and static_cast does considers constructors.

    You end up with std::function<double(double,double)> initialized with 0 (empty), which leads to the exception when you attempt to invoke it.