“Multi-Class Classification Using New PyTorch Best Practices, Part 2: Training, Accuracy, Predictions” in Visual Studio Magazine

I wrote an article titled “Multi-Class Classification Using New PyTorch Best Practices, Part 2: Training, Accuracy, Predictions” in the September 2022 edition of Microsoft Visual Studio Magazine. See https://visualstudiomagazine.com/articles/2022/09/12/multi-class-pytorch-2.aspx.

The article is the second in a two-part series that explains how to create a PyTorch multi-class classifier system. The article demo program predicts the political leaning (conservative, moderate, liberal) of a person. The first article in the series explained how to prepare the training and test data, and how to define the neural network classifier. The second article explains how to train the network, compute the accuracy of the trained network, use the network to make predictions, and save the network for use by other programs.

The demo begins by loading a 200-item file of training data and a 40-item set of test data. Each tab-delimited line represents a person. The fields are sex, age, state of residence (Michigan, Nebraska or Oklahoma), annual income and politics type (0 = conservative, 1 = moderate, 2 = liberal). The goal is to predict politics type from sex, age, state and income.

After 1,000 training epochs, the demo program computes the accuracy of the trained model on the training data as 81.50 percent (163 out of 200 correct). The model accuracy on the test data is 75.00 percent (30 out of 40 correct).

After evaluating the trained network, the demo predicts the politics type for a person who is male, 30 years old, from Oklahoma and who makes $50,000 annually. The prediction is [0.6905, 0.3049, 0.0047]. These values are pseudo-probabilities. The largest value (0.6905) is at index [0] so the prediction is class 0 = conservative.

The demo concludes by saving the trained model to file so that it can be used without having to retrain the network from scratch. There are two different ways to save a PyTorch model. The demo uses the save-state approach.

The three basic types of PyTorch systems are multi-class classification, binary classification, and regression. Multi-class classification is used when the variable to predict has three or more possible values. When the variable to predict has just two possible values, the problem is called binary classification. Binary classification uses techniques that are different from multi-class classification. Regression problems are ones where the goal is to predict a single numeric value, such as annual income.



Gender and politics.


This entry was posted in PyTorch. Bookmark the permalink.