Learning Rate Scheduling

There is probably no hyperparameter that is more important than the learning rate undefined . If the learning rate is too high, we might overshood or oscilate. If the learning rate is too low, training might be too slow, or we might get stuck in some local minimum.

In the example below for example we pick a learning rate that is relatively large. The gradient descent algorithm (with momentum) overshoots and keeps oscilating for a while, before settling on the minimum.

0102030405060 -10-8-6-4-20246810 Constant Learning Rate 0.10

It is possible, that a single constant rate is not the optimal solution. What if we start out with a relatively large learning rate to gain momentum at the beginning of trainig, but decrease the learning rate either over time or at specific events. In deep learning this is called learning rate decay or learning rate scheduling. There are dozens of different schedulers (see the PyTorch documentation for more info). You could for example decay the learing rate by subtracting a constant rate every undefined episodes. Or you could multiply the learning rate at the end of each epoch by a constant factor, for example undefined . Below we use a popular learning rate decay technique that is called reduce learning rate on plateau. Once a metric (like a loss) stops improving for certain amount of epochs we decrease the learning rate by a predetermined factor. Below we use this technique, which reduces the learning rate once the algorithm overshoots. It almost looks like the ball "glides" into the optimal value.

0102030405060 -10-8-6-4-20246810 Variable Learning Rate 0.100

Deep learning frameworks like PyTorch or Keras make it extremely easy to create learning rate schedulers. Usually it involves no more than 2-3 lines of code.

Schedulers in PyTorch are located in otpim.lr_scheduler, in our example we pick optim.lr_scheduler.ReduceLROnPlateau. All schedulers take an optimizer as input. This is necessary because the learning rate is a part of an optimizer and the scheduler has to modify that paramter. The patience attribute is a ReduceLROnPlateau specific paramter that inidicates for how many epochs the performance metric (like cross-entropy) has to shrink in order to for the learning rate to be multiplied by the factor parameter of 0.1.

model = Model().to(DEVICE)
criterion = nn.CrossEntropyLoss(reduction="sum")
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.1)

We have to adjust the train function to introduce the scheduler logic. This function might for example look as below. Similar to the optimizer.step() method there is a scheduler.step() method. This function takes a performance measure lke the validation loss and adjusts the learning rate if necessary.

def train(train_dataloader, val_dataloader, model, criterion, optimizer, scheduler):
    for epoch in range(NUM_EPOCHS):
        for batch_idx, (features, labels) in enumerate(train_dataloader):
            # switch to training mode
            model.train()
            # move features and labels to GPU
            features = features.view(-1, NUM_FEATURES).to(DEVICE)
            labels = labels.to(DEVICE)

            # ------ FORWARD PASS --------
            probs = model(features)

            # ------CALCULATE LOSS --------
            loss = criterion(probs, labels)

            # ------BACKPROPAGATION --------
            loss.backward()

            # ------GRADIENT DESCENT --------
            optimizer.step()

            # ------CLEAR GRADIENTS --------
            optimizer.zero_grad()

        # ------TRACK LOSS --------
        train_loss, train_acc = track_performance(train_dataloader, model, criterion)
        val_loss, val_acc = track_performance(val_dataloader, model, criterion)
        
        # ------ADJUST LEARNING RATE --------
        scheduler.step(val_loss)

There are no hard rules what scheduler you need to use in what situation, but when you use PyTorch you need to always call optimizer.step() before you call scheduler.step().