Search code examples
pythonnumpyrandomcython

Cython multinomial distribution


What is the most efficient way to get a multinomial distribution (say for n=1 trial) in cython?

For example, I have three probabilities p0=0.1, p1=0.2, p2=0.7 (which sum to 1) and want to have x to be 0, 1, 2 with probability p0, p1, p2 respectively.

I tried

cimport numpy as np
import numpy as np

# the data types
cdef np.ndarray[double, ndim=1] p
cdef int x

rng = np.random.default_rng() # the random number generator from numpy
p = np.array([0.1,0.2,0.7]) # the probabilities
x = <int>(rng.multinomial(1, p, size=1).argmax(axis=-1)[0])

But this is very slow, since it has to use a lot of python code. Is there a faster way, which used a good random number generator?

(Side note: I use multinomial instead of choice from numpy since it has the issue of p only almost summing up to 1 due to rounding errors fixed.)


Solution

  • My solution is based on (1) implementing a Bernoulli trial (2) nesting those trials. In this example, I do it for 4 probabilities. As suggested by @peter-o, I use weights instead of probabilities in the arguments.

    I am open to better suggestions / improvements.

    from libc.stdlib cimport rand,srand, RAND_MAX
    
    # this does a bernoulli trial. 
    #   output 0 w/ pr l_0/(l_0+l_1)
    #   output 1 w/ pr l_1/(l_0+l_1)
    cpdef int multinomial2(int l_0,int l_1):
        cdef int draw
        cdef double nom
        cdef double denom
        cdef double factor = (l_0+l_1)/<double>(l_0)
        nom = rand()
        denom = RAND_MAX 
        draw = <int>(nom*1.0/denom *factor)
        print(draw)
        if(draw<1):
            return 0
        else:
            return 1
    
    # nest bernoulli trials
    cpdef int multinomial4(int l_0,int l_1,int l_2,int l_3):
        cdef int x1
        cdef int ret
    
        if(multinomial2(l_0+l_1,l_2+l_3)==0):
            if(multinomial2(l_0,l_1)==0):
                return 0
            else:
                return 1
        else:
            if(multinomial2(l_2,l_3)==0):
                return 2
            else:
                return 3
    

    Add-on: A useful helper function to seed the random number generator:

    # seed the rng
    cpdef void rand_seed(int seed):
        srand(seed)