Computing and Displaying a Fancy Confusion Matrix for a PyTorch Multi-Class Classifier

One morning, my brain was tired. I had just finished a mini-project that had taken many days of intense concentration. I decided to give my brain a little rest by looking at a problem that isn’t difficult but would keep my programming skills sharp. Specifically, I put together a pair of functions to compute and display a fancy confusion matrix for a PyTorch multi-class classifier.

By fancy, I mean a confusion matrix that has counts of correct and incorrect classifications, but which also has accuracy values too. For example, for a PyTorch model with three classes, a fancy confusion matrix might look like:

actual     0:  6  4  1  =   0.5455
actual     1:  1 13  0  =   0.9286
actual     2:  2  2 11  =   0.7333
------------
predicted      0  1  2  =   0.7500

For class 0 items, 6 were correctly predicted, 4 were incorrectly predicted as class 1, 1 was incorrectly predicted as class 2. So, the accuracy for class 0 items is 6/11 = 0.5455.

The overall accuracy is (6 + 13 + 11) / 40 = 30/40 = 0.7500.

I wrote a confusion_matrix_accs() function that iterates through a PyTorch Dataset and computes 1.) a confusion matrix, 2.) the overall accuracy, 3.) a vector of accuracies by class, and then returns those three objects as a Tuple:

def confusion_matrix_accs(model, ds, n_classes):
  # compute confusion matrix, overall acc, and acc-by-class 
  if n_classes <= 2:  # less-than-or-equal
    print("ERROR: n_classes must be 3 or greater ")
    return None

  cm = np.zeros((n_classes,n_classes), dtype=np.int64)
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # actual class 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form
    pred_class = T.argmax(oupt)  # 0,1,2
    cm[Y][pred_class] += 1

  # overall acc
  N = np.sum(cm)  # total count
  n_correct = np.sum(np.diag(cm))  # using diag()
  overall = n_correct / N

  # acc by-class
  accs = np.zeros(n_classes, dtype=np.float32)  # by class
  row_sums = cm.sum(axis=1)
  for i in range(n_classes):
    accs[i] = cm[i][i] / row_sums[i]

  return (cm, overall, accs)  # as a Tuple

Next, I wrote a show_confusion_accs() function that accepts the return Tuple from the confusion_matrix_accs() function and displays the results:

def show_confusion_accs(cm_accs):
  # display confusion matrix with accuracies
  cm = cm_accs[0]        # the confusion matrix of counts
  overall = cm_accs[1]   # overall accuracy
  accs = cm_accs[2]      # accuracies by class

  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
    print("  = %8.4f" % accs[i], end="")
    print("")
  print("------------")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")
  print("  = %8.4f" % overall, end="")
  print("")

I used one of my standard synthetic datasets for my demo. The data looks like:

 1   0.24   1  0  0   0.2950   2
-1   0.39   0  0  1   0.5120   1
 1   0.63   0  1  0   0.7580   0
-1   0.36   1  0  0   0.4450   1
. . .

The fields are sex (male = -1, female = +1), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by $100,000), and political leaning (0 = conservative, moderate = 1, liberal = 2). The goal is to predict political leaning from sex, age, State, and income. There are 200 training items and 40 test items.

There’s no real moral to this blog post, but as I was coding I remembered (as I’ve known for a long time) that I don’t get a whole lot of satisfaction dealing with UI programming; I get much more satisfaction by creating algorithms that produce the information that can be dealt with by UI functions.



The book “Alice’s Adventures in Wonderland” (aka “Alice in Wonderland”) was written in 1865 by Lewis Carroll, a mathematics professor at Oxford University. Alice falls through a rabbit hole into a fantasy world of confusion and chaos. I like the book and various movie adaptations of Alice, but all of them make me feel vaguely uncomfortable — I like order and logic.

The Alice theme is very popular with artists. Left: By artist “Zolaida” Right: By artist “Renata-S-Art”.


Demo code. Replace “lt” (less-than) , “gt”, etc. with Boolean operator symbols. The data can be found at: https://jamesmccaffreyblog.com/2022/09/01/multi-class-classification-using-pytorch-1-12-1-on-windows-10-11/.

# people_politics_confusion.py
# predict politics type from sex, age, state, income

# fancy confusion matrix with accuracies

# PyTorch 2.0.0-CPU Anaconda3-2022.10  Python 3.9.13
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex  age    state    income   politics
  # -1   0.27   0  1  0   0.7610   2
  # +1   0.19   0  0  1   0.6550   0
  # sex: -1 = male, +1 = female
  # state: michigan, nebraska, oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,7),
      delimiter="\t", comments="#", dtype=np.float32)
    tmp_x = all_xy[:,0:6]   # cols [0,6) = [0,5]
    tmp_y = all_xy[:,6]     # 1-D

    self.x_data = T.tensor(tmp_x, 
      dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y,
      dtype=T.int64).to(device)  # 1-D

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return preds, trgts  # as a Tuple

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss() 
    return z

# -----------------------------------------------------------

def accuracy_quick(model, dataset):
  # assumes model.eval()
  X = dataset[0:len(dataset)][0]
  # Y = T.flatten(dataset[0:len(dataset)][1])
  Y = dataset[0:len(dataset)][1]
  with T.no_grad():
    oupt = model(X)  #  [40,3]  logits

  # (_, arg_maxs) = T.max(oupt, dim=1)
  arg_maxs = T.argmax(oupt, dim=1)  # argmax() is new
  num_correct = T.sum(Y==arg_maxs)
  acc = (num_correct * 1.0 / len(dataset))
  return acc.item()

# -----------------------------------------------------------

def confusion_matrix_accs(model, ds, n_classes):
  # compute confusion matrix, overall acc, and acc-by-class 
  if n_classes "lte" 2:  # less-than-or-equal
    print("ERROR: n_classes must be 3 or greater ")
    return None

  cm = np.zeros((n_classes,n_classes), dtype=np.int64)
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # actual class 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form
    pred_class = T.argmax(oupt)  # 0,1,2
    cm[Y][pred_class] += 1

  # overall acc
  N = np.sum(cm)  # total count
  n_correct = np.sum(np.diag(cm)) 
  overall = n_correct / N

  # acc by-class
  accs = np.zeros(n_classes, dtype=np.float32)  # by class
  row_sums = cm.sum(axis=1)
  for i in range(n_classes):
    accs[i] = cm[i][i] / row_sums[i]

  return (cm, overall, accs)

# -----------------------------------------------------------

def show_confusion_accs(cm_accs):
  # display confusion matrix with accuracies
  cm = cm_accs[0]        # the confusion matrix of counts
  overall = cm_accs[1]   # overall accuracy
  accs = cm_accs[2]           # accuracies by class

  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
    print(" | %8.4f" % accs[i], end="")
    print("")
  print("------------")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")
  print(" | %8.4f" % overall, end="")
  print("")

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin People predict politics type ")
  T.manual_seed(1)
  np.random.seed(1)
  
  # 1. create DataLoader objects
  print("\nCreating People Datasets ")

  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)  # 200 rows

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)    # 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

# -----------------------------------------------------------

  # 2. create network
  print("\nCreating 6-(10-10)-3 neural network ")
  net = Net().to(device)
  net.train()

# -----------------------------------------------------------

  # 3. train model
  max_epochs = 1000
  ep_log_interval = 200
  lrn_rate = 0.01

  loss_func = T.nn.NLLLoss()  # assumes log_softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training")
  for epoch in range(0, max_epochs):
    # T.manual_seed(epoch+1)  # checkpoint reproducibility
    epoch_loss = 0  # for one full epoch

    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # inputs
      Y = batch[1]  # correct class/label/politics

      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %5d  |  loss = %10.4f" % \
        (epoch, epoch_loss))

  print("Training done ")

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy")
  net.eval()
  acc_train = accuracy_quick(net, train_ds)  # item-by-item
  print("Accuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy_quick(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

  print("\nFancy confusion matrix: ")
  cm_accs = confusion_matrix_accs(net, test_ds, n_classes=3)
  show_confusion_accs(cm_accs)

  # 5. make a prediction
  print("\nPredicting politics for M  30  oklahoma  $50,000: ")
  X = np.array([[-1, 0.30,  0,0,1,  0.5000]], dtype=np.float32)
  X = T.tensor(X, dtype=T.float32).to(device) 

  with T.no_grad():
    logits = net(X)  # do not sum to 1.0
  probs = T.exp(logits)  # sum to 1.0
  probs = probs.numpy()  # numpy vector prints better
  np.set_printoptions(precision=4, suppress=True)
  print(probs)

  # 6. TODO: save model (state_dict approach)
 
  print("\nEnd People predict politics demo")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.

2 Responses to Computing and Displaying a Fancy Confusion Matrix for a PyTorch Multi-Class Classifier

  1. Thorsten Kleppe's avatar Thorsten Kleppe says:

    How do you like this somewhat verbose confusion matrix (CF)? It only needs rectangles and text in this case. Then a MouseDown method when the mouse is clicked and the position of the mouse pointer on the canvas. Then the rest was just work. But you have to go through it once.

    https://github.com/grensen/ML-Art/blob/master/2023/confusion_matrix_verbose.png

    The left side shows the CF for the predictions of a trained neural network for the MNIST test dataset with 10000 samples. The blue tiles show the incorrectly predicted examples, the top value shows the number of incorrectly predicted examples, the next value shows the percentage of wrong predictions for the respective target class and the last value shows the loss. Uff.

    By clicking on the tiles, the corresponding input examples are displayed on the right side. Below this, the respective prediction probability for each class is shown by a bar. The biggest problems seem to occur between prediction = 9 and target = 4, and between prediction = 4 and target = 9.

    I hope it motivates you at least a little bit when you have to deal with a UI again. Maybe you just need to build your UI from scratch. Even if it wasn’t that easy, the work was probably worth it, because you will never be able to create this image via console in such detail.

    And under each of your posts you show again and again that it needs pictures.

Comments are closed.