Search code examples
statisticspymc3

PyMC3 variable dependent on result of another


I am implementing a MCMC simulation for the first time and I have a variable that is defined based on the result of a previous variable. For instance if my bernoulli variable returns a 0, there will be a different value that gets fed into a deterministic variable than if it returns a 1.

with pm.Model() as model:
    x = pm.Bernoulli('x', .5)
    if x == 1:
        y = 1 
    elif x == 0:
        y = 2
    z = pm.Deterministic('z', y * 1000)

My issue is that neither of these if statements will get entered because x is not an integer, it is a distribution. Is there a way to get the sampled values of x? Or am I just thinking about this wrong?


Solution

  • You are right, you should use Theano's function switch

    with pm.Model() as model:
        x = pm.Bernoulli('x', .5)
        y = pm.math.switch(x, 1, 0)
        z = pm.Deterministic('z', y * 1000) 
    

    or a little bit more verbose

    with pm.Model() as model:
        x = pm.Bernoulli('x', .5)
        y = pm.math.switch(pm.math.eq(x, 1), 1, 0)
        z = pm.Deterministic('z', y * 1000) 
    

    switch evaluates the first argument, if true returns the second argument, otherwise the third one.

    You can also use more than one switch if you have more than two conditions.

    with pm.Model() as model:
        x = pm.DiscreteUniform('x', 0, 2)
        y_ = pm.math.switch(pm.math.eq(x, 1), 1, 0)
        y = pm.math.switch(pm.math.eq(x, 2), 2, y_)
        z = pm.Deterministic('z', y * 1000)