Auto-Encoding Variational Bayes

In my previous post on variational inference, I gave a quick foundation of the theory behind variational auteoncoders (VAEs). The dominant lay understanding of VAEs is that they’re models which allow you to take map an image to a smaller latent space and manipulate that image in meaningful ways by changing the values of the latent space. Most machine learning tutorials are centered around this interpretation. However, I feel this short-sells their meaning (and makes people say “just use a GAN instead”, which is disappointing). It is more correct, and interesting, to say that VAEs are an implementation of variational inference which allows use of arbitrarily complex functions (read: neural networks) as approximating distributions.

Like the previous post, a lot of this information was pulled from the Stanford Probabilistic Graphical Models Course, but it’s also worth reading the original (if opaque) Kingma & Welling Paper, watching this video (starting at 7 minutes or so), and reading this StackOverflow post.

Throughout this post, x refers to an input, z refers to latent variables, p(*) refers to an actual distribution, and q(*) refers to an approximating distribution.

In a variational autoencoder, we have a (true) equation p(z|x)=\frac{p(x,z)}{p(x)}, and want to estimate our p(x).

In order to do this, the first thing we do is re-formulate the VAE objective. We still use the upper bound:

    \begin{equation*} $log(Z(\theta)) = KL(q||p)-J(q)$ \end{equation*}

Which, in our variables, is:

    \begin{align*} log(p(x)) &= KL(q(z|x)||p(z|x)) - KL(q(z|x)||p(x,z)) \\ &= KL(q(z|x)||p(z|x)) + \mathbb{E}[(log(p(x,z))-log(q(z|x)))] \end{align*}

We note that KL is always greater than or equal to zero, so we remove the intractable first KL term. We then expand p(x,z) with the probability chain rule (and logarithm product rule).

    \begin{equation*} log(p(x)) \geq \mathbb{E}[log(p(x|z))+log(p(z))-log(q(z|x))] \end{equation*}

Exploit the linearity of expectation:

    \begin{equation*} log(p(x)) \geq \mathbb{E}[(log(p(x|z)))]+\mathbb{E}[log(p(z))-log(q(z|x)))] \end{equation*}

Flip the signs and collect the right two terms as a KL divergence (as q(z|x) marginalizes out, this is a straightforward algebraic manipulation of the KL divergence):

    \begin{equation*} log(p(x)) \geq \mathbb{E}[log(p(x|z))] - KL(q(z|x)||p(z)) \end{equation*}

Which is our objective function for the variational autoencoder. We can pause here for a minute and take a look at how this lines up with the common interpretation of our variational autoencoder. We note that there is a q(z|x), which corresponds to our encoder network, a p(z), which corresponds to our prior (a set of independent multivariate unit gaussians), and a p(x|z), which corresponds to our decoder network. The expected value operator is also significant, as it means we never actually need to know the value of the probability. The expected value is simply the output of the decoder network when all latent variables are at their mean value, and making the probability of x higher under this distribution corresponds to just making the input and output images match more closely!

There is one (not small) detail missing before we can perform gradient descent. Since z are random variables, we can’t differentiate through them to train the encoder and decoder network. This is what the reparameterization trick is for. In the reparameterization trick, we produce two outputs in our latent space (mean and standard deviation) and treat them as deterministic. That is, these are the outputs we want to learn. But we have a third, random, variable that we multiply sigma by which is drawn from a unit normal distribution. Since this isn’t learned, we can still backpropagate through the network as desired despite having random variables within it.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.