Variational Autoencoders

Variational autoencoders (VAE)[1] are latent variable models that have been relatively successfull in recent years. The mathematical background knowledge that is required to fully understand all intricacies of VAEs is quite extensive and can be intimidating for beginners. Therefore in this chapter we will mostly focus on the basics and the intuition, but we will also provide additional resources for you to study in case you would like to dive deeper into the theory. Fortunately the intuitive introduction is sufficient to implement a VAE from scratch in PyTorch, so you might skip the theory for now and return to it at a later point.

VAEs generate new data in a two step process. First we sample a latent variable undefined from the multivariate Gaussian distribution undefined . Each value of the latent vector is distributed according to the standard normal distribution undefined and there is no interaction between the values within the vector. In the second step we generate the new sampe undefined from the latent vector. The latent variable determines the characteristics of the data. In our implementation below for example we are going to generate new handwritten digits from Gaussian noise.

z X
undefined

How can we make an autoencoder learn a model that can generate data from Gaussian noise? The autoencoder that we have discussed in the previous section follows a very simple approach: we map an image to a latent vector undefined and reconstruct the original image undefined from the latent variable. While this procedure is ideal for simple compression tasks, when it comes to generating new images the simple autoencoder is suboptimal. The latent space that is learned might not be smooth or continuous, so when you use a latent vector that deviates slightly from the samples that the decoder saw during training, you will end up with an invalid image. Moreover there is no built-in mechanism for sampling, yet the purpose of our task is to train a generative model that we can sample new images from.

The variational autoencoder remedies those problems by mapping the image undefined to a whole distribution of latent variables undefined . Our encoder neural network with parameters undefined produces two latent vectors: the vector with means undefined and the vector with variances undefined . We use those values as input into an isotropic Gaussian distribution undefined (Gaussian with 0 covariance) and sample the latent variable undefined .

X
undefined
undefined
undefined
z

If you look at the graph above you might recognize the problem in the approach we described so far. How can we backpropagate our loss through the normal distribution? We can reframe our problem using the so called reparameterization trick. We rewrite the latent variable undefined as the function of the mean, the standard deviation and the Gaussian noise undefined .

undefined .

This rewriting does not change the fact that the latent vector undefined is distributed according to a multivariate Gaussian with mean undefined , but we can backpropagate through the mean and the variance and treat undefined as some constant that does not need to be optimized.

Once we start to implement the encoder in PyTorch we will notice a problem with our approach. The ecoder neural network produces undefined and undefined as the output of a linear layer. A linear layer can theoretically produce positive and negative numbers. While this is not a problem for the mean, the variance and the standard deviation are always positive. To circumvent this problem we will assume that the linear layer generates undefined , which can be positive or negative. To transform the logarithm of the variance back to standard deviation we can use the following equality: undefined .

The decoder of the variation autoencoder works in the exact same way as the one introduced in the previous section: given a latent vector undefined reconstruct the original image as close as possible.

The complete variational autoencoder looks as follows.

undefined
undefined
undefined
undefined
undefined
undefined

The last remaining puzzle is the loss function that we use to train a VAE, which consists of two parts: the reconstruction loss and a regularizer term.

The reconstruction loss we are going to use is the mean squared error between each pixel of the original image undefined and the reconstructed image undefined . This is the exact same loss that we used with the regular autoencoder. Be aware that sometimes the cross-entropy loss is used to measure the reconstruction quality, but for MNIST MSE works great.

The regularizer (Kullback-Leibler divergence) on the other hand tries to push the mean and the variance that the encoder outputs close to 0 and 1 respectively. This can be achieved by minimizing the following expression: undefined , where n is the size of the latent variable vector. Try to replace the mean by 0 and the variance by 1 in the above expression and see what happens. The loss goes to 0. This regularizer allows us to sample from the isotropic Gaussian with mean vector of 0 and the standard deviation vector of 1 and generate realistic images from this Gaussian noise.

Below is the implementation of the VAE. There should not be any unexptected code snippets at this point.

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
            nn.Flatten(),
        )

        self.mu = nn.Linear(in_features=1600, out_features=latent_dim)
        self.log_var = nn.Linear(in_features=1600, out_features=latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(in_features=latent_dim, out_features=1600),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (64, 5, 5)),
            nn.ConvTranspose2d(
                in_channels=64, out_channels=64, kernel_size=3, stride=2
            ),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(
                in_channels=64,
                out_channels=32,
                kernel_size=3,
                stride=2,
                output_padding=1,
            ),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=3),
            nn.Sigmoid(),
        )

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        x = self.encoder(x)
        mu, sigma = self.mu(x), torch.exp(self.log_var(x) / 2)
        epsilon = torch.randn_like(mu, device=DEVICE)
        x = mu + sigma * epsilon
        x = self.decoder(x)
        return x, mu, sigma

The train function on the other hand still holds a little surprise for us. When we calculate our full loss we scale the reconstruction_loss by a factor of 0.01. The MSE loss and the regularizer (kl_loss) are on different scales. In our implementation the MSE loss can be 100 times larger than the regularizer. If we do not scale MSE by 0.01 the improvement of the regularizer will progress slowly and the sampling quality from Gaussian noise will deteriorate.

def train(num_epochs, train_dataloader, model, criterion, optimizer):
    history = {"reconstruction_loss": [], "kl_loss": [], "full_loss": []}
    model.to(DEVICE)
    for epoch in range(num_epochs):
        num_batches = 0
        history["reconstruction_loss"] = []
        history["kl_loss"] = []
        history["full_loss"] = []

        for batch_idx, (features, _) in enumerate(train_dataloader):
            model.train()
            num_batches += 1

            features = features.to(DEVICE)

            # Forward Pass
            output, mu, sigma = model(features)

            # Calculate Loss

            # RECONSTRUCTION LOSS
            reconstruction_loss = criterion(output, features)
            reconstruction_loss = reconstruction_loss.mean()

            history["reconstruction_loss"].append(reconstruction_loss.cpu().item())

            # KL LOSS
            kl_loss = -0.5 * (1 + (sigma**2).log() - mu**2 - sigma**2).sum(dim=1)
            kl_loss = kl_loss.mean()

            history["kl_loss"].append(kl_loss.cpu().item())

            # FULL LOSS
            loss = 0.01 * reconstruction_loss + kl_loss

            history["full_loss"].append(loss.cpu().item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        reconstruction_loss, kl_loss, full_loss = (
            sum(history["reconstruction_loss"]) / num_batches,
            sum(history["kl_loss"]) / num_batches,
            sum(history["full_loss"]) / num_batches,
        )

        print(
            f"Epoch: {epoch+1:>2}/{num_epochs} | Reconstruction Loss: {reconstruction_loss:.5f} | KL Loss: {kl_loss:.5f} | Full Loss: {full_loss:.5f}"
        )

To sample images, we first sample random noise from the standad normal distribution and pass the noise throught the decoder. After training the model for 50 epochst we get the following results. The quality is not ideal, but for the most part we can recognize the digits.

num_images = 6
    with torch.inference_mode():
        z = torch.randn(num_images, LATENT_DIM).to(DEVICE)
        images = vae.decode(z)
        fig = plt.figure(figsize=(15, 4))
        for i, img in enumerate(images):
            fig.add_subplot(1, 6, i + 1)
            img = img.squeeze().cpu().numpy()
            plt.imshow(img, cmap="gray")
            plt.axis("off")
    plt.savefig("sampled.png", bbox_inches="tight")
Handwritten digits sampled from a VAE

If you would like to dive deeper into the mathematical derivations of the VAE, there are a couple of sources we would recommend. This blog post by Lilian Weng contains a good overview of different autoencoders. Additionally we would recommend this YouTube video by DeepMind, which provides a good introduction into the theory of latent variable models.

References

  1. Diederik P. Kingma, Max Welling. Auto-Encoding Variational Bayes. (2013).