Search code examples
pythonrandomprobabilitynormal-distributionrandom-seed

Two step probability draw - Combine probability weighting function and a draw from a truncuated normal distribution


I want to draw a random value (r) with python according to a specific probability function:

The value should be a fixed constant with a probability (p) and with a probability (1-p) the value is randomly drawn from a normal distribution in the interval (a,b).

From my point of view, the function consists of two draws:

First draw: a probability weighting function with probabilities p and (1-p).

Second draw: a draw from a truncated normal distribution in the interval (a,b).

f.e. constant = 10; p = 0,3; 1-p = 0,7, a=5, b=15

My idea was:

r = random.choices([constant,random.normal()], weights=(0.3, 0.7))

However, it does not seem to work and I don't know how to include the interval (a,b).


Solution

  • Your idea is on-target, but you're not tackling the normal distribution correctly. First, if you're using import random your choices are random.gauss(mu, sigma) or random.normalvariate(mu, sigma). There's no normal() function. Second, the normal distribution has an infinite range. You can specify mu, the center of the distribution, and sigma, a measure of the spread such that ~95% of the results will fall in the range mu ± 2*sigma. If you truly need to restrict the range, you'll need to write your own function with a loop to reject values outside the range limits and try again.

    Here's a working example which shows your idea:

    import random
    
    constant = -42
    mu = 10
    sigma = 2
    for _ in range(10):
        r = random.choices((constant,random.gauss(mu, sigma)), weights=(0.3, 0.7))
        print(r)
    

    With those values of mu and sigma, outcomes outside the range [5, 15] will be rare but can definitely still occur.

    This produces outcomes such as:

    [7.674159248587632]
    [-42]
    [7.818652194185853]
    [-42]
    [7.418414458386396]
    [11.855000252151326]
    [12.398753049340957]
    [9.663097201849096]
    [-42]
    [10.385663464672415]
    

    If you don't like the brackets, print(r[0]).