Saving and Recovering a PyTorch Checkpoint During Training

In non-demo scenarios, training a neural network can take hours, days, weeks, or even longer. It’s not uncommon for machines to crash, so you should always save checkpoint information during training so that if your training machine crashes or hangs, you can recover without having to start from the beginning of training.

In pseudo-code, to save a state checkpoint every 100 training epochs:

set random number generators
create network
create training optimizer

for epoch in range(0, max_epochs):
  reset random seed
  for batch in range(0, bat_sz):
    use batch inputs to compute predicteds
    compare predicteds to targets to get loss
    use loss to update network weights

  if epoch mod 100 == 0:
    print loss to monitor progress
    get RNGs, network, optimizer states
    save states to file
print("training complete")

In code, one of zillions of possibilities is:

  if epoch % 100 == 0:
    print("epoch = %4d   loss = %0.4f" % (epoch, epoch_loss))

    # save training checkpoint
    dt = time.strftime("%Y_%m_%d-%H_%M_%S")
    fn = ".\\Log\\" + str(dt) + str("-") + str(epoch) +
      "_checkpoint.pt"

    info_dict = { 'epoch' : epoch,
      'torch_random_state' : T.random.get_rng_state(),
      'numpy_random_state' : np.random.get_state(),
      'net_state' : net.state_dict(),
      'optimizer_state' : optimizer.state_dict() }

    T.save(info_dict, fn)

The current date and time is used to create a year-month-day-hour-minute-second-epoch filename that looks like “2020_11_16-12_36_46-500_checkpoint.pt”. It’s usually a good idea to fetch and save the PyTorch and NumPy random number generator states because they can often be used in many different ways. The net.state_dict() method gets all the current weights and biases, and the optimizer.state_dict() method gets the optimizer state.

Unfortunately, there is no easy way to get the state of a PyTorch DataLoader object, so when you recover from a crash, your subsequent training will be a tiny bit different than if you hadn’t crashed. This usually doesn’t matter. You can avoid this and get reproducible results by resetting the PyTorch random number generator seed at the beginning of each epoch:

  net.train()  # or net = net.train()
  for epoch in range(0, max_epochs):
    T.manual_seed(1 + epoch)  # for recovery reproducibility
    epoch_loss = 0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch['predictors']  # inputs
      Y = batch['targets']     # correct class/label/job
      . . .


If your training machine crashes during training, you can recover by resetting the RNGs (if necessary), the network, the optimizer, and the training epoch with code like this:

  print("\nBegin recover failed training \n")
  fn = ".\\Log\\2020_11_16-12_36_46_checkpoint.pt"
  chkpt = T.load(fn)  # get saved states dictionary
  # reset RNGSs if needed
  np.random.set_state(chkpt['numpy_random_state'])
  T.random.set_rng_state(chkpt['torch_random_state'])
  . . .
  net = Net().to(device)  # reset network
  net.load_state_dict(chkpt['net_state'])
  . . .
  loss_func = T.nn.CrossEntropyLoss()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)
  optimizer.load_state_dict(chkpt['optimizer_state'])
  . . .
  epoch_saved = chkpt['epoch'] + 1
  for epoch in range(epoch_saved, max_epochs):
    T.manual_seed(1 + epoch)  # for recovery reproducibility
    epoch_loss = 0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      . . .

Essentially, your code is exactly like normal training code except that you adjust the RNGs, the network, the optimizer, and the starting epoch, using the saved states.

In the two screenshots above, you can see that I saved states every 100 epochs, then I used a second program to use the saved checkpoint after epoch 400 to restart from epoch 401 to simulate what would happen if training had crashed after epoch 400. Notice that the results are the same.


Disaster can strike unexpectedly when training a neural network. Disaster can also strike whenever you pose for a photograph with animals. Duck, goat, stingray.

This entry was posted in PyTorch. Bookmark the permalink.