Search code examples
rconstraintscombinatoricsmemory-efficient

Efficiently enumerating multinomials with constant sum - R


Let's say I have an N-side dice with non-uniform probabilities for each side and I throw it M times. Now instead of observing the individual outcomes, we only observe the sum.

I have to code the likelihood where I have to sum over the multinomial likelihood components restricted to having that observed sum.

If N=3, M = 2 and the sum is 4, than it is clear that I have to sum over the two cases where one of the throws is 1 and the other 3 plus the case where both are 2.

I could also enumerate all possibilities, calculate the sum and restrict the calculation to the combinations with the sum I'm interested in, but clearly that becomes intractable quite quickly with increasing N and M.

So I am looking for an efficient approach to select the constant-sum combinations in R.


Solution

  • One option is to use RcppAlgos::compositionsGeneral() which employs 'efficient algorithms for partitioning numbers under various constraints'.

    library(RcppAlgos)
    
    compositionsGeneral(3, 2, repetition = TRUE, target = 4)
    
         [,1] [,2]
    [1,]    1    3
    [2,]    2    2
    [3,]    3    1
    

    As @ThomasIsCoding has pointed out, this approach can fail with the message:

    compositionsGeneral(3, 6, repetition = TRUE, target = 10)
    
    Error: Currently, there is no composition algorithm for this case.
     Use permuteCount, permuteIter, permuteGeneral, permuteSample, or
     permuteRank instead.
    

    So to deal with this, we can catch errors and fall back on permuteGeneral() with constraints in this event:

    comps <- \(v, m, x) {
      tryCatch(
        compositionsGeneral(v,
                            m,
                            repetition = TRUE,
                            target = x),
        error = function(e)
          permuteGeneral(v,
                         m,
                         repetition = TRUE,
                         constraintFun = "sum",
                         comparisonFun = "==",
                         limitConstraints = x
          )
      )
    }
    
    
    comps(3, 6, 10)
    
          [,1] [,2] [,3] [,4] [,5] [,6]
     [1,]    1    1    1    1    3    3
     [2,]    1    1    1    3    1    3
     [3,]    1    1    1    3    3    1
     [4,]    1    1    3    1    1    3
     [5,]    1    1    3    1    3    1
    ...
    [85,]    2    2    1    1    2    2
    [86,]    2    2    1    2    1    2
    [87,]    2    2    1    2    2    1
    [88,]    2    2    2    1    1    2
    [89,]    2    2    2    1    2    1
    [90,]    2    2    2    2    1    1
    

    Note that the documentation includes the following about calculating permutations with contraints:

    Finding all combinations/permutations with constraints is optimized by organizing them in such a way that when constraintFun is applied, a partially monotonic sequence is produced. Combinations/permutations are added successively, until a particular combination exceeds the given constraint value for a given constraint/comparison function combo. After this point, we can safely skip several combinations knowing that they will exceed the given constraint value.