Calibrating a PyTorch Binary Classification Model Using Particle Swarm Optimization

If you have a binary or multi-class classifier prediction model, calibration error is the difference between the computed output pseudo-probability and accuracy. For example, suppose you are trying to predict the sex of a person (male = 0, female = 1) based on age, State of residence, income, and political leaning.

The pseudo-probability prediction output value of a trained model will be a number between 0.0 and 1.0, where a pp less than 0.5 is a prediction of class 0 = male, and a pp greater than 0.50 is a prediction of class 1 = female. So, if the output pp is 0.51, the predicted sex is (just barely) female. But if the output pp is 0.93, the predicted sex is (strongly) female. You’d like the output pseudo-probabilities of your prediction model to match the model accuracy. In other words, if a pp = 0.93, you’d like to say there’s a good chance (about 93%) that the prediction of female is correct.

Computing calibration error is not too difficult. But if you have a prediction model with poor calibration error (larger than about 0.45), can you modify the trained model to improve the calibration error without harming the prediction accuracy?

I set out to implement a proof-of-concept demo that improves the calibration error of a trained PyTorch binary classification model using particle swarm optimization (PSO). Because calibration error is not calculus-differentiable, you can’t use normal machine learning techniques such as stochastic gradient descent. So, a possible approach is to use an optimization technique that has no mathematical requirements — such as PSO.

My experiment took quite a bit longer than expected. For the demo, I used synthetic data that looks like:

1, 0.24, 1, 0, 0, 0.2950, 0, 0, 1
0, 0.39, 0, 0, 1, 0.5120, 0, 1, 0
1, 0.63, 0, 1, 0, 0.7580, 1, 0, 0
0, 0.36, 1, 0, 0, 0.4450, 0, 1, 0
. . .

Each line represents a person. The fields are sex (male = 0, female = 1), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by 100,000), political leaning (conservative = 100, moderate = 010, liberal = 001). The goal is to predict sex from age, State, income, and political leaning. There are just 200 training items and 40 test items, which is barely enough for a useable neural network.

The key calling statements in the demo are:

. . .
  net = Net(8,10,1).to(device)  # create network
  net.train()  # set into training mode
  lrn_rate = 0.01
  max_epochs = 800
  log_interval = 200
  train(net, train_ds, bat_size, lrn_rate,
    max_epochs, log_interval)  # train binary classifier
. . .
  n_particles = 10
  max_iter = 10
  start_wts = net.get_weights()
  calibrated_wts = calibrate(net, train_ds, start_wts,
    n_particles, max_iter, seed=0) # calibrate using PSO
  net.set_weights(calibrated_wts)
. . .

The demo creates an 8-10-1 binary classifier and trains it using standard SGD back-propagation. The resulting weights and biases are fetched and used as a starting point for swarm optimization tuning to improve calibration error — as opposed to starting the calibration from scratch. The particle swarm optimization calibration tuning function uses 10 Particles and 10 tuning iterations. The resulting tuned weights and biases are then placed back into the neural network.

The output of my demo is:

Binary classification calibration with PyTorch

Creating People train and test Datasets
Creating 8-(10)-1 binary NN classifier

Setting learn rate: 0.010
Setting batch size: 10
Setting max epochs: 800

Starting training
epoch =    0  |  loss = 4.9969
epoch =  200  |  loss = 4.9747
epoch =  400  |  loss = 4.9639
epoch =  600  |  loss = 4.9506
Done

Accuracy train = 0.5400
Accuracy test = 0.4750

Calibration error train = 0.0040
Calibration error test = 0.1792

Setting n_particles = 10
Setting max_iter = 10

Starting calibration using swarm optimization
iter =    0  best calibration error = 0.0046
iter =    2  best calibration error = 0.0017
iter =    4  best calibration error = 0.0017
iter =    6  best calibration error = 0.0017
iter =    8  best calibration error = 0.0017
Done

New accuracy train = 0.5400
New accuracy test = 0.4750

New calibration error train = 0.0011
New calibration error test = 0.1783

Setting age = 30  Oklahoma  $40,000  moderate
Computed output: 0.5285
Prediction = female

End calibration demo

Let me emphasize that this is just a experimental proof of concept, and I futzed around quite a bit with the many parameters to get a representative (but not optimal) demo. The basic model scores 54.00% accuracy on the training data (108 out of 200 correct) and already has an excellent calibration error of just 0.0040 — so there’s no need to calibrate the model. But I went ahead and did so anyway.

The calibrated model also scores 54.00% accuracy but the calibration error has been reduced to a tiny 0.0011, which is very small.

Well, this was an interesting experiment and the result suggests that using PSO to calibrate a trained model is a potentially useful algorithm. But it still needs a lot of effort.



When I was a teenager, I worked at an amusement center in Anaheim, California. One of my responsibilities was to go into the arcade room before opening time, and calibrate all the arcade machines — mostly pinball machines but also skeeball and crane games, and so on. I liked going inside the machines and replacing parts and adjusting the mechanisms.

The Merrivale Old Penny Arcade is located at Merrivale Model Village in Great Yarmouth, Norfolk, UK, about 120 miles northeast of London. The Arcade has some interesting old automata. I’m sure they must require constant calbration.

Left: The Haunted Crypt. Doors open, crypt lids open, skeletons appear, and so on. A big part of the fun is trying to guess what part or person will move next.

Right: The Haunted House. Very much the same idea as the Haunted Crypt. Whenever old automata had a window, you could be sure something would appear there.


Demo program. Experimental only. Has lots of de-optimized code for clarity, and likely has a few bugs too. Replace “lt” (less than), “gt”, “lt”, “gte” with Boolean operator symbols. (My sad blog editor oftn chokes on symbols).

# people_gender_calibrate_swarm.py
# binary classification calibration using PSO
# PyTorch 2.3.1-CPU Anaconda3-2023.09  Python 3.11.5
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex age   state  income  politics
  #  0  0.27  0 1 0  0.7610  0 0 1
  #  1  0.19  0 0 1  0.6550  1 0 0
  # sex: 0 = male, 1 = female
  # state: michigan, nebraska, oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    all_data = np.loadtxt(src_file, usecols=range(0,9),
      delimiter=",", comments="#", dtype=np.float32) 

    self.x_data = T.tensor(all_data[:,1:9],
      dtype=T.float32).to(device)
    self.y_data = T.tensor(all_data[:,0],
      dtype=T.float32).to(device)  # float32 required

    self.y_data = self.y_data.reshape(-1,1)  # 2-D required

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    feats = self.x_data[idx,:]  # idx row, all 8 cols
    sex = self.y_data[idx,:]    # idx row, the only col
    return feats, sex  # as a Tuple

# -----------------------------------------------------------

class Particle():
  def __init__(self, solnLen):
    super(Particle, self).__init__()
    self.position = T.zeros(solnLen, 
      dtype=T.float32).to(device)
    self.error = T.finfo(T.float32).max
    self.velocity = T.zeros(solnLen, 
      dtype=T.float32).to(device)
    self.best_position = T.zeros(solnLen,
      dtype=T.float32).to(device)
    self.best_error = T.finfo(T.float32).max

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self, ni, nh, no):
    super(Net, self).__init__()
    self.ni = ni; self.nh = nh; self.no = no
    self.hid1 = T.nn.Linear(ni, nh)  # 8-(10)-1
    self.oupt = T.nn.Linear(nh, no)

    # T.nn.init.xavier_uniform_(self.hid1.weight) 
    # T.nn.init.zeros_(self.hid1.bias)
    # T.nn.init.xavier_uniform_(self.oupt.weight) 
    # T.nn.init.zeros_(self.oupt.bias)

    T.nn.init.uniform_(self.hid1.weight, -0.10, +0.10) 
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.uniform_(self.oupt.weight, -0.10, +0.10) 
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.sigmoid(self.oupt(z))  # for BCELoss()
    return z

  def get_weights(self):
    # ih weights, h biases, ho weights, o biases
    n_wts = (self.ni * self.nh) + self.nh + \
      (self.nh * self.no) + self.no
    result = T.zeros(n_wts, dtype=T.float32).to(device)
    ptr = 0

    for col in range(self.ni):
      for row in range(self.nh):
        result[ptr] = self.hid1.weight.data[row][col]
        ptr += 1
    for j in range(self.nh):
      result[ptr] = self.hid1.bias.data[j]
      ptr += 1
    for col in range(self.nh):
      for row in range(self.no):
        result[ptr] = self.oupt.weight.data[row][col]
        ptr += 1 
    for k in range(self.no):
      result[ptr] = self.oupt.bias.data[k]
      ptr += 1
    return result

  def set_weights(self, wts):
    ptr = 0
    for col in range(self.ni):
      for row in range(self.nh):
        self.hid1.weight.data[row][col] = wts[ptr]
        ptr += 1
    for j in range(self.nh):
      self.hid1.bias.data[j] = wts[ptr]
      ptr += 1
    for col in range(self.nh):
      for row in range(self.no):
        self.oupt.weight.data[row][col] = wts[ptr]
        ptr += 1 
    for k in range(self.no):
      self.oupt.bias.data[k] = wts[ptr]
      ptr += 1

# -----------------------------------------------------------

def calibration_error(model, ds):
  counts = np.zeros(10, dtype=np.int64)  # of pseudo-probs
  sums = np.zeros(10, dtype=np.float32)  # of pseudo-probs
  n_corrects = np.zeros(10, dtype=np.int64)
  n_wrongs = np.zeros(10, dtype=np.int64)  # not needed
  accuracies = np.zeros(10, dtype=np.float32)
  avg_pps = np.zeros(10, dtype=np.float32)
  abs_diffs = np.zeros(10, dtype=np.float32)

  for i in range(len(ds)):
    inpts = ds[i][0]         # dictionary style
    target = ds[i][1]        # float32  [0.0] or [1.0]
    target = target.int()    # int 0 or 1
    with T.no_grad():
      p = model(inpts)       # between 0.0 and 1.0
    pp = p.item()

    correct = False
    if target == 1 and p "gte" 0.5: correct = True
    elif target == 0 and p "lt" 0.5: correct = True
 
    if pp "gte" 0.0 and pp "lt" 0.1: bin = 0
    elif pp "gte" 0.1 and pp "lt" 0.2: bin = 1
    elif pp "gte" 0.2 and pp "lt" 0.3: bin = 2
    elif pp "gte" 0.3 and pp "lt" 0.4: bin = 3
    elif pp "gte" 0.4 and pp "lt" 0.5: bin = 4
    elif pp "gte" 0.5 and pp "lt" 0.6: bin = 5
    elif pp "gte" 0.6 and pp "lt" 0.7: bin = 6
    elif pp "gte" 0.7 and pp "lt" 0.8: bin = 7
    elif pp "gte" 0.8 and pp "lt" 0.9: bin = 8
    elif pp "gte" 0.9 and pp "lte" 1.0: bin = 9

    counts[bin] += 1
    sums[bin] += pp
    if correct == True: n_corrects[bin] += 1
    elif correct == False: n_wrongs[bin] += 1  # check

  for bin in range(10):
    if counts[bin] == 0: accuracies[bin] = 0.0
    else: accuracies[bin] = n_corrects[bin] / counts[bin]

  for bin in range(10):
    if counts[bin] == 0: avg_pps[bin] = 0.0
    else: avg_pps[bin] = sums[bin] / counts[bin]

  for bin in range(10):
    if bin "lte" 4: # bins for class 0
      abs_diffs[bin] = \
      np.abs((1 - avg_pps[bin]) - accuracies[bin])
    elif bin "gte" 5:  # class 1
      abs_diffs[bin] = \
      np.abs(avg_pps[bin] - accuracies[bin])  

  cal_err = 0.0
  for bin in range(10):
    cal_err += counts[bin] * abs_diffs[bin]  # weighted
  cal_err /= len(ds)
  return cal_err

# -----------------------------------------------------------

def accuracy(model, ds):
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    inpts = ds[i][0]           # dictionary style
    actu_y = ds[i][1].int()    # 0 or 1
    with T.no_grad():
      pred_y = model(inpts)    # between 0.0 and 1.0
    if actu_y == 0 and pred_y "lt" 0.5: n_correct += 1
    elif actu_y == 1 and pred_y "gte" 0.5: n_correct += 1
    else: n_wrong += 1
  return (n_correct * 1.0) / (n_correct + n_wrong)

# -----------------------------------------------------------

def train(model, ds, bs, lr, me, le):
  # dataset, bat_size, lrn_rate, max_epochs, log interval
  train_ldr = T.utils.data.DataLoader(ds, batch_size=bs,
    shuffle=True)
  # loss_func = T.nn.BCELoss()
  loss_func = T.nn.MSELoss()
  optimizer = T.optim.SGD(model.parameters(), lr=lr)

  for epoch in range(0, me):
    epoch_loss = 0.0  # for one full epoch
    for (b_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # predictors
      y = batch[1]  # targets
      optimizer.zero_grad()
      oupt = model(X)
      loss_val = loss_func(oupt, y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()  # compute gradients
      optimizer.step()     # update weights

    if epoch % le == 0:
      print("epoch = %4d  |  loss = %0.4f" % \
        (epoch, epoch_loss))  

# -----------------------------------------------------------
# -----------------------------------------------------------

def calibrate(model, ds, start_wts, n_particles,
  max_iter, seed):
  # 0. prepare
  rnd = np.random.RandomState(seed)
  soln_len = len(start_wts)
  lo = -0.10; hi = +0.10
  global_best_pos = T.zeros(soln_len, 
    dtype=T.float32).to(device)
  global_best_error = T.finfo(T.float32).max
  w = 0.729;    # inertia weight
  c1 = 1.49445  # cognitive weight
  c2 = 1.49445  # social weight

  # 1. create a swarm of particles based on start weights
  swarm = []  # a list of Particle objects
  for i in range(n_particles):
    p = Particle(soln_len)
    for j in range(soln_len):
      p.position[j] = start_wts[j] + ((hi - lo) * \
        rnd.random() + lo)
      p.velocity[j] = (hi - lo) * rnd.random() + lo
      p.best_position[j] = p.position[j]
    saved_weights = model.get_weights()
    model.set_weights(p.position)
    p.error = calibration_error(model, ds)
    model.set_weights(saved_weights)
    p.best_error = p.error
    swarm.append(p)
    
  # 2. set global bests
  for i in range(n_particles):
    if swarm[i].error "lt" global_best_error:
      global_best_error = swarm[i].error
      for j in range(soln_len):
        global_best_pos[j] = swarm[i].position[j]

  # 3. main processing loop
  for iter in range(max_iter):
    for i in range(n_particles):
      curr_p = swarm[i]
      for j in range(soln_len):  # update velocity
        r1 = rnd.random()
        r2 = rnd.random()
        curr_p.velocity[j] = (w * curr_p.velocity[j]) + \
          (c1 * r1 * (curr_p.best_position[j] - \
          curr_p.position[j])) + \
          (c2 * r2 * (global_best_pos[j] - \
          curr_p.position[j]))
      for j in range(soln_len):  # update position 
        curr_p.position[j] = curr_p.position[j] + \
        curr_p.velocity[j]

      # update calibration error
      saved_weights = model.get_weights()
      model.set_weights(curr_p.position)
      curr_p.error = calibration_error(model, ds)
      model.set_weights(saved_weights)

      # check if new particle best
      if curr_p.error "lt" curr_p.best_error:
        curr_p.best_error = curr_p.error
        for j in range(soln_len):
          curr_p.best_position[j] = curr_p.position[j]

      # check if new global best found
      if curr_p.error "lt" global_best_error:
        global_best_error = curr_p.error
        for j in range(soln_len):
          global_best_pos[j] = curr_p.position[j]

    # display progress
    if iter % (max_iter // 5) == 0:
      print("iter = %4d " % iter, end="")
      print(" best calibration error = %0.4f " % \
        global_best_error)

  # return best position found by any particle
  result = T.zeros(soln_len, dtype=T.float32)
  for j in range(soln_len):
    result[j] = global_best_pos[j]
  return result

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBinary classification calibration with PyTorch ")
  T.manual_seed(0)
  np.random.seed(0)
  np.set_printoptions(suppress=True, precision=4,
    floatmode='fixed', sign = ' ')

  # 1. create Dataset and DataLoader objects
  print("\nCreating People train and test Datasets ")

  train_file = ".\\Data\\people_train.txt"
  test_file = ".\\Data\\people_test.txt"

  train_ds = PeopleDataset(train_file)  # 200 rows
  test_ds = PeopleDataset(test_file)    # 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create neural network
  print("\nCreating 8-(10)-1 binary NN classifier ")
  net = Net(8,10,1).to(device)

  # 3. train network
  net.train()  # set into training mode
  lrn_rate = 0.01
  max_epochs = 800
  log_interval = 200
  print("\nSetting learn rate: " + "%0.3f" % lrn_rate)
  print("Setting batch size: " + str(bat_size))
  print("Setting max epochs: " + str(max_epochs))
  print("\nStarting training")
  train(net, train_ds, bat_size, lrn_rate,
    max_epochs, log_interval)
  print("Done ")

  # 4. evaluate trained model
  net.eval()
  acc_train = accuracy(net, train_ds)
  acc_test = accuracy(net, test_ds)
  print("\nAccuracy train = %0.4f " % acc_train)
  print("Accuracy test = %0.4f " % acc_test)

  ce_train = calibration_error(net, train_ds)
  ce_test = calibration_error(net, test_ds)
  print("\nCalibration error train = %0.4f " % ce_train)
  print("Calibration error test = %0.4f " % ce_test)

  # 5. calibrate model
  n_particles = 10
  max_iter = 10
  print("\nSetting n_particles = " + str(n_particles))
  print("Setting max_iter = " + str(max_iter))
  print("\nStarting calibration using swarm optimization ")
  start_wts = net.get_weights()
  calibrated_wts = calibrate(net, train_ds, start_wts,
    n_particles, max_iter, seed=0)
  net.set_weights(calibrated_wts)
  print("Done ")

  # 6. evaluate calibrated model
  net.eval()
  acc_train = accuracy(net, train_ds)
  acc_test = accuracy(net, test_ds)
  print("\nNew accuracy train = %0.4f " % acc_train)
  print("New accuracy test = %0.4f " % acc_test)

  ce_train = calibration_error(net, train_ds)
  ce_test = calibration_error(net, test_ds)
  print("\nNew calibration error train = %0.4f " % ce_train)
  print("New calibration error test = %0.4f " % ce_test)

  # 7. save model
  # print("\nSaving trained model state_dict ")
  # net.eval()
  # path = ".\\Models\\people_gender_model.pt"
  # T.save(net.state_dict(), path)

  # 8. make a prediction 
  print("\nSetting age = 30  Oklahoma  $40,000  moderate ")
  inpt = np.array([[0.30, 0,0,1, 0.4000, 0,1,0]],
    dtype=np.float32)
  inpt = T.tensor(inpt, dtype=T.float32).to(device)

  net.eval()
  with T.no_grad():
    oupt = net(inpt)    # a Tensor
  pred_prob = oupt.item()  # scalar, [0.0, 1.0]
  print("Computed output: ", end="")
  print("%0.4f" % pred_prob)

  if pred_prob "lt" 0.5:
    print("Prediction = male")
  else:
    print("Prediction = female")

  print("\nEnd calibration demo ")

if __name__== "__main__":
  main()

Training data:

# people_train.txt
# sex (0 = male, 1 = female) - dependent variable
# age, state (michigan, nebraska, oklahoma), income,
# politics type (conservative, moderate, liberal)
#
1, 0.24, 1, 0, 0, 0.2950, 0, 0, 1
0, 0.39, 0, 0, 1, 0.5120, 0, 1, 0
1, 0.63, 0, 1, 0, 0.7580, 1, 0, 0
0, 0.36, 1, 0, 0, 0.4450, 0, 1, 0
1, 0.27, 0, 1, 0, 0.2860, 0, 0, 1
1, 0.50, 0, 1, 0, 0.5650, 0, 1, 0
1, 0.50, 0, 0, 1, 0.5500, 0, 1, 0
0, 0.19, 0, 0, 1, 0.3270, 1, 0, 0
1, 0.22, 0, 1, 0, 0.2770, 0, 1, 0
0, 0.39, 0, 0, 1, 0.4710, 0, 0, 1
1, 0.34, 1, 0, 0, 0.3940, 0, 1, 0
0, 0.22, 1, 0, 0, 0.3350, 1, 0, 0
1, 0.35, 0, 0, 1, 0.3520, 0, 0, 1
0, 0.33, 0, 1, 0, 0.4640, 0, 1, 0
1, 0.45, 0, 1, 0, 0.5410, 0, 1, 0
1, 0.42, 0, 1, 0, 0.5070, 0, 1, 0
0, 0.33, 0, 1, 0, 0.4680, 0, 1, 0
1, 0.25, 0, 0, 1, 0.3000, 0, 1, 0
0, 0.31, 0, 1, 0, 0.4640, 1, 0, 0
1, 0.27, 1, 0, 0, 0.3250, 0, 0, 1
1, 0.48, 1, 0, 0, 0.5400, 0, 1, 0
0, 0.64, 0, 1, 0, 0.7130, 0, 0, 1
1, 0.61, 0, 1, 0, 0.7240, 1, 0, 0
1, 0.54, 0, 0, 1, 0.6100, 1, 0, 0
1, 0.29, 1, 0, 0, 0.3630, 1, 0, 0
1, 0.50, 0, 0, 1, 0.5500, 0, 1, 0
1, 0.55, 0, 0, 1, 0.6250, 1, 0, 0
1, 0.40, 1, 0, 0, 0.5240, 1, 0, 0
1, 0.22, 1, 0, 0, 0.2360, 0, 0, 1
1, 0.68, 0, 1, 0, 0.7840, 1, 0, 0
0, 0.60, 1, 0, 0, 0.7170, 0, 0, 1
0, 0.34, 0, 0, 1, 0.4650, 0, 1, 0
0, 0.25, 0, 0, 1, 0.3710, 1, 0, 0
0, 0.31, 0, 1, 0, 0.4890, 0, 1, 0
1, 0.43, 0, 0, 1, 0.4800, 0, 1, 0
1, 0.58, 0, 1, 0, 0.6540, 0, 0, 1
0, 0.55, 0, 1, 0, 0.6070, 0, 0, 1
0, 0.43, 0, 1, 0, 0.5110, 0, 1, 0
0, 0.43, 0, 0, 1, 0.5320, 0, 1, 0
0, 0.21, 1, 0, 0, 0.3720, 1, 0, 0
1, 0.55, 0, 0, 1, 0.6460, 1, 0, 0
1, 0.64, 0, 1, 0, 0.7480, 1, 0, 0
0, 0.41, 1, 0, 0, 0.5880, 0, 1, 0
1, 0.64, 0, 0, 1, 0.7270, 1, 0, 0
0, 0.56, 0, 0, 1, 0.6660, 0, 0, 1
1, 0.31, 0, 0, 1, 0.3600, 0, 1, 0
0, 0.65, 0, 0, 1, 0.7010, 0, 0, 1
1, 0.55, 0, 0, 1, 0.6430, 1, 0, 0
0, 0.25, 1, 0, 0, 0.4030, 1, 0, 0
1, 0.46, 0, 0, 1, 0.5100, 0, 1, 0
0, 0.36, 1, 0, 0, 0.5350, 1, 0, 0
1, 0.52, 0, 1, 0, 0.5810, 0, 1, 0
1, 0.61, 0, 0, 1, 0.6790, 1, 0, 0
1, 0.57, 0, 0, 1, 0.6570, 1, 0, 0
0, 0.46, 0, 1, 0, 0.5260, 0, 1, 0
0, 0.62, 1, 0, 0, 0.6680, 0, 0, 1
1, 0.55, 0, 0, 1, 0.6270, 1, 0, 0
0, 0.22, 0, 0, 1, 0.2770, 0, 1, 0
0, 0.50, 1, 0, 0, 0.6290, 1, 0, 0
0, 0.32, 0, 1, 0, 0.4180, 0, 1, 0
0, 0.21, 0, 0, 1, 0.3560, 1, 0, 0
1, 0.44, 0, 1, 0, 0.5200, 0, 1, 0
1, 0.46, 0, 1, 0, 0.5170, 0, 1, 0
1, 0.62, 0, 1, 0, 0.6970, 1, 0, 0
1, 0.57, 0, 1, 0, 0.6640, 1, 0, 0
0, 0.67, 0, 0, 1, 0.7580, 0, 0, 1
1, 0.29, 1, 0, 0, 0.3430, 0, 0, 1
1, 0.53, 1, 0, 0, 0.6010, 1, 0, 0
0, 0.44, 1, 0, 0, 0.5480, 0, 1, 0
1, 0.46, 0, 1, 0, 0.5230, 0, 1, 0
0, 0.20, 0, 1, 0, 0.3010, 0, 1, 0
0, 0.38, 1, 0, 0, 0.5350, 0, 1, 0
1, 0.50, 0, 1, 0, 0.5860, 0, 1, 0
1, 0.33, 0, 1, 0, 0.4250, 0, 1, 0
0, 0.33, 0, 1, 0, 0.3930, 0, 1, 0
1, 0.26, 0, 1, 0, 0.4040, 1, 0, 0
1, 0.58, 1, 0, 0, 0.7070, 1, 0, 0
1, 0.43, 0, 0, 1, 0.4800, 0, 1, 0
0, 0.46, 1, 0, 0, 0.6440, 1, 0, 0
1, 0.60, 1, 0, 0, 0.7170, 1, 0, 0
0, 0.42, 1, 0, 0, 0.4890, 0, 1, 0
0, 0.56, 0, 0, 1, 0.5640, 0, 0, 1
0, 0.62, 0, 1, 0, 0.6630, 0, 0, 1
0, 0.50, 1, 0, 0, 0.6480, 0, 1, 0
1, 0.47, 0, 0, 1, 0.5200, 0, 1, 0
0, 0.67, 0, 1, 0, 0.8040, 0, 0, 1
0, 0.40, 0, 0, 1, 0.5040, 0, 1, 0
1, 0.42, 0, 1, 0, 0.4840, 0, 1, 0
1, 0.64, 1, 0, 0, 0.7200, 1, 0, 0
0, 0.47, 1, 0, 0, 0.5870, 0, 0, 1
1, 0.45, 0, 1, 0, 0.5280, 0, 1, 0
0, 0.25, 0, 0, 1, 0.4090, 1, 0, 0
1, 0.38, 1, 0, 0, 0.4840, 1, 0, 0
1, 0.55, 0, 0, 1, 0.6000, 0, 1, 0
0, 0.44, 1, 0, 0, 0.6060, 0, 1, 0
1, 0.33, 1, 0, 0, 0.4100, 0, 1, 0
1, 0.34, 0, 0, 1, 0.3900, 0, 1, 0
1, 0.27, 0, 1, 0, 0.3370, 0, 0, 1
1, 0.32, 0, 1, 0, 0.4070, 0, 1, 0
1, 0.42, 0, 0, 1, 0.4700, 0, 1, 0
0, 0.24, 0, 0, 1, 0.4030, 1, 0, 0
1, 0.42, 0, 1, 0, 0.5030, 0, 1, 0
1, 0.25, 0, 0, 1, 0.2800, 0, 0, 1
1, 0.51, 0, 1, 0, 0.5800, 0, 1, 0
0, 0.55, 0, 1, 0, 0.6350, 0, 0, 1
1, 0.44, 1, 0, 0, 0.4780, 0, 0, 1
0, 0.18, 1, 0, 0, 0.3980, 1, 0, 0
0, 0.67, 0, 1, 0, 0.7160, 0, 0, 1
1, 0.45, 0, 0, 1, 0.5000, 0, 1, 0
1, 0.48, 1, 0, 0, 0.5580, 0, 1, 0
0, 0.25, 0, 1, 0, 0.3900, 0, 1, 0
0, 0.67, 1, 0, 0, 0.7830, 0, 1, 0
1, 0.37, 0, 0, 1, 0.4200, 0, 1, 0
0, 0.32, 1, 0, 0, 0.4270, 0, 1, 0
1, 0.48, 1, 0, 0, 0.5700, 0, 1, 0
0, 0.66, 0, 0, 1, 0.7500, 0, 0, 1
1, 0.61, 1, 0, 0, 0.7000, 1, 0, 0
0, 0.58, 0, 0, 1, 0.6890, 0, 1, 0
1, 0.19, 1, 0, 0, 0.2400, 0, 0, 1
1, 0.38, 0, 0, 1, 0.4300, 0, 1, 0
0, 0.27, 1, 0, 0, 0.3640, 0, 1, 0
1, 0.42, 1, 0, 0, 0.4800, 0, 1, 0
1, 0.60, 1, 0, 0, 0.7130, 1, 0, 0
0, 0.27, 0, 0, 1, 0.3480, 1, 0, 0
1, 0.29, 0, 1, 0, 0.3710, 1, 0, 0
0, 0.43, 1, 0, 0, 0.5670, 0, 1, 0
1, 0.48, 1, 0, 0, 0.5670, 0, 1, 0
1, 0.27, 0, 0, 1, 0.2940, 0, 0, 1
0, 0.44, 1, 0, 0, 0.5520, 1, 0, 0
1, 0.23, 0, 1, 0, 0.2630, 0, 0, 1
0, 0.36, 0, 1, 0, 0.5300, 0, 0, 1
1, 0.64, 0, 0, 1, 0.7250, 1, 0, 0
1, 0.29, 0, 0, 1, 0.3000, 0, 0, 1
0, 0.33, 1, 0, 0, 0.4930, 0, 1, 0
0, 0.66, 0, 1, 0, 0.7500, 0, 0, 1
0, 0.21, 0, 0, 1, 0.3430, 1, 0, 0
1, 0.27, 1, 0, 0, 0.3270, 0, 0, 1
1, 0.29, 1, 0, 0, 0.3180, 0, 0, 1
0, 0.31, 1, 0, 0, 0.4860, 0, 1, 0
1, 0.36, 0, 0, 1, 0.4100, 0, 1, 0
1, 0.49, 0, 1, 0, 0.5570, 0, 1, 0
0, 0.28, 1, 0, 0, 0.3840, 1, 0, 0
0, 0.43, 0, 0, 1, 0.5660, 0, 1, 0
0, 0.46, 0, 1, 0, 0.5880, 0, 1, 0
1, 0.57, 1, 0, 0, 0.6980, 1, 0, 0
0, 0.52, 0, 0, 1, 0.5940, 0, 1, 0
0, 0.31, 0, 0, 1, 0.4350, 0, 1, 0
0, 0.55, 1, 0, 0, 0.6200, 0, 0, 1
1, 0.50, 1, 0, 0, 0.5640, 0, 1, 0
1, 0.48, 0, 1, 0, 0.5590, 0, 1, 0
0, 0.22, 0, 0, 1, 0.3450, 1, 0, 0
1, 0.59, 0, 0, 1, 0.6670, 1, 0, 0
1, 0.34, 1, 0, 0, 0.4280, 0, 0, 1
0, 0.64, 1, 0, 0, 0.7720, 0, 0, 1
1, 0.29, 0, 0, 1, 0.3350, 0, 0, 1
0, 0.34, 0, 1, 0, 0.4320, 0, 1, 0
0, 0.61, 1, 0, 0, 0.7500, 0, 0, 1
1, 0.64, 0, 0, 1, 0.7110, 1, 0, 0
0, 0.29, 1, 0, 0, 0.4130, 1, 0, 0
1, 0.63, 0, 1, 0, 0.7060, 1, 0, 0
0, 0.29, 0, 1, 0, 0.4000, 1, 0, 0
0, 0.51, 1, 0, 0, 0.6270, 0, 1, 0
0, 0.24, 0, 0, 1, 0.3770, 1, 0, 0
1, 0.48, 0, 1, 0, 0.5750, 0, 1, 0
1, 0.18, 1, 0, 0, 0.2740, 1, 0, 0
1, 0.18, 1, 0, 0, 0.2030, 0, 0, 1
1, 0.33, 0, 1, 0, 0.3820, 0, 0, 1
0, 0.20, 0, 0, 1, 0.3480, 1, 0, 0
1, 0.29, 0, 0, 1, 0.3300, 0, 0, 1
0, 0.44, 0, 0, 1, 0.6300, 1, 0, 0
0, 0.65, 0, 0, 1, 0.8180, 1, 0, 0
0, 0.56, 1, 0, 0, 0.6370, 0, 0, 1
0, 0.52, 0, 0, 1, 0.5840, 0, 1, 0
0, 0.29, 0, 1, 0, 0.4860, 1, 0, 0
0, 0.47, 0, 1, 0, 0.5890, 0, 1, 0
1, 0.68, 1, 0, 0, 0.7260, 0, 0, 1
1, 0.31, 0, 0, 1, 0.3600, 0, 1, 0
1, 0.61, 0, 1, 0, 0.6250, 0, 0, 1
1, 0.19, 0, 1, 0, 0.2150, 0, 0, 1
1, 0.38, 0, 0, 1, 0.4300, 0, 1, 0
0, 0.26, 1, 0, 0, 0.4230, 1, 0, 0
1, 0.61, 0, 1, 0, 0.6740, 1, 0, 0
1, 0.40, 1, 0, 0, 0.4650, 0, 1, 0
0, 0.49, 1, 0, 0, 0.6520, 0, 1, 0
1, 0.56, 1, 0, 0, 0.6750, 1, 0, 0
0, 0.48, 0, 1, 0, 0.6600, 0, 1, 0
1, 0.52, 1, 0, 0, 0.5630, 0, 0, 1
0, 0.18, 1, 0, 0, 0.2980, 1, 0, 0
0, 0.56, 0, 0, 1, 0.5930, 0, 0, 1
0, 0.52, 0, 1, 0, 0.6440, 0, 1, 0
0, 0.18, 0, 1, 0, 0.2860, 0, 1, 0
0, 0.58, 1, 0, 0, 0.6620, 0, 0, 1
0, 0.39, 0, 1, 0, 0.5510, 0, 1, 0
0, 0.46, 1, 0, 0, 0.6290, 0, 1, 0
0, 0.40, 0, 1, 0, 0.4620, 0, 1, 0
0, 0.60, 1, 0, 0, 0.7270, 0, 0, 1
1, 0.36, 0, 1, 0, 0.4070, 0, 0, 1
1, 0.44, 1, 0, 0, 0.5230, 0, 1, 0
1, 0.28, 1, 0, 0, 0.3130, 0, 0, 1
1, 0.54, 0, 0, 1, 0.6260, 1, 0, 0

Test data:

# people_test.txt
#
0, 0.51, 1, 0, 0, 0.6120, 0, 1, 0
0, 0.32, 0, 1, 0, 0.4610, 0, 1, 0
1, 0.55, 1, 0, 0, 0.6270, 1, 0, 0
1, 0.25, 0, 0, 1, 0.2620, 0, 0, 1
1, 0.33, 0, 0, 1, 0.3730, 0, 0, 1
0, 0.29, 0, 1, 0, 0.4620, 1, 0, 0
1, 0.65, 1, 0, 0, 0.7270, 1, 0, 0
0, 0.43, 0, 1, 0, 0.5140, 0, 1, 0
0, 0.54, 0, 1, 0, 0.6480, 0, 0, 1
1, 0.61, 0, 1, 0, 0.7270, 1, 0, 0
1, 0.52, 0, 1, 0, 0.6360, 1, 0, 0
1, 0.30, 0, 1, 0, 0.3350, 0, 0, 1
1, 0.29, 1, 0, 0, 0.3140, 0, 0, 1
0, 0.47, 0, 0, 1, 0.5940, 0, 1, 0
1, 0.39, 0, 1, 0, 0.4780, 0, 1, 0
1, 0.47, 0, 0, 1, 0.5200, 0, 1, 0
0, 0.49, 1, 0, 0, 0.5860, 0, 1, 0
0, 0.63, 0, 0, 1, 0.6740, 0, 0, 1
0, 0.30, 1, 0, 0, 0.3920, 1, 0, 0
0, 0.61, 0, 0, 1, 0.6960, 0, 0, 1
0, 0.47, 0, 0, 1, 0.5870, 0, 1, 0
1, 0.30, 0, 0, 1, 0.3450, 0, 0, 1
0, 0.51, 0, 0, 1, 0.5800, 0, 1, 0
0, 0.24, 1, 0, 0, 0.3880, 0, 1, 0
0, 0.49, 1, 0, 0, 0.6450, 0, 1, 0
1, 0.66, 0, 0, 1, 0.7450, 1, 0, 0
0, 0.65, 1, 0, 0, 0.7690, 1, 0, 0
0, 0.46, 0, 1, 0, 0.5800, 1, 0, 0
0, 0.45, 0, 0, 1, 0.5180, 0, 1, 0
0, 0.47, 1, 0, 0, 0.6360, 1, 0, 0
0, 0.29, 1, 0, 0, 0.4480, 1, 0, 0
0, 0.57, 0, 0, 1, 0.6930, 0, 0, 1
0, 0.20, 1, 0, 0, 0.2870, 0, 0, 1
0, 0.35, 1, 0, 0, 0.4340, 0, 1, 0
0, 0.61, 0, 0, 1, 0.6700, 0, 0, 1
0, 0.31, 0, 0, 1, 0.3730, 0, 1, 0
1, 0.18, 1, 0, 0, 0.2080, 0, 0, 1
1, 0.26, 0, 0, 1, 0.2920, 0, 0, 1
0, 0.28, 1, 0, 0, 0.3640, 0, 0, 1
0, 0.59, 0, 0, 1, 0.6940, 0, 0, 1
This entry was posted in Machine Learning. Bookmark the permalink.

2 Responses to Calibrating a PyTorch Binary Classification Model Using Particle Swarm Optimization

  1. Thorsten Kleppe's avatar Thorsten Kleppe says:

    PSO may just need DeepSeek’s attention moment. There are also many efforts to push RL right now. The thing that makes me wonder about the demo is the low accuracy of a binary classification example, which is like flipping a coin?

    The top image shows the ensemble line in white and six trained prediction models in different colors. The data point is chosen outside the known data points and could represent an impossible concrete mixture for various reasons. In this case, all models are usually very far apart in their predictions. The image below shows a data point from the training data. Here are all models very close to each other. Good, but no great prediction.

    This is probably not the usual method for checking the prediction quality, but it works quite well and the strength of the individual predictors can even be tracked and adjusted.

  2. Thorsten, Your visualizations continue to be amazing! JM

Leave a Reply