Search code examples
pythonperformancedistribution

Python: Speed-up Code for mixed Distribution


i have the following function to return a "mixed distribution":

#M=[float]*k = Center Value of distribution,
#S=[float]*k = Standard deviations,
#P=[float]*k = probability for each value (Sum is 1.0)
#rng = Random number generator
#n = len of return array [float]*n
#return [float]*n

def mixed_normal(rng, n, M, S, P):
    #See https://en.wikipedia.org/wiki/Mixture_model
    idx = np.random.choice(len(M), p=P, replace=True, size=n)
    return np.fromiter((rng.normal(M[i], S[i]) for i in idx),dtype=np.float64)

which is called like:

rng = np.random.default_rng()
def mixed_normal_3(rng, n):
    data = [(-5, 0, 5), (1, 1, 1), (1/3, 1/3, 1/3)]
    return mixed_normal(rng, n, *data)

with n=10**6.

But, the implementation is too slow! Currently it takes around 350s on my machine. I need to get it down to approx 30s.

I consider changing

return np.fromiter((rng.normal(M[i], S[i]) for i in idx),dtype=np.float64)

from a "for-loop" to a "single numpy-call".

But, i can not come up with a working solution!

Minimal working example

import numpy as np

rng = np.random.default_rng()

def mixed_normal(rng, n, M, S, P):
    #See https://en.wikipedia.org/wiki/Mixture_model
    idx = np.random.choice(len(M), p=P, replace=True, size=n) 
    # Needs to be optimized
    return np.fromiter((rng.normal(M[i], S[i]) for i in idx),dtype=np.float64)      

def mixed_normal_3(rng, n):
    data = [(-5, 0, 5), (1, 1, 1), (1/3, 1/3, 1/3)]
    return mixed_normal(rng, n, *data)
    
# [float]*(10**6) expected  
print( mixed_normal_3( rng , 10**6 ) );

Solution

  • I fixed it by 'pre-slicing' the lists, M and S:

    def mixed_normal(rng, n, M, S, P):
        idx = np.array(np.random.choice(len(M), p=P, replace=True, size=n))    
        return rng.normal(np.array(M)[idx.astype(int)],np.array(S)[idx.astype(int)],n);
    

    The lists, M and S, are expanded to size n by taking all elements according to the random indices generated in idx – where idx has size n:

    M = [0,1,2]
    idx = [0,0,1,1,0,0,1,1,2]
    M[idx] = [0,0,1,1,0,0,1,1,2]
    

    These 'expanded' lists are then passed to the RNG.

    This improved execution time from 350s down to 40s for my test-cases.