Saving a PyTorch Model

Bottom line: There are two main ways to save a trained PyTorch neural model. You should use the newer “state_dict” approach rather than the older “full” approach.

The recommended way to save a PyTorch model looks like:

import torch as T

class Net():
  # define neural network here

def main():
  net = Net()  # create
  # train network

  path = ".\\Models\\my_model.pth"
  T.save(net.state_dict(), path)

if __name__ == "__main__":
  main()

Then to use the saved model in another file:

import torch as T

class Net():
  # exactly the same as above

def main():
  print("\nLoad using state_dict approach (preferred)")
  path = ".\\Models\\my_model.pth"
  model = Net()
  model.load_state_dict(T.load(path))
  
  # use the model to make predictions

if __name__ == "__main__":
  main()

The older approach looks very similar:

class Net():
  # define neural network here

# save old way (not preferred)
path = ".\\Models\\my_model.pth"
T.save(net, path)

# in another file:
class Net():
  # exactly the same as above

path = ".\\Models\\my_model.pth"
model = T.load(path)

You have to look at the code very carefully to see the differences between the old way and the newer state_dict approach. Notice that in both techniques, you must have the class definition of the neural network in the file that saves the model, and also in the file that loads the model.

I won’t try to explain why the newer state_dict approach is preferred — it’s really low-level details.

Just for fun, I coded up three complete working PyTorch programs to demonstrate. The first program creates a dummy neural network, computes an example output, and saves the model using both the state_dict way and also the older “full” way. The second program loads the state_dict model and computes an example output. The third program loads the older-format model and computes an example output. All three output values are the same.

In addition to saving a PyTorch model using the two ways I’ve explained here, you can also save a PyTorch model using the ONNX format, which I don’t recommend at this time. I’ll explain ONNX in another blog post sometime. Briefly, ONNX is new and still immature (so ONNX is not fully supported), and you can’t even run a saved ONNX model using PyTorch (you have to use an entirely different system to run the saved model).





Three (fashion) models saved (via photography). The photos were taken by Nina Leen (1910 – 1995) who was a famous photographer and was best known for her contributions to Life Magazine. Life Magazine was one of the most important means of communication in the world, especially from the years 1936 – 1972. These three old photos of models from the 1950s hold up very well today in my opinion.


This entry was posted in PyTorch. Bookmark the permalink.