Mixed Precision Training

In the next section we will begin looking at different CNN architectures. While the older architectures are relatively easy to train, more modern architectures require a lot of computational power. There are different ways to deal with those requirements, but in this section we will specifically focus on mixed precision traing.

So far when we trained neural networks, we utilized the torch.float32 datatype. But there are layers, like linear layers and convolutions, that can be executed much faster using the lower torch.float16 precision.


Mixed precision training allows us to train a neural network utilizing different levels of precision for different layers.

Mixed precision training has at least two advantages.

  1. Some layers are faster with torch.float16 precision, therefore the whole training process will be significantly faster
  2. Operations using torch.float16 require less memory than `torch.float32` operations. That will reduce the necessary vram requirements and will allow us to use a larger batch size.

PyTorch provides a so called automatic mixed precision functionality, that automatically decides which of the operations will run with which precision. We do not have to make any of those decisions manually. The official PyTorch documentation provides more info on the topic.

We will demonstrate the performance boost from mixed precision training with the help of the MNIST dataset.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms as T

import time
assert torch.cuda.is_available()
train_dataset = MNIST(root="../datasets", train=True, download=True, transform=T.ToTensor())

We use a much larger network, than what is required to get a good performance for MINST in order to demonstrate the potential of mixed precision training.

cfg = [[1, 32, 3, 1, 1],
       [32, 64, 3, 1, 1],
       [64, 64, 2, 2, 0],
       [64, 128, 3, 1, 1],
       [128, 128, 3, 1, 1],
       [128, 128, 3, 1, 1],
       [128, 128, 2, 2, 0],
       [128, 256, 3, 1, 1],
       [256, 256, 2, 1, 0],
       [256, 512, 3, 1, 1],
       [512, 512, 3, 1, 1],
       [512, 512, 3, 1, 1],
       [512, 512, 2, 2, 0],
       [512, 1024, 3, 1, 1],

class BasicBlock(nn.Module):
    def __init__(self, **kwargs):
        self.block = nn.Sequential(
    def forward(self, x):
        return self.block(x)

class Model(nn.Module):
    def __init__(self, cfg):
        self.features = self._build_layers(cfg)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(in_features=1024, out_features=1000),
            nn.Linear(in_features=1000, out_features=1000),
            nn.Linear(in_features=1000, out_features=10),
    def _build_layers(self, cfg):
        layers = []
        for layer in cfg:
            layers += [BasicBlock(in_channels=layer[0],
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        return x
DEVICE = torch.device('cuda')

We start by training the neural network in a familiar manner, measuring the time an epoch takes. We can use those values as a benchmark.

def train(data_loader, model, optimizer, criterion):
    for epoch in range(NUM_EPOCHS):
        start_time = time.time()
        losses = []
        for img, label in data_loader:
            img = img.to(DEVICE)
            label = label.to(DEVICE)
            prediction = model(img)
            loss = criterion(prediction, label)

        end_time = time.time()
        s = f'Epoch: {epoch+1}, ' \
          f'Loss: {sum(losses)/len(losses):.4f}, ' \
          f'Elapsed Time: {end_time-start_time:.2f}sec'
model = Model(cfg)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

Each epoch takes slightly over 20 seconds to complete.

train(train_dataloader, model, optimizer, criterion)
Epoch: 1, Loss: 0.2528, Elapsed Time: 22.82sec
Epoch: 2, Loss: 0.0316, Elapsed Time: 21.99sec
Epoch: 3, Loss: 0.0201, Elapsed Time: 22.11sec
Epoch: 4, Loss: 0.0155, Elapsed Time: 22.15sec
Epoch: 5, Loss: 0.0123, Elapsed Time: 22.14sec
Epoch: 6, Loss: 0.0106, Elapsed Time: 22.18sec
Epoch: 7, Loss: 0.0112, Elapsed Time: 22.11sec
Epoch: 8, Loss: 0.0084, Elapsed Time: 22.15sec
Epoch: 9, Loss: 0.0083, Elapsed Time: 22.17sec
Epoch: 10, Loss: 0.0078, Elapsed Time: 22.14sec

We repeat the training procedure, only this time we use mixed precision training. For that we will utilize the torch.amp module. The torch.amp.autocast context manager runs the region below the context manager in mixed precision. For our purposes the forward pass and the loss are calculated using mixed precision. We use torch.cuda.amp.GradScalar object in order to scale the gradients of the loss. If the forward pass of a layer uses 16 bit precision, so will the backward pass. For some of the calculations the gradients will be relatively small and the precision of torch.float16 will not be sufficient to hold those small values. The values might therefore underflow. In order to remedy the problem, the loss is scaled and we let the scaler deal with backprop and gradient descent. At the end we reset the scaler object for the next batch. The three lines from below do exactly that.

  • scaler.scale(loss).backward()
  • scaler.step(optimizer)
  • scaler.update()
def optimized_train(data_loader, model, optimizer, criterion):
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(NUM_EPOCHS):
        start_time = time.time()
        losses = []
        for img, label in data_loader:
            img = img.to(DEVICE)
            label = label.to(DEVICE)
            with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
                prediction = model(img)
                loss = criterion(prediction, label)

        end_time = time.time()
        s = f'Epoch: {epoch+1}, ' \
          f'Loss: {sum(losses)/len(losses):.4f}, ' \
          f'Elapsed Time: {end_time-start_time:.2f}sec'
model = Model(cfg)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

We improve the training speed significantly. The overhead to use automatic mixed precision is inconsequential when compared to the benefits.

optimized_train(train_dataloader, model, optimizer, criterion)
Epoch: 1, Loss: 0.2699, Elapsed Time: 13.00sec
Epoch: 2, Loss: 0.0319, Elapsed Time: 12.95sec
Epoch: 3, Loss: 0.0206, Elapsed Time: 12.93sec
Epoch: 4, Loss: 0.0144, Elapsed Time: 12.95sec
Epoch: 5, Loss: 0.0117, Elapsed Time: 12.95sec
Epoch: 6, Loss: 0.0104, Elapsed Time: 12.96sec
Epoch: 7, Loss: 0.0083, Elapsed Time: 12.95sec
Epoch: 8, Loss: 0.0095, Elapsed Time: 13.01sec
Epoch: 9, Loss: 0.0053, Elapsed Time: 12.97sec
Epoch: 10, Loss: 0.0091, Elapsed Time: 12.99sec