Search code examples
pythonmachine-learningbayesianpymc3unsupervised-learning

How to extract unsupervised clusters from a Dirichlet Process in PyMC3?


I just finished the Bayesian Analysis in Python book by Osvaldo Martin (great book to understand bayesian concepts and some fancy numpy indexing).

I really want to extend my understanding to bayesian mixture models for unsupervised clustering of samples. All of my google searches have led me to Austin Rochford's tutorial which is really informative. I understand what is happening but I am unclear in how this can be adapted to clustering (especially using multiple attributes for the cluster assignments but that is a different topic).

I understand how to assign the priors for the Dirichlet distribution but I can't figure out how to get the clusters in PyMC3. It looks like the majority of the mus converge to the centroids (i.e. the means of the distributions I sampled from) but they are still separate components. I thought about making a cutoff for the weights (w in the model) but that doesn't seem to work the way I imagined since multiple components have slightly different mean parameters mus that are converging.

How can I extract the clusters (centroids) from this PyMC3 model? I gave it a maximum of 15 components that I want to converge to 3. The mus seem to be at the right location but the weights are messed up b/c they are being distributed between the other clusters so I can't use a weight threshold (unless I merge them but I don't think that's the way it is normally done).

import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing
import seaborn as sns
import pandas as pd
import theano.tensor as tt
%matplotlib inline

# Clip at 15 components
K = 15

# Create mixture population
centroids = [0, 10, 50]
weights = [(2/5),(2/5),(1/5)]

mix_3 = np.concatenate([np.random.normal(loc=centroids[0], size=int(150*weights[0])), # 60 samples
                        np.random.normal(loc=centroids[1], size=int(150*weights[1])), # 60 samples
                        np.random.normal(loc=centroids[2], size=int(150*weights[2]))])# 30 samples
n = mix_3.size

enter image description here

# Create and fit model
with pm.Model() as Mod_dir:
    alpha = pm.Gamma('alpha', 1., 1.)

    beta = pm.Beta('beta', 1., alpha, shape=K)

    w = pm.Deterministic('w', beta * tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]]))

    component = pm.Categorical('component', w, shape=n)

    tau = pm.Gamma("tau", 1.0, 1.0, shape=K)

    mu = pm.Normal('mu', 0, tau=tau, shape=K)

    obs = pm.Normal('obs',
                    mu[component], 
                    tau=tau[component],
                    observed=mix_3)

    step1 = pm.Metropolis(vars=[alpha, beta, w, tau, mu, obs])
#     step2 = pm.CategoricalGibbsMetropolis(vars=[component])
    step2 = pm.ElemwiseCategorical([component], np.arange(K)) # Much, much faster than the above

    tr = pm.sample(1e4, [step1, step2], njobs=multiprocessing.cpu_count())

#burn-in = 1000, thin by grabbing every 5th idx
pm.traceplot(tr[1e3::5])

enter image description here

Similar questions below

https://stats.stackexchange.com/questions/120209/pymc3-dirichlet-distribution for regression and not clustering

https://stats.stackexchange.com/questions/108251/image-clustering-and-dirichlet-process theory on the DP process

https://stats.stackexchange.com/questions/116311/draw-a-multinomial-distribution-from-a-dirichlet-distribution explains DP

Dirichlet process in PyMC 3 directs me to Austin Rochford's tutorial above


Solution

  • Using a couple of new-ish additions to pymc3 will help make this clear. I think I updated the Dirichlet Process example after they were added, but it seems to have been reverted to the old version during a documentation cleanup; I will fix that soon.

    One of the difficulties is that the data you have generated is much more dispersed than the priors on the component means can accommodate; if you standardize your data, the samples should mix much more quickly.

    The second is that pymc3 now supports mixture distributions where the indicator variable component has been marginalized out. These marginal mixture distributions will help accelerate mixing and allow you to use NUTS (initialized with ADVI).

    Finally, with these truncated versions of infinite models, when encountering computational problems, it is often useful to increase the number of potential components. I have found that K = 30 works better for this model than K = 15.

    The following code implements these changes and shows how the "active" component means can be extracted.

    from matplotlib import pyplot as plt
    import numpy as np
    import pymc3 as pm
    import seaborn as sns
    from theano import tensor as T
    
    blue = sns.color_palette()[0]
    
    np.random.seed(462233) # from random.org
    
    N = 150
    
    CENTROIDS = np.array([0, 10, 50])
    WEIGHTS = np.array([0.4, 0.4, 0.2])
    
    x = np.random.normal(CENTROIDS[np.random.choice(3, size=N, p=WEIGHTS)], size=N)
    x_std = (x - x.mean()) / x.std()
    
    fig, ax = plt.subplots(figsize=(8, 6))
    
    ax.hist(x_std, bins=30);
    

    Standardized data

    K = 30
    
    with pm.Model() as model:
        alpha = pm.Gamma('alpha', 1., 1.)
        beta = pm.Beta('beta', 1., alpha, shape=K)
        w = pm.Deterministic('w', beta * T.concatenate([[1], T.extra_ops.cumprod(1 - beta)[:-1]]))
    
        tau = pm.Gamma('tau', 1., 1., shape=K)
        lambda_ = pm.Uniform('lambda', 0, 5, shape=K)
        mu = pm.Normal('mu', 0, tau=lambda_ * tau, shape=K)
        obs = pm.NormalMixture('obs', w, mu, tau=lambda_ * tau,
                               observed=x_std)
    
    with model:
        trace = pm.sample(2000, n_init=100000)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    
    ax.bar(np.arange(K) - 0.4, trace['w'].mean(axis=0));
    

    We see that three components appear to be used, and that their weights are reasonably close to the true values.

    Mixture weights

    Finally, we see that the posterior expected means of these three components match the true (standardized) means fairly well.

    trace['mu'].mean(axis=0)[:3]
    

    array([-0.73763891, -0.17284594, 2.10423978])

    (CENTROIDS - x.mean()) / x.std()
    

    array([-0.73017789, -0.16765707, 2.0824262 ])