Computing Calibration Error for a PyTorch Multi-Class Classifier

Suppose you have a PyTorch multi-class classification model where the goal is to predict the political leaning (conservative = 0, moderate = 1, liberal = 2) of a person based on predictor variables such as sex, age, State, and income.

The output of the model is a vector of three logits (log values) where the index of largest logit value indicates the predicted class. If you apply the exp() function to the logits, you get a vector of three pseudo-probabilities that sum to 1.0 and these can loosely be interpreted as the probabilities of each class.

Output pseudo-probability values are sometimes called confidence values or just probabilities. I’ll use the term pseudo-probabilities (PPs).

A machine learning multi-class classification model is well-calibrated if the output largest pseudo-probabilities closely reflect the model accuracies. In other words, if the output pseudo-probabilities for a person is (0.10, 0.75, 0.15) — strongly indicating class [1] = moderate — then you’d like there to be roughly a 75% chance that the model is correct. Or if the output pseudo-probabilities are (0.3332, 0.3332, 0.3336) — weakly indicating class [2] = liberal — then you’d like there to be roughly a 33.36% chance that the model is correct.

The first step in dealing with model calibration is measuring it. Small values of CE indicate a model that is well-calibrated; larger values of CE indicate a model that is less well-calibrated. There are several ways to measure multi-class classification model calibration. Here I show one version that I use and a PyTorch implementation. To the best of my knowledge, it’s the simplest possible approach.

First, you create 10 bins for the PP values: [0.0 to 0.1), [0.1 to 0.2), . . . [0.9 to 1.0]. The number of bins is arbitrary but 10 works well in practice.

Now suppose that you have 100 training items. Each item will generate a largest pseudo-probability that determines the predicted class. And suppose that 6 of the 100 items emit a largest PP between 0.4 to 0.5 for bin [4], and each item will be a correct prediction or not like so:

Item   PPs                 PP        Correct?
=============================================
[65]  (0.44, 0.40, 0.16)   0.44      correct
[ 8]  (0.20, 0.32, 0.48)   0.48      correct
[42]  (0.32, 0.42, 0.26)   0.42      wrong
[90]  (0.28, 0.44, 0.28)   0.44      wrong
[36]  (0.45, 0.39, 0.16)   0.45      correct
[21]  (0.13, 0.40, 0.47)   0.47      correct

The average PP for this bin [4] is (0.44 + 0.48 + . . + 0.47) / 6 = 0.45 and the accuracy for the bin is 4 / 6 = 0.67.

You repeat this process for each of the 10 bins. Then you compute the absolute value of the difference between the PP and the accuracy for each bin. Suppose you get:

Bin             Count     PP     Accuracy   abs(PP – Acc)
==========================================================

0  0.0 to 0.1     0               0.00       0.00
1  0.1 to 0.2     0               0.00       0.00
2  0.2 to 0.3     0               0.00       0.00
3  0.3 to 0.4    15      0.36     0.73       0.37 
4  0.4 to 0.5     6      0.45     0.67       0.22
5  0.5 to 0.6    18      0.52     0.72       0.20
6  0.6 to 0.7    20      0.66     0.60       0.06
7  0.7 to 0.8     8      0.70     0.75       0.05
8  0.8 to 0.9    22      0.83     0.64       0.19
9  0.9 to 1.0    11      0.91     0.55       0.36
                ---
                100

Notice that for this example, with 3 classes, the smallest possible PP value will be 0.3334 and so bins [0], [1], [2] will not have any items.

And now to compute calibration error for the model, you compute the weighted average of the values in the last column:

CE = [(0 * 0.00) + (0 * 0.00) + (0 * 0.00) + (15 * 0.37) +
      (6 * 0.22) + (18 * 0.20) + (20 * 0.06) + (8 * 0.05) +
      (22 * 0.19) + (11 * 0.36)] / 100
   = 0.2021

This is an overall measure how how well the PP values match the model accuracy. This example CE of 0.2021 is moderate meaning the model isn’t well calibrated, but it’s not terrible.

I implemented this version of calibration error for a PyTorch multi-class neural network. The code and data are below.

Good fun!



A miscalibrated binary classification prediction system is one thing. A miscalibrated military aircraft is something entirely different.

Left: On February 23, 2008, a U.S. Air Force B-2 bomber crashed during take-off at Anderson Air Base in Guam. Water condensation from heavy rain entered skin-flush air-data sensors that are used to calculate airspeed and altitude. Moisture! Both pilots ejected safely. To date, it is the only B-2 that has crashed. (This is a simulation image).

Right: On May 15, 2020, a U.S. Air Force F-22 fighter crashed shortly after take-off from the Elgin Air Base in Florida. The day before the crash, maintenance men washed the plane but accidentally left a piece of tape over a data sensor that is used for in-flight calibration. A piece of tape! The pilot ejected safely. To date, it is one of just six F-22s that have crashed. (This is the plane that crashed in a photo taken earlier, TN 06-4109).


Demo program. Replace “lt” (less than), “gt”, “lte”, “gte” with Boolean operator symbols. My blog editor often chokes on symbols.

# people_politics_calibration_error.py
# predict politics type from sex, age, state, income
# PyTorch 2.2.1-CPU Anaconda3-2023.09  Python 3.11.5
# Windows 10/11 

# compute accuracy by class

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex  age    state    income   politics
  # -1   0.27   0  1  0   0.7610   2
  # +1   0.19   0  0  1   0.6550   0
  # sex: -1 = male, +1 = female
  # state: michigan, nebraska, oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,7),
      delimiter=",", comments="#", dtype=np.float32)
    tmp_x = all_xy[:,0:6]   # cols [0,6) = [0,5]
    tmp_y = all_xy[:,6]     # 1-D

    self.x_data = T.tensor(tmp_x, 
      dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y,
      dtype=T.int64).to(device)  # 1-D

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return preds, trgts  # as a Tuple

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss() 
    return z

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  # item-by-item version
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def accuracy_quick(model, dataset):
  # assumes model.eval()
  X = dataset[0:len(dataset)][0]
  # Y = T.flatten(dataset[0:len(dataset)][1])
  Y = dataset[0:len(dataset)][1]
  with T.no_grad():
    oupt = model(X)  #  [40,3]  logits

  # (_, arg_maxs) = T.max(oupt, dim=1)
  arg_maxs = T.argmax(oupt, dim=1)  # argmax() is new
  num_correct = T.sum(Y==arg_maxs)
  acc = (num_correct * 1.0 / len(dataset))
  return acc.item()

# -----------------------------------------------------------

def acc_by_class(model, dataset, n_classes):
  n_corrects = np.zeros(n_classes, dtype=np.int64)
  n_wrongs = np.zeros(n_classes, dtype=np.int64)
  counts = np.zeros(n_classes, dtype=np.int64)

  for i in range(len(dataset)):
    X = dataset[i][0].reshape(1,-1)  # make it a batch
    Y = dataset[i][1].reshape(1)  # 0 1 or 2, 1D
    counts[Y.item()] += 1
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      n_corrects[Y.item()] += 1
    else:
      n_wrongs[Y.item()] += 1

  print("Counts     : ", end="")
  for c in range(n_classes): 
    print("%8d" % counts[c], end="")
  print("")

  print("Correct    : ", end="")
  for c in range(n_classes): 
    print("%8d" % n_corrects[c], end="")
  print("")

  print("Wrong      : ", end="")
  for c in range(n_classes): 
    print("%8d" % n_wrongs[c], end="")
  print("")

  accuracies = n_corrects / counts
  print("Accuracies : ", end="")
  for c in range(n_classes): 
    print("%8.4f" % accuracies[c], end="")
  print("")
    
# -----------------------------------------------------------

def calibration_error(model, ds):
  counts = np.zeros(10, dtype=np.int64)  # of PPs each bin
  sums = np.zeros(10, dtype=np.float32)  # of PPs each bin
  n_corrects = np.zeros(10, dtype=np.int64)  # for each bin
  n_wrongs = np.zeros(10, dtype=np.int64)  # not needed
  accuracies = np.zeros(10, dtype=np.float32)  # each bin
  avg_pps = np.zeros(10, dtype=np.float32)
  abs_diffs = np.zeros(10, dtype=np.float32)

  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits like (-1.2, -3.7, -0.1)

    probs = T.exp(oupt) # probs like (0.25, 0.60, 0.15)
    p_max = T.max(probs)  # largest PP like tensor[0.60]
    pp = p_max.item()  # scalar like 0.60

    correct = False
    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      correct = True
 
    if pp "gte" 0.0 and pp "lt" 0.1: bin = 0
    elif pp "gte" 0.1 and pp "lt" 0.2: bin = 1
    elif pp "gte" 0.2 and pp "lt" 0.3: bin = 2
    elif pp "gte" 0.3 and pp "lt" 0.4: bin = 3
    elif pp "gte" 0.4 and pp "lt" 0.5: bin = 4
    elif pp "gte" 0.5 and pp "lt" 0.6: bin = 5
    elif pp "gte" 0.6 and pp "lt" 0.7: bin = 6
    elif pp "gte" 0.7 and pp "lt" 0.8: bin = 7
    elif pp "gte" 0.8 and pp "lt" 0.9: bin = 8
    elif pp "gte" 0.9 and pp "lte" 1.0: bin = 9

    counts[bin] += 1
    sums[bin] += pp
    if correct == True: n_corrects[bin] += 1
    elif correct == False: n_wrongs[bin] += 1  # check

  for bin in range(10):
    if counts[bin] == 0: accuracies[bin] = 0.0
    else: accuracies[bin] = n_corrects[bin] / counts[bin]

  for bin in range(10):
    if counts[bin] == 0: avg_pps[bin] = 0.0
    else: avg_pps[bin] = sums[bin] / counts[bin]

  for bin in range(10):
    abs_diffs[bin] = \
      np.abs(avg_pps[bin] - accuracies[bin]) 

  cal_err = 0.0
  for bin in range(10):
    cal_err += counts[bin] * abs_diffs[bin]  # weighted
  cal_err /= len(ds)
  return cal_err

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin People predict politics type ")
  T.manual_seed(1)
  np.random.seed(1)
  
  # 1. create DataLoader objects
  print("\nCreating People Datasets ")

  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)  # 200 rows

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)    # 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

# -----------------------------------------------------------

  # 2. create network
  print("\nCreating 6-(10-10)-3 neural network ")
  net = Net().to(device)
  net.train()  # set mode

# -----------------------------------------------------------

  # 3. train model
  max_epochs = 1000
  ep_log_interval = 200
  lrn_rate = 0.01

  loss_func = T.nn.NLLLoss()  # assumes log_softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training")
  for epoch in range(0, max_epochs):
    # T.manual_seed(epoch+1)  # checkpoint reproducibility
    epoch_loss = 0  # for one full epoch

    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # inputs
      Y = batch[1]  # correct class/label/politics

      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %5d  |  loss = %10.4f" % \
        (epoch, epoch_loss))

  print("Training done ")

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nEvaluating model ")
  net.eval()
  acc_train = accuracy(net, train_ds)  # item-by-item
  print("Accuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

  print("\nAccuracy on test data by class: ")
  acc_by_class(net, test_ds, 3)

  ce_train = calibration_error(net, train_ds)
  print("\nCalibration error on train data = %0.4f " % ce_train)
  ce_test = calibration_error(net, test_ds)
  print("Calibration error on test data = %0.4f " % ce_test)

  # 5. make a prediction
  print("\nPredicting politics for M  30  oklahoma  $50,000: ")
  X = np.array([[-1, 0.30,  0,0,1,  0.5000]], dtype=np.float32)
  X = T.tensor(X, dtype=T.float32).to(device) 

  with T.no_grad():
    logits = net(X)  # do not sum to 1.0
  probs = T.exp(logits)  # sum to 1.0
  probs = probs.numpy()  # numpy vector prints better
  np.set_printoptions(precision=4, suppress=True)
  print(probs)

  # 6. save model (state_dict approach)
  print("\nSaving trained model state")
  # fn = ".\\Models\\people_model.pt"
  # T.save(net.state_dict(), fn)

  # saved_model = Net()  # requires class definintion
  # saved_model.load_state_dict(T.load(fn))
  # use saved_model to make prediction(s)

  print("\nEnd People predict politics demo")

if __name__ == "__main__":
  main()

Training data (200 items):

# people_train.txt
# sex (M=-1, F=1)  age  state (michigan, 
# nebraska, oklahoma) income
# politics (consrvative, moderate, liberal)
#
 1, 0.24, 1, 0, 0, 0.2950, 2
-1, 0.39, 0, 0, 1, 0.5120, 1
 1, 0.63, 0, 1, 0, 0.7580, 0
-1, 0.36, 1, 0, 0, 0.4450, 1
 1, 0.27, 0, 1, 0, 0.2860, 2
 1, 0.50, 0, 1, 0, 0.5650, 1
 1, 0.50, 0, 0, 1, 0.5500, 1
-1, 0.19, 0, 0, 1, 0.3270, 0
 1, 0.22, 0, 1, 0, 0.2770, 1
-1, 0.39, 0, 0, 1, 0.4710, 2
 1, 0.34, 1, 0, 0, 0.3940, 1
-1, 0.22, 1, 0, 0, 0.3350, 0
 1, 0.35, 0, 0, 1, 0.3520, 2
-1, 0.33, 0, 1, 0, 0.4640, 1
 1, 0.45, 0, 1, 0, 0.5410, 1
 1, 0.42, 0, 1, 0, 0.5070, 1
-1, 0.33, 0, 1, 0, 0.4680, 1
 1, 0.25, 0, 0, 1, 0.3000, 1
-1, 0.31, 0, 1, 0, 0.4640, 0
 1, 0.27, 1, 0, 0, 0.3250, 2
 1, 0.48, 1, 0, 0, 0.5400, 1
-1, 0.64, 0, 1, 0, 0.7130, 2
 1, 0.61, 0, 1, 0, 0.7240, 0
 1, 0.54, 0, 0, 1, 0.6100, 0
 1, 0.29, 1, 0, 0, 0.3630, 0
 1, 0.50, 0, 0, 1, 0.5500, 1
 1, 0.55, 0, 0, 1, 0.6250, 0
 1, 0.40, 1, 0, 0, 0.5240, 0
 1, 0.22, 1, 0, 0, 0.2360, 2
 1, 0.68, 0, 1, 0, 0.7840, 0
-1, 0.60, 1, 0, 0, 0.7170, 2
-1, 0.34, 0, 0, 1, 0.4650, 1
-1, 0.25, 0, 0, 1, 0.3710, 0
-1, 0.31, 0, 1, 0, 0.4890, 1
 1, 0.43, 0, 0, 1, 0.4800, 1
 1, 0.58, 0, 1, 0, 0.6540, 2
-1, 0.55, 0, 1, 0, 0.6070, 2
-1, 0.43, 0, 1, 0, 0.5110, 1
-1, 0.43, 0, 0, 1, 0.5320, 1
-1, 0.21, 1, 0, 0, 0.3720, 0
 1, 0.55, 0, 0, 1, 0.6460, 0
 1, 0.64, 0, 1, 0, 0.7480, 0
-1, 0.41, 1, 0, 0, 0.5880, 1
 1, 0.64, 0, 0, 1, 0.7270, 0
-1, 0.56, 0, 0, 1, 0.6660, 2
 1, 0.31, 0, 0, 1, 0.3600, 1
-1, 0.65, 0, 0, 1, 0.7010, 2
 1, 0.55, 0, 0, 1, 0.6430, 0
-1, 0.25, 1, 0, 0, 0.4030, 0
 1, 0.46, 0, 0, 1, 0.5100, 1
-1, 0.36, 1, 0, 0, 0.5350, 0
 1, 0.52, 0, 1, 0, 0.5810, 1
 1, 0.61, 0, 0, 1, 0.6790, 0
 1, 0.57, 0, 0, 1, 0.6570, 0
-1, 0.46, 0, 1, 0, 0.5260, 1
-1, 0.62, 1, 0, 0, 0.6680, 2
 1, 0.55, 0, 0, 1, 0.6270, 0
-1, 0.22, 0, 0, 1, 0.2770, 1
-1, 0.50, 1, 0, 0, 0.6290, 0
-1, 0.32, 0, 1, 0, 0.4180, 1
-1, 0.21, 0, 0, 1, 0.3560, 0
 1, 0.44, 0, 1, 0, 0.5200, 1
 1, 0.46, 0, 1, 0, 0.5170, 1
 1, 0.62, 0, 1, 0, 0.6970, 0
 1, 0.57, 0, 1, 0, 0.6640, 0
-1, 0.67, 0, 0, 1, 0.7580, 2
 1, 0.29, 1, 0, 0, 0.3430, 2
 1, 0.53, 1, 0, 0, 0.6010, 0
-1, 0.44, 1, 0, 0, 0.5480, 1
 1, 0.46, 0, 1, 0, 0.5230, 1
-1, 0.20, 0, 1, 0, 0.3010, 1
-1, 0.38, 1, 0, 0, 0.5350, 1
 1, 0.50, 0, 1, 0, 0.5860, 1
 1, 0.33, 0, 1, 0, 0.4250, 1
-1, 0.33, 0, 1, 0, 0.3930, 1
 1, 0.26, 0, 1, 0, 0.4040, 0
 1, 0.58, 1, 0, 0, 0.7070, 0
 1, 0.43, 0, 0, 1, 0.4800, 1
-1, 0.46, 1, 0, 0, 0.6440, 0
 1, 0.60, 1, 0, 0, 0.7170, 0
-1, 0.42, 1, 0, 0, 0.4890, 1
-1, 0.56, 0, 0, 1, 0.5640, 2
-1, 0.62, 0, 1, 0, 0.6630, 2
-1, 0.50, 1, 0, 0, 0.6480, 1
 1, 0.47, 0, 0, 1, 0.5200, 1
-1, 0.67, 0, 1, 0, 0.8040, 2
-1, 0.40, 0, 0, 1, 0.5040, 1
 1, 0.42, 0, 1, 0, 0.4840, 1
 1, 0.64, 1, 0, 0, 0.7200, 0
-1, 0.47, 1, 0, 0, 0.5870, 2
 1, 0.45, 0, 1, 0, 0.5280, 1
-1, 0.25, 0, 0, 1, 0.4090, 0
 1, 0.38, 1, 0, 0, 0.4840, 0
 1, 0.55, 0, 0, 1, 0.6000, 1
-1, 0.44, 1, 0, 0, 0.6060, 1
 1, 0.33, 1, 0, 0, 0.4100, 1
 1, 0.34, 0, 0, 1, 0.3900, 1
 1, 0.27, 0, 1, 0, 0.3370, 2
 1, 0.32, 0, 1, 0, 0.4070, 1
 1, 0.42, 0, 0, 1, 0.4700, 1
-1, 0.24, 0, 0, 1, 0.4030, 0
 1, 0.42, 0, 1, 0, 0.5030, 1
 1, 0.25, 0, 0, 1, 0.2800, 2
 1, 0.51, 0, 1, 0, 0.5800, 1
-1, 0.55, 0, 1, 0, 0.6350, 2
 1, 0.44, 1, 0, 0, 0.4780, 2
-1, 0.18, 1, 0, 0, 0.3980, 0
-1, 0.67, 0, 1, 0, 0.7160, 2
 1, 0.45, 0, 0, 1, 0.5000, 1
 1, 0.48, 1, 0, 0, 0.5580, 1
-1, 0.25, 0, 1, 0, 0.3900, 1
-1, 0.67, 1, 0, 0, 0.7830, 1
 1, 0.37, 0, 0, 1, 0.4200, 1
-1, 0.32, 1, 0, 0, 0.4270, 1
 1, 0.48, 1, 0, 0, 0.5700, 1
-1, 0.66, 0, 0, 1, 0.7500, 2
 1, 0.61, 1, 0, 0, 0.7000, 0
-1, 0.58, 0, 0, 1, 0.6890, 1
 1, 0.19, 1, 0, 0, 0.2400, 2
 1, 0.38, 0, 0, 1, 0.4300, 1
-1, 0.27, 1, 0, 0, 0.3640, 1
 1, 0.42, 1, 0, 0, 0.4800, 1
 1, 0.60, 1, 0, 0, 0.7130, 0
-1, 0.27, 0, 0, 1, 0.3480, 0
 1, 0.29, 0, 1, 0, 0.3710, 0
-1, 0.43, 1, 0, 0, 0.5670, 1
 1, 0.48, 1, 0, 0, 0.5670, 1
 1, 0.27, 0, 0, 1, 0.2940, 2
-1, 0.44, 1, 0, 0, 0.5520, 0
 1, 0.23, 0, 1, 0, 0.2630, 2
-1, 0.36, 0, 1, 0, 0.5300, 2
 1, 0.64, 0, 0, 1, 0.7250, 0
 1, 0.29, 0, 0, 1, 0.3000, 2
-1, 0.33, 1, 0, 0, 0.4930, 1
-1, 0.66, 0, 1, 0, 0.7500, 2
-1, 0.21, 0, 0, 1, 0.3430, 0
 1, 0.27, 1, 0, 0, 0.3270, 2
 1, 0.29, 1, 0, 0, 0.3180, 2
-1, 0.31, 1, 0, 0, 0.4860, 1
 1, 0.36, 0, 0, 1, 0.4100, 1
 1, 0.49, 0, 1, 0, 0.5570, 1
-1, 0.28, 1, 0, 0, 0.3840, 0
-1, 0.43, 0, 0, 1, 0.5660, 1
-1, 0.46, 0, 1, 0, 0.5880, 1
 1, 0.57, 1, 0, 0, 0.6980, 0
-1, 0.52, 0, 0, 1, 0.5940, 1
-1, 0.31, 0, 0, 1, 0.4350, 1
-1, 0.55, 1, 0, 0, 0.6200, 2
 1, 0.50, 1, 0, 0, 0.5640, 1
 1, 0.48, 0, 1, 0, 0.5590, 1
-1, 0.22, 0, 0, 1, 0.3450, 0
 1, 0.59, 0, 0, 1, 0.6670, 0
 1, 0.34, 1, 0, 0, 0.4280, 2
-1, 0.64, 1, 0, 0, 0.7720, 2
 1, 0.29, 0, 0, 1, 0.3350, 2
-1, 0.34, 0, 1, 0, 0.4320, 1
-1, 0.61, 1, 0, 0, 0.7500, 2
 1, 0.64, 0, 0, 1, 0.7110, 0
-1, 0.29, 1, 0, 0, 0.4130, 0
 1, 0.63, 0, 1, 0, 0.7060, 0
-1, 0.29, 0, 1, 0, 0.4000, 0
-1, 0.51, 1, 0, 0, 0.6270, 1
-1, 0.24, 0, 0, 1, 0.3770, 0
 1, 0.48, 0, 1, 0, 0.5750, 1
 1, 0.18, 1, 0, 0, 0.2740, 0
 1, 0.18, 1, 0, 0, 0.2030, 2
 1, 0.33, 0, 1, 0, 0.3820, 2
-1, 0.20, 0, 0, 1, 0.3480, 0
 1, 0.29, 0, 0, 1, 0.3300, 2
-1, 0.44, 0, 0, 1, 0.6300, 0
-1, 0.65, 0, 0, 1, 0.8180, 0
-1, 0.56, 1, 0, 0, 0.6370, 2
-1, 0.52, 0, 0, 1, 0.5840, 1
-1, 0.29, 0, 1, 0, 0.4860, 0
-1, 0.47, 0, 1, 0, 0.5890, 1
 1, 0.68, 1, 0, 0, 0.7260, 2
 1, 0.31, 0, 0, 1, 0.3600, 1
 1, 0.61, 0, 1, 0, 0.6250, 2
 1, 0.19, 0, 1, 0, 0.2150, 2
 1, 0.38, 0, 0, 1, 0.4300, 1
-1, 0.26, 1, 0, 0, 0.4230, 0
 1, 0.61, 0, 1, 0, 0.6740, 0
 1, 0.40, 1, 0, 0, 0.4650, 1
-1, 0.49, 1, 0, 0, 0.6520, 1
 1, 0.56, 1, 0, 0, 0.6750, 0
-1, 0.48, 0, 1, 0, 0.6600, 1
 1, 0.52, 1, 0, 0, 0.5630, 2
-1, 0.18, 1, 0, 0, 0.2980, 0
-1, 0.56, 0, 0, 1, 0.5930, 2
-1, 0.52, 0, 1, 0, 0.6440, 1
-1, 0.18, 0, 1, 0, 0.2860, 1
-1, 0.58, 1, 0, 0, 0.6620, 2
-1, 0.39, 0, 1, 0, 0.5510, 1
-1, 0.46, 1, 0, 0, 0.6290, 1
-1, 0.40, 0, 1, 0, 0.4620, 1
-1, 0.60, 1, 0, 0, 0.7270, 2
 1, 0.36, 0, 1, 0, 0.4070, 2
 1, 0.44, 1, 0, 0, 0.5230, 1
 1, 0.28, 1, 0, 0, 0.3130, 2
 1, 0.54, 0, 0, 1, 0.6260, 0

Test data (40 items):

# people_test.txt
#
-1, 0.51, 1, 0, 0, 0.6120, 1
-1, 0.32, 0, 1, 0, 0.4610, 1
 1, 0.55, 1, 0, 0, 0.6270, 0
 1, 0.25, 0, 0, 1, 0.2620, 2
 1, 0.33, 0, 0, 1, 0.3730, 2
-1, 0.29, 0, 1, 0, 0.4620, 0
 1, 0.65, 1, 0, 0, 0.7270, 0
-1, 0.43, 0, 1, 0, 0.5140, 1
-1, 0.54, 0, 1, 0, 0.6480, 2
 1, 0.61, 0, 1, 0, 0.7270, 0
 1, 0.52, 0, 1, 0, 0.6360, 0
 1, 0.30, 0, 1, 0, 0.3350, 2
 1, 0.29, 1, 0, 0, 0.3140, 2
-1, 0.47, 0, 0, 1, 0.5940, 1
 1, 0.39, 0, 1, 0, 0.4780, 1
 1, 0.47, 0, 0, 1, 0.5200, 1
-1, 0.49, 1, 0, 0, 0.5860, 1
-1, 0.63, 0, 0, 1, 0.6740, 2
-1, 0.30, 1, 0, 0, 0.3920, 0
-1, 0.61, 0, 0, 1, 0.6960, 2
-1, 0.47, 0, 0, 1, 0.5870, 1
 1, 0.30, 0, 0, 1, 0.3450, 2
-1, 0.51, 0, 0, 1, 0.5800, 1
-1, 0.24, 1, 0, 0, 0.3880, 1
-1, 0.49, 1, 0, 0, 0.6450, 1
 1, 0.66, 0, 0, 1, 0.7450, 0
-1, 0.65, 1, 0, 0, 0.7690, 0
-1, 0.46, 0, 1, 0, 0.5800, 0
-1, 0.45, 0, 0, 1, 0.5180, 1
-1, 0.47, 1, 0, 0, 0.6360, 0
-1, 0.29, 1, 0, 0, 0.4480, 0
-1, 0.57, 0, 0, 1, 0.6930, 2
-1, 0.20, 1, 0, 0, 0.2870, 2
-1, 0.35, 1, 0, 0, 0.4340, 1
-1, 0.61, 0, 0, 1, 0.6700, 2
-1, 0.31, 0, 0, 1, 0.3730, 1
 1, 0.18, 1, 0, 0, 0.2080, 2
 1, 0.26, 0, 0, 1, 0.2920, 2
-1, 0.28, 1, 0, 0, 0.3640, 2
-1, 0.59, 0, 0, 1, 0.6940, 2
This entry was posted in PyTorch. Bookmark the permalink.