Search code examples
python-3.xhistogrampymcmcmc

How to plot a probability distribution with `pymc.MCMC` in Python


I know that I can use:

S = pymc.MCMC(model1)
from pymc import Matplot as mcplt
mcplt.plot(S)

and that will give me a figure with three plots but all I want is just a single plot of the histogram. Then I want to normalise the histogram and then make a plot a smooth curve of the distribution rather than the bars of the histogram. Could anyone help me to code this so I can get a final plot of the distribution?


Solution

  • If you have matplotlib installed for plotting, and scipy for doing a kernel density estimate (many KDE functions exist), then you could do something similar to the following (based on this example, where 'late_mean' is one of the names of the sampled parameters in that case):

    from pymc.examples import disaster_model
    from pymc import MCMC
    import numpy as np
    
    M = MCMC(disaster_model) # you could substitute your own model
    
    # perform sampling of model
    M.sample(iter=10000, burn=1000, thin=10)
    
    # get numpy array containing the MCMC chain of the parameter you want: 'late_mean' in this case
    chain = M.trace('late_mean')[:]
    
    # import matplotlib plotting functions
    from matplotlib import pyplot as pl
    
    # plot histogram (using 15 bins, but you can choose whatever you want) - density=True returns a normalised histogram
    pl.hist(chain, bins=15, histtype='stepfilled', density=True)
    ax = pl.gca() # get current axis
    
    # import scipy gaussian KDE function
    from scipy.stats import gaussian_kde
    
    # create KDE of samples
    kde = gaussian_kde(chain)
    
    # calculate KDE at a range of points (in this case based on the current plot, but you could choose a range based on the chain)
    vals = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 100)
    
    # overplot KDE
    pl.plot(vals, kde(vals), 'b')
    
    pl.xlabel('late mean')
    pl.ylabel('PDF')
    
    # show the plot
    pl.show()
    
    # save the plot
    pl.savefig('hist.png', dpi=200)