Transformer
In the year 2017 researchers from Google introduced the so called Transformer[1] . Transformers have taken the world by storm after their initial release, starting with NLP first and slowly but surely spilling into computer vision, reinforcement learning and other domains. Nowadays transformers dominate most deep learning research and are an integral part of most state of the art models.
Encoder and Decoder
The original paper introduced the transformer as a language translation tool. Similar to recurrent seq-to-seq models the transformer is structured as an encoder-decoder architecture. The encoder takes the original sentence, processes each word in a series of layers and passes the results to the decoder, which in turn produces a translated version of the input sentence.
The source text and the output text are embedded by their individual embedding layers, before they are transferred to the encoder and decoder respectively. We depict the encoder slightly smaller, due to a somewhat more complex nature of the decoder, but the components of the encoder and the decoder are actually almost identical. The Nx to the right of the encoder and to the left of the decoder indicate that both blocks are actually made up of several stacked layers. In the original paper 6 encoder and 6 decoder layers were utilized.
Embeddings
When we use a recurrent net, the relative position of the word in a sentence is implicitly conveyed to the network, because the words are processed in an ordered fashion. A transformer on the other hand processes all words in a sentence at the same time, without caring for the relative position of the word. Yet the order in which a word appears in a sentence does matter for the meaning of that sentece. We need to somehow inject addtioinal positional information into the embeddings.
For that purpose we will use an additional embedding layer. We define an embedding layer which has as many embeddings, as the maximal sentence lengths requires. If you expect the longest sentence to consist of 100 tokens, you will need to encode 100 values. The first token in the sentence will get an embedding that corresponds to index 0, the second word the embedding that corresponds to index 1 and so on. The output of the token embedding and the positional embedding is a 512 dimensional vector. We add both values to get our final embedding.
class Embeddings(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.position_embedding = nn.Embedding(max_len, embedding_dim)
self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
seq_len = x.shape[1]
token_embedding = self.token_embedding(x)
position_embedding = self.position_embedding(
torch.arange(0, seq_len, device=device).view(1, seq_len)
)
return token_embedding + position_embedding
Attention
The type of attention that the transformer uses is called self-attention. Given a sequence of tokens, each token focuses on all parts of the sequence at the same time (including itself), but with different levels of attention, called attention weights.
We miltiply attention weights with each of the token embeddings and add up the results, thereby creating a new embedding, that is more aware of the surrounding context of the word.
Info
The purpose of self attention is to produce context-aware embeddings.
The easiest way to explain what that means is to look at so called homonyms. Words that are written the same, but have a different meaning. Let's for example look the meaning of the word date.
What is your date of birthday?
The date is my favourite fruit.
In the first sentence the word date will pay attention to itself, but also to birthday and will incorporate the word date and the information that relates to time into a single vector. In the second sentence, the word date will pay attention to itself and the word fruit, incorporating the "fruitiness" aspect into the vector of the word date.
Without the self-attention mechanism we would not be able to differentiate between the two words, because word embeddings produce the same vector for the same word, without incorporating the context that surrounds the word. But attention is obviously also useful for words other than homonyms, because it allows to create an embedding for each word, that is specific to the exact context that the word is surrounded by.
In practice the self-attention mechanism in transformers is inspired by information retrieval systems like database queries or search engines. Theses systems are based on notions of a query, a key and a value.
In a classical database, like the one above, it is relatively clear what values you will get back from your query. The value is returned, if the query alligns with the key. If for example we use the query "SELECT value WHERE key='key 1'", we should get value 1 in return.
When we deal with transformers we can think about a more "fuzzy" database, where we don't get a single value for a query, but a weighted sum of all values in the database. Let's for simplicity assume, that we have only two entries in the database with the following vector based keys.
We use the following vector based query.
We can determine the similarity between the query and each of the keys by calculating the dot product and we end up with the following results.
The similarity between the query and the first key is larger than with the second key, because the query and the first key are identical. The query and the second key are also somewhat related, because they have identical values in some of the vector spots.
We can use these similarity scores to calculate the attention weights, by using them as input into the softmax function.
Finally we use attention weights to calculate the weighted sum of the values from the database. This is the value that you retrieve from the database. A
The transformer is loosely based on this idea. In order to calculate the attention the transformer takes embeddings E undefined as an input. These can be original embeddings from the embedding layer, or outputs from a previous encoder/decoder layer. These embeddings are used as inputs into three different linear layers (without any activations), producing queries Q undefined , keys K undefined and values V undefined respectively. Those three are used to calculate the attention A undefined . As the queries, keys and values are all based on the same inputs we are still dealing with self attention, but the linear layers introduce weights, that make the attention mechanism more powerful.
The dimensions of the three matrices are identical: (batch size, sequence length, embedding dimension). This allows us to calculate the attention for all tokens and all batches in parallel.
The only variable that is unknown to us is d undefined , the dimension of the key. If we are dealing with a 64 dimensional vector embedding for example, we have to divide the similarity by the root of 64. According to the authors this is done, because if the similarity between two vectors is too strong, the softmax might get into a region with very low gradients. The scaling helps to alleviate that problem. The whole expression above is called scaled dot-product attention.
def attention(query, key, value, mask=None):
scores = (query @ key.transpose(1, 2)) / torch.tensor(
embedding_dim, device=device
).sqrt()
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn_weights = torch.softmax(scores, -1)
attention = attn_weights @ value
return attention
In the code snippet above we additionally use a so called attention mask. The mask is used when we want the transformer to ignore a certain part of the sentence. If the values of the mask amount to 0, we replace the scores by a value of minus infinity, which essentially amounts to attention weights of 0 due to the softmax.
There is still one caveat we need to discuss. Instead of calculating a single attention A undefined , we calculate a so called multihead attention. A single attention head calculates a separate Q undefined , K undefined and V undefined , but with a reduced embedding dimensionality. Instead of full 512 dimensional embeddings, each head uses only 64 dimensional vectors. Alltogether the transformer uses 8 heads, wich are concatenated in the final step.
This procedure might be useful, because each head can learn to focus on a separate context, thereby improving the performance of the transformer.
class AttentionHead(nn.Module):
def __init__(self):
super().__init__()
self.query = nn.Linear(embedding_dim, head_dim)
self.key = nn.Linear(embedding_dim, head_dim)
self.value = nn.Linear(embedding_dim, head_dim)
def forward(self, query, key, value, mask=None):
query = self.query(query)
key = self.key(key)
value = self.value(value)
return attention(query, key, value, mask)
class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
self.heads = nn.ModuleList([AttentionHead() for _ in range(num_heads)])
self.output = nn.Linear(embedding_dim, embedding_dim)
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
x = [head(query, key, value, mask) for head in self.heads]
x = torch.cat(x, dim=-1)
x = self.dropout(self.output(x))
return x
Position-wise Feed-Forward Networks
The encoder and decoder apply a so called position-wise feed-forward neural network. In essence that means that the same network, with the same weights is applied to each position of the sentence individually. Each embedded word in the sequence is passed though the network without interacting with any other word.
The position-wise network is a two-layer neural network, that takes an embedding of size 512, increases the dimensionality to 2048 in the first linear layer, applies a ReLU activation function, followed again by a linear layer that transforms the embeddings back to lengths 512.
PyTorch does this procedure automatically. Each dimension of a tensor, except for the last one is treated similar to a batch dimension. Only the last dimension, the embedding dimension, is processed through the neural network. The batch dimensions are regarded as additional samples, which can be are processed simultaneouly on the GPU.
class PWFeedForward(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(embedding_dim, fc_dim),
nn.ReLU(),
nn.Linear(fc_dim, embedding_dim),
nn.Dropout(p=dropout),
)
def forward(self, x):
return self.layers(x)
Encoder Layer
The encoder layer is a combination of two sublayers: a multihead attention and a position-wise feed-forward neural network. Both sublayers make up an encoder layer, that is stacked N times.
After both sublayers we use an "Add & Norm" block. The "Add" component indicates that we are using skip connections in order to mitigate vanishing gradients and stabilize training. The "Norm" part indicates, that we normalize the values, before we send the results to the next layer or sub-layer. In the original paper the authors used a so called layer normalization[2] . When we use layer norm we do not calculate the mean and the standard deviation for the same features over the different batches, but over the different features within the same batch.
Assuming we use a batch size of 5 and 10 features, the two approaches would differ in the following way.
You will notice in practice, that many modern implementation deviate from the original by normalizing the values first, before they are used as inputs into the sublayers. This is found to work better empirically and we do the same in the code snippets below.
class EncoderLayer(nn.Module):
def __init__(self):
super().__init__()
self.norm1 = nn.LayerNorm(embedding_dim)
self.norm2 = nn.LayerNorm(embedding_dim)
self.self_attention = MultiHeadAttention()
self.feed_forward = PWFeedForward()
def forward(self, src, mask=None):
normalized = self.norm1(src)
src = src + self.self_attention(normalized, normalized, normalized, mask)
normalized = self.norm2(src)
src = src + self.feed_forward(normalized)
return src
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([EncoderLayer() for _ in range(num_layers)])
def forward(self, src, mask=None):
for encoder in self.layers:
src = encoder(src, mask)
return src
Decoder Layer
The decoder layer is also stacks multihead-attention and position-wise feed-forward networks, but the implementation details are slightly different.
The embeddings from the target text are masked. This means that when we use multihead attention, the attention mechanism is only allowed to pay attention to words that were already generated. If that wouldn't be the case, the transformer would be allowed to cheat, by looking at the words it is expected to produce.
When we are about to produce the first word, the transformer is only allowed to see the start of sequence token. If it is about to produce the word "is", it is only allowed to additionally see the word "what". The transformer can pay attention to the words that came before, but never future words. To accomplish that practically we create a mask, which contains zeros at future positions.
You have already probably noticed, that the decoder has an additional attention layer. The second multi-head attention layer combines the encoder with the decoder. This time the queries, values and keys do not come from the same embeddings. The query is based on the decoder embeddings, while the key and the value are based on the output of the last encoder layer. This attention mechanism is called cross-attention.
The rest of the implementation is similar to the encoder.
class DecoderLayer(nn.Module):
def __init__(self):
super().__init__()
self.norm1 = nn.LayerNorm(embedding_dim)
self.norm2 = nn.LayerNorm(embedding_dim)
self.norm3 = nn.LayerNorm(embedding_dim)
self.self_attention = MultiHeadAttention()
self.cross_attention = MultiHeadAttention()
self.feed_forward = PWFeedForward()
def forward(self, src, trg, src_mask, trg_mask):
normalized = self.norm1(trg)
trg = trg + self.self_attention(normalized, normalized, normalized, trg_mask)
normalized = self.norm2(trg)
trg = trg + self.cross_attention(trg, src, src, src_mask)
normalized = self.norm3(trg)
trg = trg + self.feed_forward(normalized)
return trg
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([DecoderLayer() for _ in range(num_layers)])
def forward(self, src, trg, src_mask=None, trg_mask=None):
for decoder in self.layers:
trg = decoder(src, trg, src_mask, trg_mask)
return trg
Further Sources
Understanding the transformer with all the details is not an easy task. It is unlikely that the section above is sufficient to completely cover this architecture. You should therefore study as many sources as possible. Up to this day the transformer is the most performant architecture in deep learning and it is essential to have a solid understanding of the basic principles of this architecture.
You have to read the original paper by Vasvani et. al. We had to omit some of the implementation details, so if you want to implement the transformer on your own, reading this paper is a must.
"The Illustrated Transformer" by Jay Alamar is a great resource if you need additional intuitive illustrations and explanations.
"The Annotated Transformer" from the Harvard University is a great choice if you need an in depths PyTorch implementation.
The book "Natural Language Processing with Transformers" covers theory and applications of different transformer models in a very approachable manner.