Search code examples
linear-regressionbayesianpymcpymc3

Bayesian Lasso using PyMC3


I'm trying to reproduce the results of this tutorial (see LASSO regression) on PyMC3. As commented on this reddit thread, the mixing for the first two coefficients wasn't good because the variables are correlated.

I tried implementing it in PyMC3 but it didn't work as expected when using Hamiltonian samplers. I could only get it working with the Metropolis sampler, which achieves the same result as PyMC2.

I don't know if it's something related to the fact that the Laplacian is peaked (discontinuous derivative at 0), but it worked perfectly well with Gaussian priors. I tried with or without MAP initialization and the result is always the same.

Here is my code:

from pymc import *
from scipy.stats import norm
import pylab as plt

# Same model as the tutorial 
n = 1000

x1 = norm.rvs(0, 1, size=n)
x2 = -x1 + norm.rvs(0, 10**-3, size=n)
x3 = norm.rvs(0, 1, size=n)

y = 10 * x1 + 10 * x2 + 0.1 * x3

with Model() as model:
    # Laplacian prior only works with Metropolis sampler
    coef1 = Laplace('x1', 0, b=1/sqrt(2))
    coef2 = Laplace('x2', 0, b=1/sqrt(2))
    coef3 = Laplace('x3', 0, b=1/sqrt(2))

    # Gaussian prior works with NUTS sampler
    #coef1 = Normal('x1', mu = 0, sd = 1)
    #coef2 = Normal('x2', mu = 0, sd = 1)
    #coef3 = Normal('x3', mu = 0, sd = 1)

    likelihood = Normal('y', mu= coef1 * x1 + coef2 * x2 + coef3 * x3, tau = 1, observed=y)

    #step = Metropolis() # Works just like PyMC2
    start = find_MAP() # Doesn't help
    step = NUTS(state = start) # Doesn't work
    trace = sample(10000, step, start = start, progressbar=True) 

plt.figure(figsize=(7, 7))
traceplot(trace)
plt.tight_layout()

autocorrplot(trace)
summary(trace)

Here is the error I get:

PositiveDefiniteError: Simple check failed. Diagonal contains negatives

Am I doing something wrong or is the NUTS sampler not supposed to work on cases like this?


Solution

  • Whyking from the reddit thread gave the suggestion to use the MAP as scaling instead of the state and it actually worked wonders.

    Here is a notebook with the results and the updated code.