Search code examples
pythonpymcpymc3mcmc

How to correctly defined mixture of Beta distributions in PyMC3


I am trying to fit data using a mixture of two Beta distributions (I do not know the weights of each distribution) using Mixture from PyMC3. Here is the code:

model=pm.Model()
with model:
    alpha1=pm.Uniform("alpha1",lower=0,upper=20)
    beta1=pm.Uniform("beta1",lower=0,upper=20)
    alpha2=pm.Uniform("alpha2",lower=0,upper=20)
    beta2=pm.Uniform("beta2",lower=0,upper=20)
    w=pm.Uniform("w",lower=0,upper=1)
    b1=pm.Beta("B1",alpha=alpha1,beta=beta1)
    b2=pm.Beta("B2",alpha=alpha2,beta=beta2)
    mix=pm.Mixture("mix",w=[1.0,w],comp_dists=[b1,b2])

After run this code I get the following error: AttributeError: 'list' object has no attribute 'mean'. Any suggestions?


Solution

  • PyMC3 comes with a pymc3.tests module which contains useful examples. By searching that directory for the word "mixture" I came upon this example:

    Mixture('x_obs', w,
            [Normal.dist(mu[0], tau=tau[0]), Normal.dist(mu[1], tau=tau[1])],
            observed=self.norm_x)
    

    Notice that the classmethod dist is called. Googling "pymc3 dist classmethod" leads to this doc page which explains

    ... each Distribution has a dist class method that returns a stripped-down distribution object that can be used outside of a PyMC model.

    Beyond this I'm not entirely clear why the stripped-down distribution is what is required here, but it seems to work:

    import pymc3 as pm
    
    model = pm.Model()
    with model:
        alpha1 = pm.Uniform("alpha1", lower=0, upper=20)
        beta1 = pm.Uniform("beta1", lower=0, upper=20)
        alpha2 = pm.Uniform("alpha2", lower=0, upper=20)
        beta2 = pm.Uniform("beta2", lower=0, upper=20)
        w = pm.Uniform("w", lower=0, upper=1)
        b1 = pm.Beta.dist(alpha=alpha1, beta=beta1)
        b2 = pm.Beta.dist(alpha=alpha2, beta=beta2)
        mix = pm.Mixture("mix", w=[1.0, w], comp_dists=[b1, b2])
    

    Note that when using the dist classmethod, the name string is omitted.