Early Stopping
A simple strategy to deal with overfitting is to interrupt training, once the validation loss has been increasing for a certain number of epochs. When the validation loss starts increasing, while the training loss keeps decreasing, it is reasonable to assume that the training process has entered the phase of overfitting. At that point we should not waste the time watching the divergence between the training and valuation loss increase. This strategy is called early stopping. After you stopped the training you go back to the weights that showed the lowest validation loss. The assumption is, that those weights will exhibit best generalization capabilities.
PyTorch does not support early stopping out of the box, but if you know how to save and restore a model, you can easily implement this logic.
import torch
from torch import nn, optim
The two functions that PyTorch provides are torch.save()
and
torch.load()
. Below for example we save and load a simple
Tensor.
t = torch.ones(3, 3)
torch.save(t, f="tensor.pt")
loaded_t = torch.load(f="tensor.pt")
Usually we are not interested in saving just tensors, but whole internal
states. Modul states for example include all the weights and biases, but
also layer specific parameters, like the dropout probability. Often we also
need to save the state of the optimizer so that we can resume training at a
later time. To retrieve a state, modules and optimizers provide a state_dict()
method. A state can be restored, by utilizing the
load_state_dict()
method.
model = nn.Sequential(
nn.Linear(10, 50), nn.Sigmoid(), nn.Linear(50, 10), nn.Sigmoid(), nn.Linear(10, 1)
)
optimizer = optim.SGD(model.parameters(), lr=0.01)
model_state = model.state_dict()
optim_state = optimizer.state_dict()
torch.save({"model": model_state, "optmim": optim_state}, f="state.py")
state = torch.load(f="state.py")
model.load_state_dict(state["model"])
optimizer.load_state_dict(state["optim"])
We will not be using early stopping in the deep learning module, as this technique is generally considered a bad practice. Other techniques, like learning rate schedulers, that we will encounter in future chapters, will give us better options to decide if we found a good set of weights.