Computing and Displaying a Confusion Matrix for a PyTorch Neural Network Binary Classifier

After training a PyTorch binary classifier, it’s important to evaluate the accuracy of the trained model. Simple classification accuracy is OK but in many scenarios you want a so-called confusion matrix that gives details of the number of correct and wrong predictions for each of the two target classes. You also want precision, recall, and F1 metrics.

For example, suppose you’re predicting the sex (0 = male, 1 = female) of a person based on their age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by $100,000), and political leaning (conservative = 100, moderate = 010, liberal = 001). An example of a formatted confusion matrix and metrics computed from the matrix might look like:

Computing confusion matrix

actual     0: 21  5
actual     1:  1 13
------------
predicted      0  1

Computing metrics from confusion
acc = 0.8500 pre = 0.7222 rec = 0.9286 f1 = 0.8125

Here’s my function to compute a raw confusion matrix for a binary classifier:

def confusion_matrix_bin(model, ds, n_classes):
  if n_classes != 2:
    print("ERROR n_classes must be 2 ")
    return None

  cm = np.zeros((n_classes,n_classes), dtype=np.int64)
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # actual class 0.0 or 1.0, 1D
    Y = Y.type(T.int64)      # make it an int/index
    with T.no_grad():
      oupt = model(X)  # logits form
    if oupt "lt" 0.5: pred_class = 0
    else: pred_class = 1

    cm[Y][pred_class] += 1
  return cm

The function accepts a trained PyTorch classifier and a PyTorch Dataset object that is composed of either a Tuple or a Dictionary where the predictors are at [0] and the target labels are at [1]. The n_classes could be determined programmatically but it’s easier to pass that value in as a parameter.

Note: A function to compute a confusion matrix for a multi-class classifier, where there are three or more possible outcomes, uses slightly different code. See https://jamesmccaffreyblog.com/2023/03/15/computing-and-displaying-a-confusion-matrix-for-a-pytorch-neural-network-classifier/

The raw confusion matrix is difficult to interpret so I wrote a function to format the matrix by adding some labels:

def show_confusion(cm):
  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
    print("")
  print("------------")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")
  print("")

If you have a binary confusion matrix, you can compute precision, recall, and F1 score from it. The true negatives are at [0][0]. The true positives are at [1][1]. The false positives (“incorrectly predicted as positive/1”) are at [0][1]. The false negatives (“incorrectly predicted as negative/0”) are at [1][0].

Note that in a binary classification scenario, which outcome you specify is positive and which outcome is negative is arbitrary, so the meanings/values of precision and recall are arbitrary too. This is one reason why the F1 score, the harmonic mean of precision and recall, is often used.

Here is a function that computes accuracy, precision, recall and F1 from a raw binary confusion matrix. It assumes a particular geometry of the matrix.

def metrics_from_confusion_bin(cm):
  # return (accuracy, precision, recall, F1)
  N = 0  # total count
  dim = len(cm)
  for i in range(dim):
    for j in range(dim):
      N += cm[i][j]
  n_correct = 0
  for i in range(dim):
    n_correct += cm[i][i]  # on the diagonal
  acc = n_correct / N

#        pred 0  pred 1
# act 0    tn      fp
# act 1    fn      tp

  tp = cm[1][1]
  tn = cm[0][0]
  fp = cm[0][1]  # falsely predicted as positive
  fn = cm[1][0]  # falsely predicted as negative
  pre = tp / (tp + fp)
  rec = tp / (tp + fn)
  f1 = 1.0 / ( ((1.0 / pre) + (1.0 / rec)) / 2.0 )
  
  return (acc, pre, rec, f1)

Good fun. Demo code below. Replace “lt”, “gt”, “lte”, “gte” with Boolean operator symbols (my lame blog editor chokes on symbols). The training and test data can be found at https://jamesmccaffreyblog.com/2022/09/23/binary-classification-using-pytorch-1-12-1-on-windows-10-11/.



There is all kinds of social science research about the many biological and behavioral differences between men and women. For example, women will often tilt their heads down and look up. According to the research, this pose is appealing and shows vulnerability. Men do not ever use such a posture, at least none of the guys I know. Facts like this are interesting but they’re not terribly useful for prediction.


# people_gender.py
# binary classification

# confusion matrix and metrics demo

# PyTorch 1.12.1-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

class PeopleDataset(T.utils.data.Dataset):
  # sex age   state    income  politics
  #  0  0.27  0  1  0  0.7610  0 0 1
  #  1  0.19  0  0  1  0.6550  1 0 0
  # sex: 0 = male, 1 = female
  # state: michigan, nebraska, oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    all_data = np.loadtxt(src_file, usecols=range(0,9),
      delimiter=",", comments="#", dtype=np.float32) 

    self.x_data = T.tensor(all_data[:,1:9],
      dtype=T.float32).to(device)
    self.y_data = T.tensor(all_data[:,0],
      dtype=T.float32).to(device)  # float32 required

    self.y_data = self.y_data.reshape(-1,1)  # 2-D required

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    feats = self.x_data[idx,:]  # idx row, all 8 cols
    sex = self.y_data[idx,:]    # idx row, the only col
    return feats, sex  # as a Tuple

# ---------------------------------------------------------

def metrics(model, ds, thresh=0.5):
  # compute accuracy, precision, recall, F1 directly
  # note: N = total number of items = TP + FP + TN + FN
  # accuracy  = (TP + TN)  / N
  # precision = TP / (TP + FP)
  # recall    = TP / (TP + FN)
  # F1        = 2 / [(1 / precision) + (1 / recall)]

  tp = 0; tn = 0; fp = 0; fn = 0
  for i in range(len(ds)):
    inpts = ds[i][0]         # dictionary style
    target = ds[i][1]        # float32  [0.0] or [1.0]
    target = target.type(T.int64)  # make it an int
    with T.no_grad():
      p = model(inpts)       # between 0.0 and 1.0

    # should really avoid 'target == 1.0'
    if target == 1 and p "gte" thresh:    # TP
      tp += 1
    elif target == 1 and p "lt" thresh:   # FN
      fn += 1
    elif target == 0 and p "lt" thresh:   # TN
      tn += 1
    elif target == 0 and p "gte" thresh:  # FP
      fp += 1

  N = tp + fp + tn + fn
  if N != len(ds):
    print("FATAL LOGIC ERROR in metrics()")

  accuracy = (tp + tn) / (N * 1.0)
  precision = (1.0 * tp) / (tp + fp)
  recall = (1.0 * tp) / (tp + fn)
  f1 = 2.0 / ((1.0 / precision) + (1.0 / recall))
  return (accuracy, precision, recall, f1)  # as a Tuple

# -----------------------------------------------------------

def confusion_matrix_bin(model, ds, n_classes):
  if n_classes != 2:
    print("ERROR n_classes must be 2 ")
    return None

  cm = np.zeros((n_classes,n_classes), dtype=np.int64)
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # actual class 0.0 or 1.0, 1D
    Y = Y.type(T.int64)
    with T.no_grad():
      oupt = model(X)  # logits form
    if oupt "lt" 0.5: pred_class = 0
    else: pred_class = 1

    cm[Y][pred_class] += 1
  return cm

# -----------------------------------------------------------

def show_confusion(cm):
  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
    print("")
  print("------------")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")
  print("")

# -----------------------------------------------------------  

def metrics_from_confusion_bin(cm):
  # return (accuracy, precision, recall, F1)
  N = 0  # total count
  dim = len(cm)
  for i in range(dim):
    for j in range(dim):
      N += cm[i][j]
  n_correct = 0
  for i in range(dim):
    n_correct += cm[i][i]  # on the diagonal
  acc = n_correct / N

#        pred 0  pred 1
# act 0    tn      fp
# act 1    fn      tp

  tp = cm[1][1]
  tn = cm[0][0]
  fp = cm[0][1]  # falsely predicted as positive
  fn = cm[1][0]  # falsely predicted as negative
  pre = tp / (tp + fp)
  rec = tp / (tp + fn)
  f1 = 1.0 / ( ((1.0 / pre) + (1.0 / rec)) / 2.0 )
  
  return (acc, pre, rec, f1)

# ----------------------------------------------------------- 

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(8, 10)  # 8-(10-10)-1
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 1)

    T.nn.init.xavier_uniform_(self.hid1.weight) 
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight) 
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight) 
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    z = T.sigmoid(self.oupt(z))  # for BCELoss()
    return z

# ----------------------------------------------------------

def main():
  # 0. get started
  print("\nPeople gender using PyTorch ")
  T.manual_seed(1)
  np.random.seed(1)

  # 1. create Dataset and DataLoader objects
  print("\nCreating People train and test Datasets ")

  train_file = ".\\Data\\people_train.txt"
  test_file = ".\\Data\\people_test.txt"

  train_ds = PeopleDataset(train_file)  # 200 rows
  test_ds = PeopleDataset(test_file)    # 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create neural network
  print("\nCreating 8-(10-10)-1 binary NN classifier \n")
  net = Net().to(device)

  # 3. train network
  net.train()  # set training mode
  lrn_rate = 0.01
  loss_func = T.nn.BCELoss()  # binary cross entropy
  optimizer = T.optim.SGD(net.parameters(),
    lr=lrn_rate)
  max_epochs = 500
  ep_log_interval = 100

  print("Loss function: " + str(loss_func))
  print("Optimizer: " + str(optimizer.__class__.__name__))
  print("Learn rate: " + "%0.3f" % lrn_rate)
  print("Batch size: " + str(bat_size))
  print("Max epochs: " + str(max_epochs))

  print("\nStarting training")
  for epoch in range(0, max_epochs):
    epoch_loss = 0.0            # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0]             # [bs,4]  inputs
      Y = batch[1]             # [bs,1]  targets
      oupt = net(X)            # [bs,1]  computeds 

      loss_val = loss_func(oupt, Y)   # a tensor
      epoch_loss += loss_val.item()  # accumulate
      optimizer.zero_grad() # reset all gradients
      loss_val.backward()   # compute new gradients
      optimizer.step()      # update all weights

    if epoch % ep_log_interval == 0:
      print("epoch = %4d   loss = %8.4f" % \
        (epoch, epoch_loss))
  print("Done ")

# ----------------------------------------------------------

  # 4. evaluate model
  net.eval()
  metrics_train = metrics(net, train_ds, thresh=0.5)
  print("\nMetrics for train data: ")
  print("accuracy  = %0.4f " % metrics_train[0])
  print("precision = %0.4f " % metrics_train[1])
  print("recall    = %0.4f " % metrics_train[2])
  print("F1        = %0.4f " % metrics_train[3])

  metrics_test = metrics(net, test_ds, thresh=0.5)
  print("\nMetrics for test data: ")
  print("accuracy  = %0.4f " % metrics_test[0])
  print("precision = %0.4f " % metrics_test[1])
  print("recall    = %0.4f " % metrics_test[2])
  print("F1        = %0.4f " % metrics_test[3])

  print("\nComputing confusion matrix ")
  cm = confusion_matrix_bin(net, test_ds, n_classes=2)
  # print(cm) # raw matrix
  show_confusion(cm)

  print("\nComputing metrics from confusion ")
  (acc, pre, rec, f1) = metrics_from_confusion_bin(cm)
  print("acc = %0.4f pre = %0.4f rec = %0.4f f1 = %0.4f " % \
    (acc, pre, rec, f1))

  # 5. save model
  print("\nSaving trained model state_dict ")
  # path = ".\\Models\\people_model.pt"
  # T.save(net.state_dict(), path)

  # 6. make a prediction 
  print("\nSetting age = 30  Oklahoma  $40,000  moderate")
  inpt = np.array([[0.30, 0,0,1, 0.40, 0,1,0]],
    dtype=np.float32)
  inpt = T.tensor(inpt, dtype=T.float32).to(device)

  net.eval()
  with T.no_grad():
    oupt = net(inpt)    # a Tensor
  pred_prob = oupt.item()  # scalar, [0.0, 1.0]
  print("Computed output: ", end="")
  print("%0.4f" % pred_prob)

  if pred_prob "lt" 0.5:
    print("Prediction = male")
  else:
    print("Prediction = female")

  print("\nEnd People binary demo ")

if __name__== "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.