LSTM
Often a recurrent neural network has to process very large sequences. With each recurrent step we incorporate more and more information into the hidden state while the distance to the start of the sequence increases. Let's take the following language task as an example.
Computers have been a passion of mine from the very early age. I was fascinated by the idea that I could create anything out of nothing using only my skills and my imagination. So when the time came to select my major, I did not hesitate and picked -----------.
In this task we need to predict the last word(s). The key to the solution is to realize that the very first word in the very first sentence provides the necessary context. If that person has been interested in computers for such a long time, it is reasonable to assume that computer science or programming would be good choces to fill the blank. Yet the distance between the first word and the prediction is roughly 50 words. Each time the recurrent neural network processes a piece of the sentence, it adjusts the hidden state, so that by the end of the sentence hardly any information remains from the beginning of the sentence. The network forgets the beginning of the sentence long before it is done reading. A recurrent neural network struggles with so called long term dependencies.
Theoretically speaking there is nothing that prevents a recurrent neural network from learning such dependencies. Assuming you had the perfect weights for a particular task, an RNN should have the capacity to model long term dependencies. It is the learning part that makes recurrent networks less attractive. A recurrent neural network shares weights, thus when the network encounters a long sequence, gradients will explode or vanish. We can use gradient clipping to deal with exploding gradients, but many of the techniques that we used with feed-forward neural networks to deal with vanishing gradients will not work. Batch normalization for example calculates the mean and the standard deviation per feature and layer, but in a recurrent neural network the weights are shared and we might theoretically need different statistics for a different part of the recurrent loop. Specialized techniques were developed to deal with the vanishing gradients problem of RNNs and in this section we are gong to cover a new type of a recurrent neural network called long short-term memory, or LSTM[1] for short.
For the most part we do not change the overarching design of a recurrent neural network. The LSTM cell produces outputs, that are used as an input in the next iteration. Unlike the regular RNN cell, an LSTM cell produces two vectors: the short term memory \mathbf{h}_t undefined (hidden value) and the long term memory \mathbf{c_t} undefined (cell value).
An LSTM cell makes heavy use of so-called gates. Gates allow information to keep flowing or stop information flow depending on the state of the gate.
When the gate is closed, no information is allowed to flow and data can not pass past the gate.
When the gate is open, information can flow without interruption.
Essentially an LSTM cell determines which parts of the sequence data should be processed and saved for future reference and which parts are irrelevant.
Let's for example assume that our data is a two dimensional vector, with values 2 and 5.
The gate is a vector of the same size, that contains values of either 0 or 1.
In order to determine what part of the data is allowed to flow we use elementwise multiplication.
The parts of the vector that are multiplied by a 0 are essentially erased, while those parts that are multiplied by a 1 are allowed to keep flowing. In practice LSTM cells do not work in a completely binary fashion, but contain continuous values between 0 and 1. This allows the gate to pass just a fraction of the information.
The beauty of an LSTM cell is its ability to learn and to calculate the values of different gates automatically. That means that an LSTM cell decides on the fly which information is important for the future and should be saved and which information should be discarded. These gates are essentially fully connected neural networks, which use a sigmid activation function in order to scale the values between 0 and 1.
Now let's have a look at the inner workings of the LSTM cell. The design might look intimidating at first glance, but we will take it one step at a time.
The LSTM cell outputs the long-term memory \mathbf{c}_t undefined and the short term memory \mathbf{h}_t undefined . For that purpose the LSTM cell contains four fully connected neural networks. The red networks f undefined , i undefined and o undefined are networks with a sigmoid activation function, that act as gates, while the violet neural network g undefined applies a tanh activation function and is used to generate values that can be used to adjust the long-term memory. All four networks take the same inputs: a vector that contains previous hidden state \mathbf{h}_{t-1} undefined and the current piece of the sequence \mathbf{x}_{t} undefined .
If you look at the flow of the long term memory \mathbf{c} undefined you should notice, that it flows in a straight line from one part of the sequence to the next. The general idea is to only adjust that flow if it is warranted. That allows the LSTM cell to establish long-term dependencies.
The neural network f undefined calculates the forget gate. We multipy each component of the long term memory \mathbf{c}_{t-1} undefined vector by each component from the neural network f undefined . The gate uses the sigmoid activation function and can therefore theoretically reduce or even completely erase the long term memory if the LSTM cell deems this necessary. The closer the outputs of the fully connected neural network are to 1, the more long-term memory is kept.
In the second step we decide if we should add anything to the long term memory. First we calculate the memories that can be used as potential additions to the long-term memory using the fully connected neural network g undefined .
Then we use the neural network i undefined , which acts as a gate for those "potential memories". This gate is called input gate. The elementwise product of the two neural networks outputs is the actual adjustment to the long-term state, which are added elementwise to the values that were passed through the forget gate.
The forget gate, the input gate and the "potential memories" are used to calculate the long-term memories for the next timestep of the series.
The final neural network o undefined is used to determine which values are suitable for the short-term memory \mathbf{h}_t undefined . This gate is called the output gate. For that purpose the long-term memory \mathbf{c}_t undefined is copied and is preprocessed by the tanh activation function. The result is multiplied by the output gate.
If we want to use a LSTM instead of a plain valilla recurrent neural net, we
have to use the nn.LSTM
module instead of the
nn.RNN
.
batch_size=4
sequence_length=5
input_size=6
hidden_size=3
num_layers=2
lstm = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers)
We need to account for the long term memory, but the rest of the implementation is almost identical.
# create inputs to the LSTM
sequence = torch.randn(sequence_length, batch_size, input_size)
h_0 = torch.zeros(num_layers, batch_size, hidden_size)
c_0 = torch.zeros(num_layers, batch_size, hidden_size)
with torch.inference_mode():
output, (h_n, c_n) = lstm(sequence, (h_0, c_0))
print(output.shape, h_n.shape, c_n.shape)
torch.Size([5, 4, 3]) torch.Size([2, 4, 3]) torch.Size([2, 4, 3])