Search code examples
bayesianpymc3

PyMC3 sample() function does not accept the "start" value to generate a trace


I am new to PyMC3 and Bayesian inference methods. I have a simple code that tries to infer the value of some decay constant (=1) from the artificial data generated using a truncated exponential distribution:

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt 
import pymc3 as pm
import arviz as az


T = stats.truncexpon(b = 10.)
t = T.rvs(1000)

#Bayesian Inference

with pm.Model() as model: 
    #Define Priors
    lam = pm.Gamma('$\lambda$', alpha=1, beta=1)

    #Define Likelihood
    time = pm.Exponential('time', lam = lam, observed = t)

    #Inference
    trace = pm.sample(20, start = {'lam': 10.}, \
            step=pm.Metropolis(), chains=1, cores=1, \
            progressbar = True)


az.plot_trace(trace)
plt.show()

This code produces a trace like below

enter image description here

I am really confused as to why the starting value of 10. is not accepted by the sampler. The trace above should start at 10. I am using python 3.7 to run the code.

Thank you.


Solution

  • Few things going on:

    • when the sampler first starts it has a tuning phase; samples during this phase are discarded by default, but this can be controlled with the discard_tuned_samples argument
    • the keys in the start argument dictionary need to correspond to the name given to the RandomVariable ('$\lambda$') not the Python variable

    Incorporating those two, one can try

    trace = pm.sample(20, start = {'$\lambda$': 10.},
                step=pm.Metropolis(), chains=1, cores=1,
                discard_tuned_samples=False)
    

    However, the other possible issue is that

    • the starting value isn't guaranteed to be emitted in the first draw; only if the first proposal sample is rejected, which is down to chance.

    Fixing the game (setting a random seed), though, we can get glimpse:

    trace = pm.sample(20, start = {'$\lambda$': 10.},
                step=pm.Metropolis(), chains=1, cores=1,
                discard_tuned_samples=False, random_seed=1)
    
    ...
    
    trace.get_values(varname='$\lambda$')[:10]
    
    # array([10.        ,  5.42397358,  3.19841997,  1.09383329,  1.09383329,
    #         1.09383329,  1.09383329,  1.09383329,  1.09383329,  1.09383329])