Biderectional Recurrent Neural Networks
A recurrent neural network processes one part of the sequence at a time. When we are dealing with a sentence for example, the neural network starts with the very first word and moves forward through the sentence. A biderectional recurrent neural network traverses the sequence from two directions. As usual from the start to finish and in the reverse direction, from finish to start. The output of the network, \mathbf{y_t} undefined , simply concatenates the two vectors that come from different directions.
A biderectional recurrent neural network is especially well suited for language tasks. Look at the two sentences below.
The bank opens ...
The bank of the river ...
While the sentences start out with the same two words, the meaning can only be understood by reading through the whole sentence.
A biderectional RNN is not suited for every task though. If you intend to predict future points of a time series data and you use a biderectional RNN, you will introduce data leakage. Data leakage means that during training your network has access to the type of information, that is not available during inference. Using a biderectional RNN would imply that you use future time series information to train your neural network, like training a RNN to predict the stock price, that the network has already observed.
We can implement a biderectional RNN in PyTorch by simply setting the bidirectional
flag to True
.
batch_size=4
sequence_length=5
input_size=6
hidden_size=3
num_layers=2
rnn = nn.RNN(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=True)
sequence = torch.randn(sequence_length, batch_size, input_size)
# 2*num_layers due to biderectional model
h_0 = torch.zeros(2*num_layers, batch_size, hidden_size)
Due to the biderectional nature of the recurrent neural network, the dimensions of the outputs and the hidden states increase.
with torch.inference_mode():
output, h_n = rnn(sequence, h_0)
print(output.shape, h_n.shape)
torch.Size([5, 4, 6]) torch.Size([4, 4, 3])