Gated PixelCNN
Shortly after the initial release of the PixelCNN architecture, DeepMind released the Gated PixelCNN[1] . The paper introduced several improvements simultaneously, that reduced the gap with the recurrent PixelRNN.
Vertical and horizontal Stacks
The PixelCNN has a limitation, that is not obvious at first glance. To explain that limitation let's remember how a convolutional neural network usually works. The very first layer applies convolutions to a tight receptive field around a particular pixel. If we apply a 3x3 convolution, then the neural network can only look at the immediate surroundings of a particular pixel. But as we stack more and more convolutional layers on top of each other, the receptive field starts to grow.
In this interactive example we assume that all calculations are considered from the perspective of the black pixel, the kernel size is 3x3 and the padding is always 1 in order to keep the size of the image constant.
Now let's see how the receptive field grows, once we incorporate masked convolutions.
While the receptive field grows, we are left with a blind spot. Many pixels above the black dot are not taken into the account, which will most likely deteriorate the performance.
class MaskedConvolution(nn.Module):
def __init__(self, in_channels, out_channels, mask, dilation=1):
super().__init__()
kernel_size = mask.shape
padding = tuple([dilation * (size - 1) // 2 for size in kernel_size])
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
)
self.register_buffer("mask", mask)
def forward(self, x):
with torch.no_grad():
self.conv.weight *= self.mask
return self.conv(x)
To deal with this problem the researchers at DeepMind separated the convolution into two distinct stacks: the vertical stack, which processes the pixels above the black pixel and the horizontal stack, which processes the pixels to the left.
You can think about the vertical stack as a regular convolution, that can only access the upper half of the image.
The horizontal stack is a 1d convolution that processes the pixels to the left.
The combination of both produces the desired output.
class VerticalStackConvolution(MaskedConvolution):
def __init__(
self, in_channels, out_channels, kernel_size=3, mask_type="B", dilation=1
):
assert mask_type in ["A", "B"]
mask = torch.ones(kernel_size, kernel_size)
mask[kernel_size // 2 + 1 :, :] = 0
if mask_type == "A":
mask[kernel_size // 2, :] = 0
super().__init__(in_channels, out_channels, mask, dilation=dilation)
class HorizontalStackConvolution(MaskedConvolution):
def __init__(
self, in_channels, out_channels, kernel_size=3, mask_type="B", dilation=1
):
assert mask_type in ["A", "B"]
mask = torch.ones(1, kernel_size)
mask[0, kernel_size // 2 + 1 :] = 0
if mask_type == "A":
mask[0, kernel_size // 2] = 0
super().__init__(in_channels, out_channels, mask, dilation=dilation)
Gated Architecture
The gated PixelCNN architecture was developed in order to close the performance gap between the PixelCNN and the RowLSTM. The researcher hypothesised, that the multiplicative units from an LSTM can help the model to learn more complex patterns and introduced similar units to the convolutional layers.
Let's start our discussion with the upper part of the graph: the vertical stack. The vertical layer receives the output from the previous vertical stack and applies a n \times n undefined masked convolution of type 'B', such that the mask only looks at the above pixels. The convolution takes in p undefined feature maps and produces twice that amount as the output. This is done because one half goes into the \tanh undefined and the other goes into the sigmoid activation \sigma undefined . We multiply both results positionwise. In essence we can interpret the sigmoid output as a gate, that decides which part of the \tanh undefined output is allowed to flow.
The lower part of the graph is the horizontal stack. First we process the output from the vertical convolution through a 1x1 convolution and add that result to the output of the horizontal convolution. That way the model can attend to all above pixels and all pixels to the left. Second we use skip connections in the vertical stack in order to facilitate training.
Lastly the PixelCNN paper focused on conditional models. For example we would like to condition the model on the label we would like to produce. As we are dealing with MNIST, we could use the numbers 0-9 as an additional input to the model, so that it can learn to generate specific numbers on demand. This should make it easier for a model to create coherent numbers.
class ConditionalGatedResidualBlock(nn.Module):
def __init__(self, in_channels, kernel_size=3, dilation=1):
super().__init__()
self.v = VerticalStackConvolution(
in_channels,
out_channels=2 * in_channels,
kernel_size=kernel_size,
dilation=dilation,
)
self.h = HorizontalStackConvolution(
in_channels,
out_channels=2 * in_channels,
kernel_size=kernel_size,
dilation=dilation,
)
self.v_to_h = nn.Conv2d(2 * in_channels, 2 * in_channels, kernel_size=1)
self.v_to_res = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v_embedding = nn.Embedding(num_embeddings=10, embedding_dim=in_channels)
self.h_embedding = nn.Embedding(num_embeddings=10, embedding_dim=in_channels)
def forward(self, v_prev, h_prev, num_cls):
# calculate embeddings to condition the model
v_embedding = self.v_embedding(num_cls).unsqueeze(-1).unsqueeze(-1)
h_embedding = self.h_embedding(num_cls).unsqueeze(-1).unsqueeze(-1)
# vertical stack
v = self.v(v_prev + v_embedding)
v_f, v_g = v.chunk(2, dim=1)
v_out = torch.tanh(v_f) * torch.sigmoid(v_g)
# vertical to horizontal
v_to_h = self.v_to_h(v)
# horizontal stack
h = self.h(h_prev + h_embedding) + v_to_h
h_f, h_g = h.chunk(2, dim=1)
h_out = torch.tanh(h_f) * torch.sigmoid(h_g)
# skip connection
h_out = self.v_to_res(h_out)
h_out += h_prev
return v_out, h_out
class ConditionalGatedPixelCNN(nn.Module):
def __init__(self, hidden_dim, dilations=[1, 2, 1, 4, 1, 2, 1]):
super().__init__()
self.v = VerticalStackConvolution(
in_channels=1, out_channels=hidden_dim, kernel_size=7, mask_type="A"
)
self.h = HorizontalStackConvolution(
in_channels=1, kernel_size=7, out_channels=hidden_dim, mask_type="A"
)
self.gated_residual_blocks = nn.ModuleList(
[
ConditionalGatedResidualBlock(
hidden_dim, kernel_size=3, dilation=dilation
)
for dilation in dilations
]
)
self.conv = nn.Conv2d(
in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1
)
# we apply a 256 way softmax
self.output = nn.Conv2d(in_channels=hidden_dim, out_channels=256, kernel_size=1)
def forward(self, x, label):
v = self.v(x)
h = self.h(x)
for gated_layer in self.gated_residual_blocks:
v, h = gated_layer(v, h, label)
out = self.conv(F.relu(h))
out = self.output(F.relu(out))
# from Batch, Classes, Height, Width to Batch, Classes, Channel, Height, Width
out = out.unsqueeze(dim=2)
return out
If we train our model for 25 epochs we get images similar to those below. The quality of the generated images is clearly a lot better than those we created in the previous section.