Gradient Clipping
The exploding gradients problem arises when the gradients get larger and larger, until they get larger than the maximum permitted value for the tensor datatype.
We could remedy the problem with a simple solution. We could determine the maximum allowed gradient and if the gradient value moves beyond that threshold we cut the gradient to the allowed value. The technique we just described is called gradient clipping, value clipping to be exact.
The below table demonstrates how value clipping works in theory. We set the threshold value to 1 and if the absolute value of the gradient moves beyond the threshold, we clip it to the max value.
Original Gradient | Clipped Gradient |
---|---|
1 | 1 |
0.5 | 0.5 |
2 | 1 |
-3 | -1 |
Value clipping is problematic, because it basically changes the direction of gradient descent. Below is a simulation to demonstrate the problem. When you start the simulation, the gradient vector (dashed line) will start to move randomly in the 2d coordinate system. If one of of the two gradients is larger than one, we will clip that gradient to 1 and thus create a new gradient vector (red line). So if one gradient is 3 and the other is 1.5, we clip both to 1, thereby disregarding the relative magnitude of the vector components and changing the direction of the vector. This is not what we actually desire.
A better solution is to use norm clipping. When we clip the norm of the gradient vector, we clip all the gradients proportionally, such that the direction remains the same.
Below is a simulation of norm clipping. When the magnitude of the gradient vector is reduced to the threshold value, the direction remains unchanged.
Norm clipping often feels like a hack, but it is actually quite practical. You might not be able to solve all your problems with gradient clipping, but it should be part of your toolbox.
The implementation of gradient clipping is PyTorch is astonishingly simple.
All we have to do is to add the following line of code after we call loss.backward()
but before we call optimizer.step()
.
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
The above line concatenates all parameter gradients into a single vector, calculates the norm for that vector and eventually clips the gradients in-place, if the norm is above 1.