Search code examples

Variadic template with list of pairs, want to forward only first elements of pairs

I'm coding a function that calculates the pmf of multinomial distribution(

I've coded a function of multinomial coefficients successfully, and want to use this in calculating multinomial distribution, but compiling is failing.

My current attempt:

#include <algorithm>
#include <cmath>
#include <cassert>
#include <utility>
#include <initializer_list>
#include <numeric>

template <typename N>
concept Unsigned = std::is_unsigned_v<N>;

template <typename N>
concept Integral = std::is_integral_v<N>;

template <typename R>
concept Floating = std::is_floating_point_v<R>;

template <Unsigned U>
constexpr U binom(U N, U K) {
    assert(N >= K);
    U result {1};
    for (U i = U {1}; i <= N - K; i++) {
        result *= i + K;
        result /= i;
    return result;

template <Unsigned U, Integral I>
constexpr U multinom (U N, I K) {
    assert(N == K);
    return U {1};

template <Unsigned U, Integral I, Integral... Is>
constexpr U multinom (U N, I K1, Is... Ks) {
    assert(N >= static_cast<U>(K1));
    return binom(N, static_cast<U>(K1)) * multinom(N - static_cast<U>(K1), Ks...);

template <Floating F1, Floating F2>
constexpr bool almost_equal(F1 f1, F2 f2) {
    return std::fabs(f1 - f2) < std::min({1.0e-4, f1 * 1.0e-3, f2 * 1.0e-3});

/*template <Unsigned U, Integral I, Floating F>
constexpr F multinom_pmf(U N, std::initializer_list<std::pair<I, F>> args) {
    assert(almost_equal(1.0, std::accumulate(args.begin(), args.end(), 0.0, [](auto& a, auto& b) {return a + b.second;})));

template <Unsigned U, Integral I, Floating F>
constexpr F multinom_pmf (U N, std::pair<I, F>... args) {
    assert(almost_equal(1.0, std::accumulate(args.begin(), args.end(), 0.0, [](auto& a, auto& b) {return a + b.second;})));
    // ???

int main() {
    static_assert(multinom<size_t>(5, 5) == 1);
    static_assert(multinom<size_t>(5, 3, 2) == 10);
    static_assert(multinom<size_t>(6, 2, 2, 2) == 90);

Desired interface:

multinom_pmf<>(N, pair<I, F>... args) = multinom<>(N, /* first of args */) * std::pow(/* second of args */, /* first of args */) * ...


multinom_pmf<>(5, {3, 0.6}, {2, 0.4}) = multinom<>(5, 3, 2) * std::pow(0.6, 3) * std::pow(0.4, 2) = 0.3456
multinom_pmf<>(6, {2, 0.4}, {2, 0.35}, {2, 0.25}) = multinom<>(6, 2, 2, 2) * pow(0.4, 2) * pow(0.35, 2) * pow(0.25, 2) = 0.11025

I want the function to be able to check that

N == sum of first of args (which would be done inside multinom() call)
1.0 == sum of second of args

How can I improve my attempts? Thanks in advance.


  • Self-answering: My solution ended up with discarding std::pair things. Hacky, but works.

    Still have no idea about how to check whether the sum of probability is 1.

    template <Unsigned U, Integral I, Floating F, typename... Ts>
    constexpr F multinom_pmf (U N, I K1, F theta1, Ts... args) {
        F prob_value = std::pow(theta1, K1) * binom(N, static_cast<U>(K1));
        if constexpr(sizeof...(args) > 0)
            return prob_value * multinom_pmf(N - static_cast<U>(K1), args...);
            return prob_value;
    int main() {
        static_assert(multinom<size_t>(5, 5) == 1);
        static_assert(multinom<size_t>(5, 3, 2) == 10);
        static_assert(multinom<size_t>(6, 2, 2, 2) == 90);
        assert(almost_equal(multinom_pmf<size_t>(5, 3, 0.4, 2, 0.6), 0.2304));