Warning
This chapter is early work in progress
Generative Adversarial Networks
Generative adversarial networks (commonly known as GANs) is a family of generative models, that were designed to be trained through an adversarial process. The easiest way to explain what that actually means is by looking at a quote from the paper by Goodfellow et. al. that originally introduced the GAN architecture [1] .
Quote
There is a team of counterfeiters, trying to produce fake currency and use it without detection, while the the police is trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles.
The counterfeiters and the police are actually two separate fully connected neural networks. The counterfeiters neural network is called the generator G undefined . This model takes a vector of random noise \mathbf{z} undefined and produces an image G(\mathbf{z}) undefined .
To generate an image we can simply sample a vector \mathbf{z} undefined from the standard normal distribution and pass it through the function G undefined . The original GAN implementation is based on fully connected neural networks, therefore we will have to reshape the image into a 2d tensor at a later step.
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.generator = nn.Sequential(
nn.Linear(LATENT_SIZE, HIDDEN_SIZE),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.5),
nn.Linear(HIDDEN_SIZE, IMG_SIZE),
nn.Tanh(),
)
def forward(self, x):
return self.generator(x)
Our implementation is fairly simple. We use two fully connected layers with
a leaky ReLU
inbetween and a tanh as the output. It is fairly
common to use
tanh
for the output of the generator, therefore we will need to
scale the traininig data between -1 and 1.
The policing network, the discriminator D undefined , is designed to distinguished between real and generated data.
If the input is a true image \mathbf{x} undefined the discriminator is expected to generate a value close to 1, otherwise it should generate a value close to 0. Basically the discriminator generates a probability that indicates if the input is a real image or not.
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.discriminator = nn.Sequential(
nn.Linear(IMG_SIZE, HIDDEN_SIZE),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.5),
nn.Linear(HIDDEN_SIZE, 1),
)
def forward(self, x):
return self.discriminator(x)
Normally we would need to use a sigmoid as the output layer of the
discriminator, but nn.BCEWithLogitsLoss()
already takes care of
that, as it combines sigmoid with binary cross-entropy.
During the training process we combine the generator and the discriminator and train both jointly. We start the generation process by drawing Gaussian noise from the standard normal distributionz \sim N(0,1) undefined . Next we generate fake images by feeding a noise vector (also called latent vector) \mathbf{z} undefined into the generator neural network G undefined . The discriminator neural network D undefined receveis a batch of real data \mathbf{x} undefined and a batch of fake data G(\mathbf{z}) undefined and needs to predict the probability that the data is real.
The intuition why this process works goes as follows. At the beginnig the discriminator can not differentiate between true and fake images, but it faces a relatively straighforward classification task. Once classification accuracy increases the generator needs to learn to generate better images in order to fool the discriminator, which in turn forces the discriminator to get better and so the arms race continues until the generator creates images that look real and the discriminator generates a probability of 50%, because the generated images look realistic can not be distinguished from real images. Unfortunately the reality is more complicated. GANs are notoriously hard to train. In fact there are dozens of GANs architectures which were designed to improve on the original GAN architecture. We will encounter several of those architectures as move through this chapter.
In reality during training the generator and the discriminator play a so called min-max game with the following value function V undefined .
\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(\mathbf{x})] + \mathbb{E}_{ z \sim p_{z}(z)}[\log(1 - D(G(\mathbf{z})))] undefinedThe generator tries to minimize this function, while the discriminator tries to maximize this function. But why does this make sense? If the discriminator faces real data \mathbf{x} undefined , it makes sense to maximize D(\mathbf{x}) undefined , because D(\mathbf{x}) undefined indicates the probability of the data to be real. If the discriminator faces fake data G(\mathbf{z}) undefined , the discriminator will try to reduce the probability D(G(\mathbf{z})) undefined , thereby increasing \log(1 - D(G(\mathbf{z}))) undefined . The generator does the exact opposite. Its goal is to fool the discriminator and to increase the probability of the fake data to be seen as real, which can be achieved by maximizing D(G(\mathbf{z})) undefined .
Deep learning frameworks work with gradient descent, yet we are expected to maximize the value function from the point of the discriminator, so let's transform the above problem into a format that will be compatible with PyTorch. This is relatively straightforward, because maximizing an expression and miniziming a negative expression should lead to the same results.
L_n = - [ y_n \cdot \log D(x_n) + (1 - y_n) \cdot \log (1 - D(G(z_n))] undefinedWe essentially frame the problem as a binary cross-entropy loss. If the discriminator faces a true image, the loss will collapse to -\log D(x_n) undefined . If the discriminator faces a fake image, the loss will collapse to -[\log (1 - D(G(z_n))] undefined .
The generator is already framed as a minimization problem, yet we still face a practical problem. Especially at the beginning of the training the value \log(1 - D(G(z)) undefined will be close to 1 as the discriminator will have it easy to distinguish between real and fake images. D(G(z)) undefined and its gradient will be close to 0 and the generator will have a hard time training. The authors therefore suggest to turn the problem into a maximization problem: maximize \log D(G(z)) undefined . This leads to the same result, but as mentioned before PyTorch needs a problem to be framed in terms of gradient descent, so we minimize -\log D(G(z)) undefined . While we are optimizing towards the same weights and biases, the gradients are much larger at the beginning when the discriminator has an easy time. Try to think through a couple of examples to make sure you understand why the new expression is better for the training process.
for epoch in range(NUM_EPOCHS):
dis_loss_col = []
gen_loss_col = []
for batch_idx, (features, _) in enumerate(dataloader):
real_images = features.view(-1, IMG_SIZE).to(DEVICE)
# generate fake images from standar normal distributed latent vector
latent_vector = torch.randn(BATCH_SIZE, LATENT_SIZE, device=DEVICE)
fake_imgs = generator(latent_vector)
# calculate logits for true and fake images
fake_logits = discriminator(fake_imgs.detach())
real_logits = discriminator(real_images)
# calculate discriminator loss
dis_real_loss = criterion(real_logits, torch.ones(BATCH_SIZE, 1, device=DEVICE))
dis_fake_loss = criterion(
fake_logits, torch.zeros(BATCH_SIZE, 1, device=DEVICE)
)
dis_loss = dis_real_loss + dis_fake_loss
# optimize the discriminator
dis_optim.zero_grad()
dis_loss.backward()
dis_optim.step()
# calculate generator loss
gen_loss = criterion(
discriminator(fake_imgs), torch.ones(BATCH_SIZE, 1, device=DEVICE)
)
# optimize the generator
gen_optim.zero_grad()
gen_loss.backward()
gen_optim.step()
It might not be obvious at first glance, but when we calculate the generator loss we flip the labels: we use 1's instead of 0's. That trick transforms the generator loss into the right format. Try to understand why that works.
Below are the results that were generated by our simple GAN after 100 epochs. There is definetely room for improvement and we will look at better GAN architectures in the following sections of this chapter.