5

I am aware of the mathematical differences between ADVI/MCMC, but I am trying to understand the practical implications of using one or the other. I am running a very simple logistic regressione example on data I created in this way:

import pandas as pd
import pymc3 as pm
import matplotlib.pyplot as plt
import numpy as np

def logistic(x, b, noise=None):
    L = x.T.dot(b)
    if noise is not None:
        L = L+noise
    return 1/(1+np.exp(-L))

x1 = np.linspace(-10., 10, 10000)
x2 = np.linspace(0., 20, 10000)
bias = np.ones(len(x1))
X = np.vstack([x1,x2,bias]) # Add intercept
B =  [-10., 2., 1.] # Sigmoid params for X + intercept

# Noisy mean
pnoisy = logistic(X, B, noise=np.random.normal(loc=0., scale=0., size=len(x1)))
# dichotomize pnoisy -- sample 0/1 with probability pnoisy
y = np.random.binomial(1., pnoisy)

And the I run ADVI like this:

with pm.Model() as model: 
    # Define priors
    intercept = pm.Normal('Intercept', 0, sd=10)
    x1_coef = pm.Normal('x1', 0, sd=10)
    x2_coef = pm.Normal('x2', 0, sd=10)

    # Define likelihood
    likelihood = pm.Bernoulli('y',                  
           pm.math.sigmoid(intercept+x1_coef*X[0]+x2_coef*X[1]),
                          observed=y)
    approx = pm.fit(90000, method='advi')

Unfortunately, no matter how much I increase the sampling, ADVI does not seem to be able to recover the original betas I defined [-10., 2., 1.], while MCMC works fine (as shown below)

enter image description here

Thanks' for the help!

Alberto
  • 467
  • 3
  • 16

1 Answers1

8

This is an interesting question! The default 'advi' in PyMC3 is mean field variational inference, which does not do a great job capturing correlations. It turns out that the model you set up has an interesting correlation structure, which can be seen with this:

import arviz as az

az.plot_pair(trace, figsize=(5, 5))

correlated samples

PyMC3 has a built-in convergence checker - running optimization for to long or too short can lead to funny results:

from pymc3.variational.callbacks import CheckParametersConvergence

with model:
    fit = pm.fit(100_000, method='advi', callbacks=[CheckParametersConvergence()])

draws = fit.sample(2_000)

This stops after about 60,000 iterations for me. Now we can inspect the correlations and see that, as expected, ADVI fit axis-aligned gaussians:

az.plot_pair(draws, figsize=(5, 5))

another correlation image

Finally, we can compare the fit from NUTS and (mean field) ADVI:

az.plot_forest([draws, trace])

forest plot

Note that ADVI is underestimating variance, but fairly close for the mean of each parameter. Also, you can set method='fullrank_advi' to capture the correlations you are seeing a little better.

(note: arviz is soon to be the plotting library for PyMC3)

colcarroll
  • 3,632
  • 17
  • 25
  • Given how widespread correlated features are, isn't the mvnormal with diagonal covariance approximation.....really bad in general? – Russell Richie Sep 25 '19 at 03:38
  • 1
    totally. you'll find that a lot of the literature on variational inference focuses on this (legitimate!) worry. however, it turns a sampling problem into an optimization problem, which can handle tons of data and goes much faster. So if you don't expect to see correlations, it could be the only feasible approach. – colcarroll Sep 25 '19 at 13:36
  • Right. Anyway, thank you SO MUCH for your answer -- I was seeing poor posterior predictive performance based on ADVI, and I think it may come down to the fact that I have a lot of correlated features, just like OP. I'll try MCMC, and see if that works better. – Russell Richie Sep 25 '19 at 14:01
  • btw, is this a problem for most variational inference algorithms, or just advi? – Russell Richie Oct 12 '19 at 19:06
  • 1
    It depends on the "flavor" of ADVI you use. Mean field uses a diagonal covariance matrix, while full rank fits a dense covariance matrix, which comes with its own problems. See https://nbviewer.jupyter.org/gist/ColCarroll/d673a3af7169bd713bcbdb9445d4a543 for some comparisons of NUTS, mean field, and full-rank ADVI. – colcarroll Oct 16 '19 at 00:40
  • A lot of variational inference uses a mean field assumption, because in large models working with a full rank covariance matrix is intractable. If you expect correlations, you could roll your own variational solution (assuming it's tractable) or look at for example [pyro](http://pyro.ai/) that allows a more flexible factorisation – alan ocallaghan Dec 18 '19 at 21:34