Warning

This chapter is early work in progress

Generative Adversarial Networks

Generative adversarial networks (commonly known as GANs) is a family of generative models, that were designed to be trained through an adversarial process. The easiest way to explain what that actually means is by looking at a quote from the paper by Goodfellow et. al. that originally introduced the GAN architecture [1] .

Quote

There is a team of counterfeiters, trying to produce fake currency and use it without detection, while the the police is trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles.

The counterfeiters and the police are actually two separate fully connected neural networks. The counterfeiters neural network is called the generator undefined . This model takes a vector of random noise undefined and produces an image undefined .

undefined

To generate an image we can simply sample a vector undefined from the standard normal distribution and pass it through the function undefined . The original GAN implementation is based on fully connected neural networks, therefore we will have to reshape the image into a 2d tensor at a later step.

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.generator = nn.Sequential(
            nn.Linear(LATENT_SIZE, HIDDEN_SIZE),
            nn.LeakyReLU(0.1),
            nn.Dropout(p=0.5),
            nn.Linear(HIDDEN_SIZE, IMG_SIZE),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.generator(x)

Our implementation is fairly simple. We use two fully connected layers with a leaky ReLU inbetween and a tanh as the output. It is fairly common to use tanh for the output of the generator, therefore we will need to scale the traininig data between -1 and 1.

The policing network, the discriminator undefined , is designed to distinguished between real and generated data.

undefined
undefined
T

If the input is a true image undefined the discriminator is expected to generate a value close to 1, otherwise it should generate a value close to 0. Basically the discriminator generates a probability that indicates if the input is a real image or not.

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(IMG_SIZE, HIDDEN_SIZE),
            nn.LeakyReLU(0.1),
            nn.Dropout(p=0.5),
            nn.Linear(HIDDEN_SIZE, 1),
        )

    def forward(self, x):
        return self.discriminator(x)

Normally we would need to use a sigmoid as the output layer of the discriminator, but nn.BCEWithLogitsLoss() already takes care of that, as it combines sigmoid with binary cross-entropy.

During the training process we combine the generator and the discriminator and train both jointly. We start the generation process by drawing Gaussian noise from the standard normal distribution undefined . Next we generate fake images by feeding a noise vector (also called latent vector) undefined into the generator neural network undefined . The discriminator neural network undefined receveis a batch of real data undefined and a batch of fake data undefined and needs to predict the probability that the data is real.

undefined
undefined
undefined
undefined
undefined
undefined

The intuition why this process works goes as follows. At the beginnig the discriminator can not differentiate between true and fake images, but it faces a relatively straighforward classification task. Once classification accuracy increases the generator needs to learn to generate better images in order to fool the discriminator, which in turn forces the discriminator to get better and so the arms race continues until the generator creates images that look real and the discriminator generates a probability of 50%, because the generated images look realistic can not be distinguished from real images. Unfortunately the reality is more complicated. GANs are notoriously hard to train. In fact there are dozens of GANs architectures which were designed to improve on the original GAN architecture. We will encounter several of those architectures as move through this chapter.

In reality during training the generator and the discriminator play a so called min-max game with the following value function undefined .

undefined

The generator tries to minimize this function, while the discriminator tries to maximize this function. But why does this make sense? If the discriminator faces real data undefined , it makes sense to maximize undefined , because undefined indicates the probability of the data to be real. If the discriminator faces fake data undefined , the discriminator will try to reduce the probability undefined , thereby increasing undefined . The generator does the exact opposite. Its goal is to fool the discriminator and to increase the probability of the fake data to be seen as real, which can be achieved by maximizing undefined .

Deep learning frameworks work with gradient descent, yet we are expected to maximize the value function from the point of the discriminator, so let's transform the above problem into a format that will be compatible with PyTorch. This is relatively straightforward, because maximizing an expression and miniziming a negative expression should lead to the same results.

undefined

We essentially frame the problem as a binary cross-entropy loss. If the discriminator faces a true image, the loss will collapse to undefined . If the discriminator faces a fake image, the loss will collapse to undefined .

The generator is already framed as a minimization problem, yet we still face a practical problem. Especially at the beginning of the training the value undefined will be close to 1 as the discriminator will have it easy to distinguish between real and fake images. undefined and its gradient will be close to 0 and the generator will have a hard time training. The authors therefore suggest to turn the problem into a maximization problem: maximize undefined . This leads to the same result, but as mentioned before PyTorch needs a problem to be framed in terms of gradient descent, so we minimize undefined . While we are optimizing towards the same weights and biases, the gradients are much larger at the beginning when the discriminator has an easy time. Try to think through a couple of examples to make sure you understand why the new expression is better for the training process.

for epoch in range(NUM_EPOCHS):
    dis_loss_col = []
    gen_loss_col = []
    for batch_idx, (features, _) in enumerate(dataloader):
        real_images = features.view(-1, IMG_SIZE).to(DEVICE)

        # generate fake images from standar normal distributed latent vector
        latent_vector = torch.randn(BATCH_SIZE, LATENT_SIZE, device=DEVICE)
        fake_imgs = generator(latent_vector)

        # calculate logits for true and fake images
        fake_logits = discriminator(fake_imgs.detach())
        real_logits = discriminator(real_images)

        # calculate discriminator loss
        dis_real_loss = criterion(real_logits, torch.ones(BATCH_SIZE, 1, device=DEVICE))
        dis_fake_loss = criterion(
            fake_logits, torch.zeros(BATCH_SIZE, 1, device=DEVICE)
        )
        dis_loss = dis_real_loss + dis_fake_loss

        # optimize the discriminator
        dis_optim.zero_grad()
        dis_loss.backward()
        dis_optim.step()

        # calculate generator loss
        gen_loss = criterion(
            discriminator(fake_imgs), torch.ones(BATCH_SIZE, 1, device=DEVICE)
        )

        # optimize the generator
        gen_optim.zero_grad()
        gen_loss.backward()
        gen_optim.step()

It might not be obvious at first glance, but when we calculate the generator loss we flip the labels: we use 1's instead of 0's. That trick transforms the generator loss into the right format. Try to understand why that works.

Below are the results that were generated by our simple GAN after 100 epochs. There is definetely room for improvement and we will look at better GAN architectures in the following sections of this chapter.

MNIST generated by a GAN

References

  1. Goodfellow, Ian and Pouget-Abadie, Jean and Mirza, Mehdi and Xu, Bing and Warde-Farley, David and Ozair, Sherjil and Courville, Aaron and Bengio, Yoshua. Generative Adversarial Nets. Advances in Neural Information Processing Systems. Vol. 27. (2014).