Saving and Using a PyTorch Model in ONNX Format

ONNX (Open Neural Network Exchange) is a format for saving a neural network model. The idea is to be able to save a trained neural network, which was trained using any library, such as PyTorch or Keras or scikit-learn, in a universal format.

I’m skeptical about the viability of ONNX but ONNX is still immature so my opinion could change. There are several reasons for my skepticism. For example, Keras and Tensorflow don’t currently support ONNX officially (you must use an add-on tool) and if you save a PyTorch model in ONNX format, you currently can’t use the model from PyTorch — you must use the saved model using something else such as the separate ONNX Runtime library.

Anyway, this past weekend I sat down to code an example. My goal was to train a neural network model using PyTorch, save the trained model in ONNX format, then write a second program to use the saved ONNX model to make a prediction.

I had a neural network up and running that I was using for other experiments, so I used it for my ONNX experiment. My existing neural network was a binary classifier for the Banknote Authentication problem. The data has four input values from digital images of banknotes (euros I think), and the output is a single value between 0 and 1 where less than 0.5 means authentic and greater than 0.5 means forgery.

The key lines of code to save the trained model are:

# banknote_bnn.py
import torch as T
device = T.device("cpu")

(first, create and train a NN named 'net')

print("\nSaving trained model as ONNX \n")
path = ".\\Models\\banknote_onnx_model.onnx"
dummy = T.tensor([[0.5, 0.5, 0.5, 0.5]],
  dtype=T.float32).to(device)
T.onnx.export(net, dummy, path, input_names=["input1"],
  output_names=["output1"])

The model is saved using the export() function. You supply a dummy input value so the model knows what to expect for general input.



Here’s a second program that loads the saved ONNX model into memory and then uses it to make a prediction:

# banknote_use_onnx.py

import onnx          # installed 1.7.0 separately
import onnxruntime   # installed 1.4.0 separately
import numpy as np

def main():
  print("\nBegin ONNX from PyTorch demo \n")
  path = ".\\Models\\banknote_onnx_model.onnx"
  model = onnx.load(path)
  onnx.checker.check_model(model)

  sess = onnxruntime.InferenceSession(path)
  input_name = sess.get_inputs()[0].name
  print("input name", input_name)
  input_shape = sess.get_inputs()[0].shape
  print("input shape", input_shape)

  np.set_printoptions(precision=8, suppress=True)
  x = np.array([[0.5, 0.5, 0.5, 0.5]], dtype=np.float32)
  print("\nInput values to ONNX model = ")
  print(x)
  res = sess.run(None, {"input1": x})
  print("\nOuput value from ONNX model = ")
  print(res)

  print("\nEnd demo ")

if __name__ == "__main__":
  main()

The demo uses the onnx and onnxruntime libraries. The output of the saved ONNX model (0.00002071) is very close to, but not exactly the same as, the output from the source PyTorch model (0.00002068) due to rounding differences.

My hunch is that neural networks are just too complex for a universal format to be completely feasible. But I’ve been wrong before on issues like this.


All water clocks have round-off error. Left: A recreation of a clepsydra (“water thief”) designed by Greek inventor Ctesibius in the 3rd centrury BC. His clock was the most accurate for over 1,800 years, until the invention of the pendulum clock. Center: A design by artist Bernard Gitton. The clock is functional but is mostly a work of art. Right: Simple water clocks are essentially hourglasses with water instead of sand.

This entry was posted in PyTorch. Bookmark the permalink.

1 Response to Saving and Using a PyTorch Model in ONNX Format

  1. Thorsten Kleppe's avatar Thorsten Kleppe says:

    Hello James,

    I’ve created a full production demo in C# for you that runs and test a well trained neural network for the MNIST training and test data.

    The cool thing is, it needs only the information of the neural network like 784,16,16,10 and the related training weights. You can examine different networks and their weights in the saved txt files.

    I call it level 0 = neural network + trained weights

    One problem to save the network was the float precision.
    My solution was that function:

    void NeuralNetworkSave(string fullPath)
    {
    string[] weightString = new string[weightLen + 1];
    weightString[0] = string.Join(“,”, u); // neural network at first line
    for (int i = 1; i less then weightLen + 1; i++)
    weightString[i] = ((decimal)((double)weight[i – 1])).ToString(); // for precision
    File.WriteAllLines(fullPath, weightString);
    }

    I’ve so many questions, how many layer can a neural network train or how many hidden neurons did a neural network need to work with ten output classes. You can find some cool examples of trained neural networks inside that shows what I found out.
    https://github.com/grensen/goodgameExportDemo

    Hope that’s a bam for you. 🙂

Comments are closed.