A rapid look at variational inference with KL divergence

I’ve been spending some time reading up on variational autoencoders (VAEs), which are a paradigm of machine learning that draws some interesting parallels with Bayesian Inference. In order to understand why VAEs worked (aside from just saying neural networks and refusing to think further), I had to get at least a surface level understanding of variational methods.

This information is rephrased from the Stanford Course on Probabilistic Graphical Models here.

Variational inference seeks to cast (often intractable) inference problems as optimization problems. There are some differences between how various fields refer to inference, but in this case inference signifies extracting information from the probabilistic model. This information may be the probability of a given variable after we sum everything else out (marginal inference) or the most likely assignment to the variables in the model (maximum a-posteriori).

Performing variational inference requires an approximating family of distributions, Q, and an optimization objective, J(q). One type of variational inference, called “mean field” uses the product of many univariate distributions as the approximating family. However, there are many other options (for example, VAEs use neural networks as approximating families).

For the optimization objective, J(q), we typically use the KL divergence, which is defined as:

    \begin{equation*} KL(q||p) = \Sigma q(x) log(\frac{q(x)}{p(x)}) \end{equation*}

The KL divergence is not symmetric (KL(q||p) \neq KL(p||q)) but is always greater than 0 when p \neq q and 0 when p = q. This second point is important.

As a specific problem, we show p(x) to be a joint distribution of many variables, which (for purely pedagogical reasons) is the normalized version of \tilde{p}(x). \theta represents the parameters of the distribution, and Z is the normalizing distribution.

    \begin{equation*} $p(x_1,...,x_n;\theta) = \frac{\tilde{p}(x_1,...,x_n;\theta)}{Z(\theta)}$ \end{equation*}

From here, we note that we can’t calculate p(x) or Z, but we can still play with our optimization objective.

    \begin{equation*} J(q) = \Sigma q(x)log(\frac{q(x)}{\tilde{p}(x)}) \end{equation*}

We note that p(x) = \frac{\tilde{p(x)}}{Z(\theta)}, so it follows that our optimization objective can be expressed as:

    \begin{equation*} J(q) = \Sigma q(x) log(\frac{q(x)}{p(x)}) - log(Z(\theta)) \end{equation*}

And re-formed as:

    \begin{equation*} log(Z(\theta)) = KL(q||p)-J(q) \end{equation*}

The KL divergence between q and p is still intractable. HOWEVER, since KL divergences are always greater than or equal to zero, our KL divergence between q and \tilde{p} (i.e., J(q)) is tractable, and is a lower bound on the probability of log(Z(\theta)). This is called the Variational Lower Bound (VLBO) or more commonly (in my reading) the Evidence Lower Bound (ELBO).

How is this equation useful? It shows up in the mathematics for the variational autoencoder, which I describe in the next post. It also shows up all sorts of other places that I haven’t covered.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.