Gated PixelCNN

Shortly after the initial release of the PixelCNN architecture, DeepMind released the Gated PixelCNN[1] . The paper introduced several improvements simultaneously, that reduced the gap with the recurrent PixelRNN.

Vertical and horizontal Stacks

The PixelCNN has a limitation, that is not obvious at first glance. To explain that limitation let's remember how a convolutional neural network usually works. The very first layer applies convolutions to a tight receptive field around a particular pixel. If we apply a 3x3 convolution, then the neural network can only look at the immediate surroundings of a particular pixel. But as we stack more and more convolutional layers on top of each other, the receptive field starts to grow.

In this interactive example we assume that all calculations are considered from the perspective of the black pixel, the kernel size is 3x3 and the padding is always 1 in order to keep the size of the image constant.

Now let's see how the receptive field grows, once we incorporate masked convolutions.

While the receptive field grows, we are left with a blind spot. Many pixels above the black dot are not taken into the account, which will most likely deteriorate the performance.

class MaskedConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, mask, dilation=1):
        super().__init__()
        kernel_size = mask.shape
        padding = tuple([dilation * (size - 1) // 2 for size in kernel_size])

        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=padding,
            dilation=dilation,
        )
        self.register_buffer("mask", mask)

    def forward(self, x):
        with torch.no_grad():
            self.conv.weight *= self.mask
        return self.conv(x)

To deal with this problem the researchers at DeepMind separated the convolution into two distinct stacks: the vertical stack, which processes the pixels above the black pixel and the horizontal stack, which processes the pixels to the left.

You can think about the vertical stack as a regular convolution, that can only access the upper half of the image.

The horizontal stack is a 1d convolution that processes the pixels to the left.

The combination of both produces the desired output.

class VerticalStackConvolution(MaskedConvolution):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, mask_type="B", dilation=1
    ):
        assert mask_type in ["A", "B"]
        mask = torch.ones(kernel_size, kernel_size)
        mask[kernel_size // 2 + 1 :, :] = 0
        if mask_type == "A":
            mask[kernel_size // 2, :] = 0

        super().__init__(in_channels, out_channels, mask, dilation=dilation)


class HorizontalStackConvolution(MaskedConvolution):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, mask_type="B", dilation=1
    ):
        assert mask_type in ["A", "B"]
        mask = torch.ones(1, kernel_size)
        mask[0, kernel_size // 2 + 1 :] = 0
        if mask_type == "A":
            mask[0, kernel_size // 2] = 0
        super().__init__(in_channels, out_channels, mask, dilation=dilation)

Gated Architecture

The gated PixelCNN architecture was developed in order to close the performance gap between the PixelCNN and the RowLSTM. The researcher hypothesised, that the multiplicative units from an LSTM can help the model to learn more complex patterns and introduced similar units to the convolutional layers.

undefined
undefined
undefined
undefined
undefined
undefined
undefined
undefined

Let's start our discussion with the upper part of the graph: the vertical stack. The vertical layer receives the output from the previous vertical stack and applies a undefined masked convolution of type 'B', such that the mask only looks at the above pixels. The convolution takes in undefined feature maps and produces twice that amount as the output. This is done because one half goes into the undefined and the other goes into the sigmoid activation undefined . We multiply both results positionwise. In essence we can interpret the sigmoid output as a gate, that decides which part of the undefined output is allowed to flow.

The lower part of the graph is the horizontal stack. First we process the output from the vertical convolution through a 1x1 convolution and add that result to the output of the horizontal convolution. That way the model can attend to all above pixels and all pixels to the left. Second we use skip connections in the vertical stack in order to facilitate training.

Lastly the PixelCNN paper focused on conditional models. For example we would like to condition the model on the label we would like to produce. As we are dealing with MNIST, we could use the numbers 0-9 as an additional input to the model, so that it can learn to generate specific numbers on demand. This should make it easier for a model to create coherent numbers.

class ConditionalGatedResidualBlock(nn.Module):
    def __init__(self, in_channels, kernel_size=3, dilation=1):
        super().__init__()
        self.v = VerticalStackConvolution(
            in_channels,
            out_channels=2 * in_channels,
            kernel_size=kernel_size,
            dilation=dilation,
        )
        self.h = HorizontalStackConvolution(
            in_channels,
            out_channels=2 * in_channels,
            kernel_size=kernel_size,
            dilation=dilation,
        )
        self.v_to_h = nn.Conv2d(2 * in_channels, 2 * in_channels, kernel_size=1)
        self.v_to_res = nn.Conv2d(in_channels, in_channels, kernel_size=1)

        self.v_embedding = nn.Embedding(num_embeddings=10, embedding_dim=in_channels)
        self.h_embedding = nn.Embedding(num_embeddings=10, embedding_dim=in_channels)

    def forward(self, v_prev, h_prev, num_cls):
        # calculate embeddings to condition the model
        v_embedding = self.v_embedding(num_cls).unsqueeze(-1).unsqueeze(-1)
        h_embedding = self.h_embedding(num_cls).unsqueeze(-1).unsqueeze(-1)

        # vertical stack
        v = self.v(v_prev + v_embedding)
        v_f, v_g = v.chunk(2, dim=1)
        v_out = torch.tanh(v_f) * torch.sigmoid(v_g)

        # vertical to horizontal
        v_to_h = self.v_to_h(v)

        # horizontal stack
        h = self.h(h_prev + h_embedding) + v_to_h
        h_f, h_g = h.chunk(2, dim=1)
        h_out = torch.tanh(h_f) * torch.sigmoid(h_g)

        # skip connection
        h_out = self.v_to_res(h_out)
        h_out += h_prev

        return v_out, h_out


class ConditionalGatedPixelCNN(nn.Module):
    def __init__(self, hidden_dim, dilations=[1, 2, 1, 4, 1, 2, 1]):
        super().__init__()
        self.v = VerticalStackConvolution(
            in_channels=1, out_channels=hidden_dim, kernel_size=7, mask_type="A"
        )
        self.h = HorizontalStackConvolution(
            in_channels=1, kernel_size=7, out_channels=hidden_dim, mask_type="A"
        )

        self.gated_residual_blocks = nn.ModuleList(
            [
                ConditionalGatedResidualBlock(
                    hidden_dim, kernel_size=3, dilation=dilation
                )
                for dilation in dilations
            ]
        )

        self.conv = nn.Conv2d(
            in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1
        )

        # we apply a 256 way softmax
        self.output = nn.Conv2d(in_channels=hidden_dim, out_channels=256, kernel_size=1)

    def forward(self, x, label):
        v = self.v(x)
        h = self.h(x)

        for gated_layer in self.gated_residual_blocks:
            v, h = gated_layer(v, h, label)

        out = self.conv(F.relu(h))
        out = self.output(F.relu(out))
        # from Batch, Classes, Height, Width to Batch, Classes, Channel, Height, Width
        out = out.unsqueeze(dim=2)
        return out

If we train our model for 25 epochs we get images similar to those below. The quality of the generated images is clearly a lot better than those we created in the previous section.

Generated MNIST Images

References

  1. van den Oord, AƤron and Kalchbrenner, Nal and Vinyals, Oriol and Espeholt, Lasse and Graves, Alex and Kavukcuoglu, Koray. Conditional Image Generation with PixelCNN Decoders. (2016).