Search code examples
logistic-regressionpymc3

Why is Pymc3 ADVI worse than MCMC in this logistic regression example?


I am aware of the mathematical differences between ADVI/MCMC, but I am trying to understand the practical implications of using one or the other. I am running a very simple logistic regressione example on data I created in this way:

import pandas as pd
import pymc3 as pm
import matplotlib.pyplot as plt
import numpy as np

def logistic(x, b, noise=None):
    L = x.T.dot(b)
    if noise is not None:
        L = L+noise
    return 1/(1+np.exp(-L))

x1 = np.linspace(-10., 10, 10000)
x2 = np.linspace(0., 20, 10000)
bias = np.ones(len(x1))
X = np.vstack([x1,x2,bias]) # Add intercept
B =  [-10., 2., 1.] # Sigmoid params for X + intercept

# Noisy mean
pnoisy = logistic(X, B, noise=np.random.normal(loc=0., scale=0., size=len(x1)))
# dichotomize pnoisy -- sample 0/1 with probability pnoisy
y = np.random.binomial(1., pnoisy)

And the I run ADVI like this:

with pm.Model() as model: 
    # Define priors
    intercept = pm.Normal('Intercept', 0, sd=10)
    x1_coef = pm.Normal('x1', 0, sd=10)
    x2_coef = pm.Normal('x2', 0, sd=10)

    # Define likelihood
    likelihood = pm.Bernoulli('y',                  
           pm.math.sigmoid(intercept+x1_coef*X[0]+x2_coef*X[1]),
                          observed=y)
    approx = pm.fit(90000, method='advi')

Unfortunately, no matter how much I increase the sampling, ADVI does not seem to be able to recover the original betas I defined [-10., 2., 1.], while MCMC works fine (as shown below)

enter image description here

Thanks' for the help!


Solution

  • This is an interesting question! The default 'advi' in PyMC3 is mean field variational inference, which does not do a great job capturing correlations. It turns out that the model you set up has an interesting correlation structure, which can be seen with this:

    import arviz as az
    
    az.plot_pair(trace, figsize=(5, 5))
    

    correlated samples

    PyMC3 has a built-in convergence checker - running optimization for to long or too short can lead to funny results:

    from pymc3.variational.callbacks import CheckParametersConvergence
    
    with model:
        fit = pm.fit(100_000, method='advi', callbacks=[CheckParametersConvergence()])
    
    draws = fit.sample(2_000)
    

    This stops after about 60,000 iterations for me. Now we can inspect the correlations and see that, as expected, ADVI fit axis-aligned gaussians:

    az.plot_pair(draws, figsize=(5, 5))
    

    another correlation image

    Finally, we can compare the fit from NUTS and (mean field) ADVI:

    az.plot_forest([draws, trace])
    

    forest plot

    Note that ADVI is underestimating variance, but fairly close for the mean of each parameter. Also, you can set method='fullrank_advi' to capture the correlations you are seeing a little better.

    (note: arviz is soon to be the plotting library for PyMC3)