Auto-Encoding Variational Bayes

In my previous post on variational inference, I gave a quick foundation of the theory behind variational auteoncoders (VAEs). The dominant lay understanding of VAEs is that they’re models which allow you to take map an image to a smaller latent space and manipulate that image in meaningful ways by changing the values of the latent space. Most machine learning tutorials are centered around this interpretation. However, I feel this short-sells their meaning (and makes people say “just use a GAN instead”, which is disappointing). It is more correct, and interesting, to say that VAEs are an implementation of variational inference which allows use of arbitrarily complex functions (read: neural networks) as approximating distributions.

Like the previous post, a lot of this information was pulled from the Stanford Probabilistic Graphical Models Course, but it’s also worth reading the original (if opaque) Kingma & Welling Paper, watching this video (starting at 7 minutes or so), and reading this StackOverflow post.

Throughout this post, x refers to an input, z refers to latent variables, p(*) refers to an actual distribution, and q(*) refers to an approximating distribution.

In a variational autoencoder, we have a (true) equation p(z|x)=\frac{p(x,z)}{p(x)}, and want to estimate our p(x).

In order to do this, the first thing we do is re-formulate the VAE objective. We still use the upper bound:

    \begin{equation*} $log(Z(\theta)) = KL(q||p)-J(q)$ \end{equation*}

Which, in our variables, is:

    \begin{align*} log(p(x)) &= KL(q(z|x)||p(z|x)) - KL(q(z|x)||p(x,z)) \\ &= KL(q(z|x)||p(z|x)) + \mathbb{E}[(log(p(x,z))-log(q(z|x)))] \end{align*}

We note that KL is always greater than or equal to zero, so we remove the intractable first KL term. We then expand p(x,z) with the probability chain rule (and logarithm product rule).

    \begin{equation*} log(p(x)) \geq \mathbb{E}[log(p(x|z))+log(p(z))-log(q(z|x))] \end{equation*}

Exploit the linearity of expectation:

    \begin{equation*} log(p(x)) \geq \mathbb{E}[(log(p(x|z)))]+\mathbb{E}[log(p(z))-log(q(z|x)))] \end{equation*}

Flip the signs and collect the right two terms as a KL divergence (as q(z|x) marginalizes out, this is a straightforward algebraic manipulation of the KL divergence):

    \begin{equation*} log(p(x)) \geq \mathbb{E}[log(p(x|z))] - KL(q(z|x)||p(z)) \end{equation*}

Which is our objective function for the variational autoencoder. We can pause here for a minute and take a look at how this lines up with the common interpretation of our variational autoencoder. We note that there is a q(z|x), which corresponds to our encoder network, a p(z), which corresponds to our prior (a set of independent multivariate unit gaussians), and a p(x|z), which corresponds to our decoder network. The expected value operator is also significant, as it means we never actually need to know the value of the probability. The expected value is simply the output of the decoder network when all latent variables are at their mean value, and making the probability of x higher under this distribution corresponds to just making the input and output images match more closely!

There is one (not small) detail missing before we can perform gradient descent. Since z are random variables, we can’t differentiate through them to train the encoder and decoder network. This is what the reparameterization trick is for. In the reparameterization trick, we produce two outputs in our latent space (mean and standard deviation) and treat them as deterministic. That is, these are the outputs we want to learn. But we have a third, random, variable that we multiply sigma by which is drawn from a unit normal distribution. Since this isn’t learned, we can still backpropagate through the network as desired despite having random variables within it.

A rapid look at variational inference with KL divergence

I’ve been spending some time reading up on variational autoencoders (VAEs), which are a paradigm of machine learning that draws some interesting parallels with Bayesian Inference. In order to understand why VAEs worked (aside from just saying neural networks and refusing to think further), I had to get at least a surface level understanding of variational methods.

This information is rephrased from the Stanford Course on Probabilistic Graphical Models here.

Variational inference seeks to cast (often intractable) inference problems as optimization problems. There are some differences between how various fields refer to inference, but in this case inference signifies extracting information from the probabilistic model. This information may be the probability of a given variable after we sum everything else out (marginal inference) or the most likely assignment to the variables in the model (maximum a-posteriori).

Performing variational inference requires an approximating family of distributions, Q, and an optimization objective, J(q). One type of variational inference, called “mean field” uses the product of many univariate distributions as the approximating family. However, there are many other options (for example, VAEs use neural networks as approximating families).

For the optimization objective, J(q), we typically use the KL divergence, which is defined as:

    \begin{equation*} KL(q||p) = \Sigma q(x) log(\frac{q(x)}{p(x)}) \end{equation*}

The KL divergence is not symmetric (KL(q||p) \neq KL(p||q)) but is always greater than 0 when p \neq q and 0 when p = q. This second point is important.

As a specific problem, we show p(x) to be a joint distribution of many variables, which (for purely pedagogical reasons) is the normalized version of \tilde{p}(x). \theta represents the parameters of the distribution, and Z is the normalizing distribution.

    \begin{equation*} $p(x_1,...,x_n;\theta) = \frac{\tilde{p}(x_1,...,x_n;\theta)}{Z(\theta)}$ \end{equation*}

From here, we note that we can’t calculate p(x) or Z, but we can still play with our optimization objective.

    \begin{equation*} J(q) = \Sigma q(x)log(\frac{q(x)}{\tilde{p}(x)}) \end{equation*}

We note that p(x) = \frac{\tilde{p(x)}}{Z(\theta)}, so it follows that our optimization objective can be expressed as:

    \begin{equation*} J(q) = \Sigma q(x) log(\frac{q(x)}{p(x)}) - log(Z(\theta)) \end{equation*}

And re-formed as:

    \begin{equation*} log(Z(\theta)) = KL(q||p)-J(q) \end{equation*}

The KL divergence between q and p is still intractable. HOWEVER, since KL divergences are always greater than or equal to zero, our KL divergence between q and \tilde{p} (i.e., J(q)) is tractable, and is a lower bound on the probability of log(Z(\theta)). This is called the Variational Lower Bound (VLBO) or more commonly (in my reading) the Evidence Lower Bound (ELBO).

How is this equation useful? It shows up in the mathematics for the variational autoencoder, which I describe in the next post. It also shows up all sorts of other places that I haven’t covered.