PyTorch Binary Classification Using Two Output Nodes Instead of One Output Node

Quite some time ago, before the rise of PyTorch library as the de facto standard for neural networks, I did some experiments that explored binary classification using two output nodes instead of the usual one output node. My experiments used the Keras library and from-scratch Python code. Those experiments were not conclusive, but suggested that there is no significant difference between the effectiveness of the two techniques. I figured I’d take another look at the idea, using PyTorch.

To remind you: Neural network binary classification almost always uses a network with a single output node. The output node activation is logistic sigmoid so that the value of the output node is between 0.0 and 1.0. A computed output value that is less than 0.5 is a prediction of class 0, and an output value greater than 0.5 is a prediction of class 1. The loss function used during training is either mean squared error (MSE) or binary cross entropy error (BCE), with BCE being more common.

So, suppose you want to predict the sex of a person using their age, State of residence, income, and political leaning (conservative, moderate, liberal). Using the standard one-node output technique, the data might look like:

1, 0.24, 1, 0, 0, 0.2950, 0, 0, 1
0, 0.39, 0, 0, 1, 0.5120, 0, 1, 0
1, 0.63, 0, 1, 0, 0.7580, 1, 0, 0
0, 0.36, 1, 0, 0, 0.4450, 0, 1, 0
1, 0.27, 0, 1, 0, 0.2860, 0, 0, 1
. . .

where sex is encoded as 0 = male and 1 = female, age is normalized by dividing by 100, State is one-hot encoded as Michigan = 100, Nebraska – 010, Oklahoma = 001, income is normalized by dividing by 100,000, and political leaning is one-hot encoded as conservative = 100, moderate = 010, liberal = 001. The key PyTorch statements that compute output (using tanh hidden node activation) are:

def forward(self, x):
  z = T.tanh(self.hid1(x))
  z = T.tanh(self.hid2(z))
  z = T.sigmoid(self.oupt(z))
  return z

For the two-node output technique, the data might look like:

0, 1, 0.24, 1, 0, 0, 0.2950, 0, 0, 1
1, 0, 0.39, 0, 0, 1, 0.5120, 0, 1, 0
0, 1, 0.63, 0, 1, 0, 0.7580, 1, 0, 0
1, 0, 0.36, 1, 0, 0, 0.4450, 0, 1, 0
0, 1, 0.27, 0, 1, 0, 0.2860, 0, 0, 1
. . .

where male is encoded as 1 0 and female is encoded as 0 1. The key PyTorch statements that compute output for this approach are:

def forward(self, x):
  z = T.tanh(self.hid1(x))
  z = T.tanh(self.hid2(z))
  z = T.nn.functional.softmax(self.oupt(z), dim=1)
  return z

The output has two nodes like [0.6500, 0.3500] where the softmax activation forces the values in the two nodes to sum to 1. The index of the largest pseudo-probability value is the predicted class. For example, for [0.6500, 0.3500] the largest pseudo-probability is at [0] so the predicted class is male.

The idea behind why using two output nodes instead of one might be useful is a bit tricky to explain. Briefly, if there are two output nodes, there are twice as many hidden-to-output weights. The neural network will be mode difficult to train, but the trained network might have more predictive accuracy — maybe. That’s what I was looking at.

So I put together a demo using PyTorch. The bottom line is that the two-node output approach gave similar results as the one-node output approach. Because the two node approach is a bit more complicated than the one-node approach, and the one-node approach is much more common than the two-node approach, there’s no compelling reason to use the two-node approach for binary classification.

Other than my few experiments, not much is known about using two output nodes for neural network binary prediction.



One of my favorite illustrators of the late 1950s and early 1960s is Edwin Georgi (1896-1964). He was studying to be an engineer at Princeton when World War I broke out in 1914. The U.S. entered the war in 1917, and Georgi joined the U.S. Army Air Corps as a pilot. Other than that, little is known about him.


Demo code.

# people_gender_two_node.py
# binary classification using 2-node with MSE
# PyTorch 2.1.2-CPU Anaconda3-2023.09-0  Python 3.11.5
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

class PeopleDataset(T.utils.data.Dataset):
  # sex   age   state  income  politics
  #  1 0  0.27  0 1 0  0.7610  0 0 1
  #  0 1  0.19  0 0 1  0.6550  1 0 0
  # sex: 10 = male, 01 = female
  # state: michigan, nebraska, oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    all_data = np.loadtxt(src_file, usecols=range(0,10),
      delimiter=",", comments="#", dtype=np.float32) 

    self.x_data = T.tensor(all_data[:,2:10],
      dtype=T.float32).to(device)
    self.y_data = T.tensor(all_data[:,0:2],
      dtype=T.float32).to(device)  # 2D

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx,:]  # idx row, all 8 cols
    sex = self.y_data[idx,:]    # idx row, 2 cols
    return preds, sex  # as a Tuple

# ---------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(8, 10)  # 8-(10-10)-2
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 2)

    T.nn.init.xavier_uniform_(self.hid1.weight) 
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight) 
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight) 
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    # z = T.sigmoid(self.oupt(z))
    z = T.nn.functional.softmax(self.oupt(z), dim=1)
    return z

# ---------------------------------------------------------

def metrics(model, ds, thresh=0.5):
  # note: N = total number of items = TP + FP + TN + FN
  # accuracy  = (TP + TN)  / N
  # precision = TP / (TP + FP)
  # recall    = TP / (TP + FN)
  # F1        = 2 / [(1 / precision) + (1 / recall)]

  tp = 0; tn = 0; fp = 0; fn = 0
  for i in range(len(ds)):
    inpts = ds[i][0].reshape(1,-1)    # dictionary style
    targets = ds[i][1]        # float32  [1 0] or [0 1]
    target_idx = T.argmax(targets)  # 0 or 1
    with T.no_grad():
      preds = model(inpts)       # between 0.0 and 1.0

    pred_idx = T.argmax(preds).item()

    if target_idx == 1 and pred_idx == 1:    # TP
      tp += 1
    elif target_idx == 1 and pred_idx == 0:   # FN
      fn += 1
    elif target_idx == 0 and pred_idx == 0:   # TN
      tn += 1
    elif target_idx == 0 and pred_idx == 1:  # FP
      fp += 1

  N = tp + fp + tn + fn
  if N != len(ds):
    print("FATAL LOGIC ERROR in metrics()")

  accuracy = (tp + tn) / (N * 1.0)
  precision = (1.0 * tp) / (tp + fp)  # tp + fp != 0
  recall = (1.0 * tp) / (tp + fn)     # tp + fn != 0
  f1 = 2.0 / ((1.0 / precision) + (1.0 / recall))
  return (accuracy, precision, recall, f1)  # as a Tuple

# ---------------------------------------------------------

def main():
  # 0. get started
  print("\nPeople gender using PyTorch (two-node) ")
  T.manual_seed(1)
  np.random.seed(1)

  # 1. create Dataset and DataLoader objects
  print("\nCreating People train and test Datasets ")

  train_file = ".\\DataTwoNode\\people_train_two_node.txt"
  test_file = ".\\DataTwoNode\\people_test_two_node.txt"

  train_ds = PeopleDataset(train_file)  # 200 rows
  test_ds = PeopleDataset(test_file)    # 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create neural network
  print("\nCreating 8-(10-10)-2 NN classifier \n")
  net = Net().to(device)
  net.train()  # set into training mode

  # 3. train network
  lrn_rate = 0.02
  # loss_func = T.nn.BCELoss()  # binary cross entropy
  # loss_func = T.nn.MSELoss()
  loss_func = T.nn.CrossEntropyLoss()
  optimizer = T.optim.SGD(net.parameters(),
    lr=lrn_rate)
  max_epochs = 1000
  ep_log_interval = 200

  print("Loss function: " + str(loss_func))
  print("Optimizer: " + str(optimizer.__class__.__name__))
  print("Learn rate: " + "%0.3f" % lrn_rate)
  print("Batch size: " + str(bat_size))
  print("Max epochs: " + str(max_epochs))

  print("\nStarting training")
  for epoch in range(0, max_epochs):
    epoch_loss = 0.0            # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0]             # [bs,8]  inputs
      Y = batch[1]             # [bs,2]  targets
      oupt = net(X)            # [bs,2]  computeds 

      loss_val = loss_func(oupt, Y)   # a tensor
      epoch_loss += loss_val.item()  # accumulate
      optimizer.zero_grad() # reset all gradients
      loss_val.backward()   # compute new gradients
      optimizer.step()      # update all weights

    if epoch % ep_log_interval == 0:
      print("epoch = %4d   loss = %8.4f" % \
        (epoch, epoch_loss))
  print("Done ")

# ---------------------------------------------------------

  # 4. evaluate model
  net.eval()
  metrics_train = metrics(net, train_ds, thresh=0.5)
  print("\nMetrics for train data: ")
  print("accuracy  = %0.4f " % metrics_train[0])
  print("precision = %0.4f " % metrics_train[1])
  print("recall    = %0.4f " % metrics_train[2])
  print("F1        = %0.4f " % metrics_train[3])

  metrics_test = metrics(net, test_ds, thresh=0.5)
  print("\nMetrics for test data: ")
  print("accuracy  = %0.4f " % metrics_test[0])
  print("precision = %0.4f " % metrics_test[1])
  print("recall    = %0.4f " % metrics_test[2])
  print("F1        = %0.4f " % metrics_test[3])

  # 5. save model
  print("\nSaving trained model state_dict ")
  net.eval()
  # path = ".\\Models\\people_gender_model.pt"
  # T.save(net.state_dict(), path)

  # 6. make a prediction 
  print("\nSetting age = 30  Oklahoma  $40,000  moderate ")
  inpt = np.array([[0.30, 0,0,1, 0.4000, 0,1,0]],
    dtype=np.float32)
  inpt = T.tensor(inpt, dtype=T.float32).to(device)

  net.eval()
  with T.no_grad():
    oupt = net(inpt)    # a Tensor
  print("\nComputed output: ")
  print(oupt)
  pred_sex = T.argmax(oupt).item()
  print("\nPredicted class (0 = male, 1 = female) ")
  print(pred_sex)

  print("\nEnd People binary classification demo ")

if __name__== "__main__":
  main()

Training data:

# people_train_two_node.txt
# sex (10 = male 01 = female) - dependent variable
# age,
# state (michigan nebraska oklahoma),
# income,
# politics type (conservative moderate liberal)
#
0, 1, 0.24, 1, 0, 0, 0.2950, 0, 0, 1
1, 0, 0.39, 0, 0, 1, 0.5120, 0, 1, 0
0, 1, 0.63, 0, 1, 0, 0.7580, 1, 0, 0
1, 0, 0.36, 1, 0, 0, 0.4450, 0, 1, 0
0, 1, 0.27, 0, 1, 0, 0.2860, 0, 0, 1
0, 1, 0.50, 0, 1, 0, 0.5650, 0, 1, 0
0, 1, 0.50, 0, 0, 1, 0.5500, 0, 1, 0
1, 0, 0.19, 0, 0, 1, 0.3270, 1, 0, 0
0, 1, 0.22, 0, 1, 0, 0.2770, 0, 1, 0
1, 0, 0.39, 0, 0, 1, 0.4710, 0, 0, 1
0, 1, 0.34, 1, 0, 0, 0.3940, 0, 1, 0
1, 0, 0.22, 1, 0, 0, 0.3350, 1, 0, 0
0, 1, 0.35, 0, 0, 1, 0.3520, 0, 0, 1
1, 0, 0.33, 0, 1, 0, 0.4640, 0, 1, 0
0, 1, 0.45, 0, 1, 0, 0.5410, 0, 1, 0
0, 1, 0.42, 0, 1, 0, 0.5070, 0, 1, 0
1, 0, 0.33, 0, 1, 0, 0.4680, 0, 1, 0
0, 1, 0.25, 0, 0, 1, 0.3000, 0, 1, 0
1, 0, 0.31, 0, 1, 0, 0.4640, 1, 0, 0
0, 1, 0.27, 1, 0, 0, 0.3250, 0, 0, 1
0, 1, 0.48, 1, 0, 0, 0.5400, 0, 1, 0
1, 0, 0.64, 0, 1, 0, 0.7130, 0, 0, 1
0, 1, 0.61, 0, 1, 0, 0.7240, 1, 0, 0
0, 1, 0.54, 0, 0, 1, 0.6100, 1, 0, 0
0, 1, 0.29, 1, 0, 0, 0.3630, 1, 0, 0
0, 1, 0.50, 0, 0, 1, 0.5500, 0, 1, 0
0, 1, 0.55, 0, 0, 1, 0.6250, 1, 0, 0
0, 1, 0.40, 1, 0, 0, 0.5240, 1, 0, 0
0, 1, 0.22, 1, 0, 0, 0.2360, 0, 0, 1
0, 1, 0.68, 0, 1, 0, 0.7840, 1, 0, 0
1, 0, 0.60, 1, 0, 0, 0.7170, 0, 0, 1
1, 0, 0.34, 0, 0, 1, 0.4650, 0, 1, 0
1, 0, 0.25, 0, 0, 1, 0.3710, 1, 0, 0
1, 0, 0.31, 0, 1, 0, 0.4890, 0, 1, 0
0, 1, 0.43, 0, 0, 1, 0.4800, 0, 1, 0
0, 1, 0.58, 0, 1, 0, 0.6540, 0, 0, 1
1, 0, 0.55, 0, 1, 0, 0.6070, 0, 0, 1
1, 0, 0.43, 0, 1, 0, 0.5110, 0, 1, 0
1, 0, 0.43, 0, 0, 1, 0.5320, 0, 1, 0
1, 0, 0.21, 1, 0, 0, 0.3720, 1, 0, 0
0, 1, 0.55, 0, 0, 1, 0.6460, 1, 0, 0
0, 1, 0.64, 0, 1, 0, 0.7480, 1, 0, 0
1, 0, 0.41, 1, 0, 0, 0.5880, 0, 1, 0
0, 1, 0.64, 0, 0, 1, 0.7270, 1, 0, 0
1, 0, 0.56, 0, 0, 1, 0.6660, 0, 0, 1
0, 1, 0.31, 0, 0, 1, 0.3600, 0, 1, 0
1, 0, 0.65, 0, 0, 1, 0.7010, 0, 0, 1
0, 1, 0.55, 0, 0, 1, 0.6430, 1, 0, 0
1, 0, 0.25, 1, 0, 0, 0.4030, 1, 0, 0
0, 1, 0.46, 0, 0, 1, 0.5100, 0, 1, 0
1, 0, 0.36, 1, 0, 0, 0.5350, 1, 0, 0
0, 1, 0.52, 0, 1, 0, 0.5810, 0, 1, 0
0, 1, 0.61, 0, 0, 1, 0.6790, 1, 0, 0
0, 1, 0.57, 0, 0, 1, 0.6570, 1, 0, 0
1, 0, 0.46, 0, 1, 0, 0.5260, 0, 1, 0
1, 0, 0.62, 1, 0, 0, 0.6680, 0, 0, 1
0, 1, 0.55, 0, 0, 1, 0.6270, 1, 0, 0
1, 0, 0.22, 0, 0, 1, 0.2770, 0, 1, 0
1, 0, 0.50, 1, 0, 0, 0.6290, 1, 0, 0
1, 0, 0.32, 0, 1, 0, 0.4180, 0, 1, 0
1, 0, 0.21, 0, 0, 1, 0.3560, 1, 0, 0
0, 1, 0.44, 0, 1, 0, 0.5200, 0, 1, 0
0, 1, 0.46, 0, 1, 0, 0.5170, 0, 1, 0
0, 1, 0.62, 0, 1, 0, 0.6970, 1, 0, 0
0, 1, 0.57, 0, 1, 0, 0.6640, 1, 0, 0
1, 0, 0.67, 0, 0, 1, 0.7580, 0, 0, 1
0, 1, 0.29, 1, 0, 0, 0.3430, 0, 0, 1
0, 1, 0.53, 1, 0, 0, 0.6010, 1, 0, 0
1, 0, 0.44, 1, 0, 0, 0.5480, 0, 1, 0
0, 1, 0.46, 0, 1, 0, 0.5230, 0, 1, 0
1, 0, 0.20, 0, 1, 0, 0.3010, 0, 1, 0
1, 0, 0.38, 1, 0, 0, 0.5350, 0, 1, 0
0, 1, 0.50, 0, 1, 0, 0.5860, 0, 1, 0
0, 1, 0.33, 0, 1, 0, 0.4250, 0, 1, 0
1, 0, 0.33, 0, 1, 0, 0.3930, 0, 1, 0
0, 1, 0.26, 0, 1, 0, 0.4040, 1, 0, 0
0, 1, 0.58, 1, 0, 0, 0.7070, 1, 0, 0
0, 1, 0.43, 0, 0, 1, 0.4800, 0, 1, 0
1, 0, 0.46, 1, 0, 0, 0.6440, 1, 0, 0
0, 1, 0.60, 1, 0, 0, 0.7170, 1, 0, 0
1, 0, 0.42, 1, 0, 0, 0.4890, 0, 1, 0
1, 0, 0.56, 0, 0, 1, 0.5640, 0, 0, 1
1, 0, 0.62, 0, 1, 0, 0.6630, 0, 0, 1
1, 0, 0.50, 1, 0, 0, 0.6480, 0, 1, 0
0, 1, 0.47, 0, 0, 1, 0.5200, 0, 1, 0
1, 0, 0.67, 0, 1, 0, 0.8040, 0, 0, 1
1, 0, 0.40, 0, 0, 1, 0.5040, 0, 1, 0
0, 1, 0.42, 0, 1, 0, 0.4840, 0, 1, 0
0, 1, 0.64, 1, 0, 0, 0.7200, 1, 0, 0
1, 0, 0.47, 1, 0, 0, 0.5870, 0, 0, 1
0, 1, 0.45, 0, 1, 0, 0.5280, 0, 1, 0
1, 0, 0.25, 0, 0, 1, 0.4090, 1, 0, 0
0, 1, 0.38, 1, 0, 0, 0.4840, 1, 0, 0
0, 1, 0.55, 0, 0, 1, 0.6000, 0, 1, 0
1, 0, 0.44, 1, 0, 0, 0.6060, 0, 1, 0
0, 1, 0.33, 1, 0, 0, 0.4100, 0, 1, 0
0, 1, 0.34, 0, 0, 1, 0.3900, 0, 1, 0
0, 1, 0.27, 0, 1, 0, 0.3370, 0, 0, 1
0, 1, 0.32, 0, 1, 0, 0.4070, 0, 1, 0
0, 1, 0.42, 0, 0, 1, 0.4700, 0, 1, 0
1, 0, 0.24, 0, 0, 1, 0.4030, 1, 0, 0
0, 1, 0.42, 0, 1, 0, 0.5030, 0, 1, 0
0, 1, 0.25, 0, 0, 1, 0.2800, 0, 0, 1
0, 1, 0.51, 0, 1, 0, 0.5800, 0, 1, 0
1, 0, 0.55, 0, 1, 0, 0.6350, 0, 0, 1
0, 1, 0.44, 1, 0, 0, 0.4780, 0, 0, 1
1, 0, 0.18, 1, 0, 0, 0.3980, 1, 0, 0
1, 0, 0.67, 0, 1, 0, 0.7160, 0, 0, 1
0, 1, 0.45, 0, 0, 1, 0.5000, 0, 1, 0
0, 1, 0.48, 1, 0, 0, 0.5580, 0, 1, 0
1, 0, 0.25, 0, 1, 0, 0.3900, 0, 1, 0
1, 0, 0.67, 1, 0, 0, 0.7830, 0, 1, 0
0, 1, 0.37, 0, 0, 1, 0.4200, 0, 1, 0
1, 0, 0.32, 1, 0, 0, 0.4270, 0, 1, 0
0, 1, 0.48, 1, 0, 0, 0.5700, 0, 1, 0
1, 0, 0.66, 0, 0, 1, 0.7500, 0, 0, 1
0, 1, 0.61, 1, 0, 0, 0.7000, 1, 0, 0
1, 0, 0.58, 0, 0, 1, 0.6890, 0, 1, 0
0, 1, 0.19, 1, 0, 0, 0.2400, 0, 0, 1
0, 1, 0.38, 0, 0, 1, 0.4300, 0, 1, 0
1, 0, 0.27, 1, 0, 0, 0.3640, 0, 1, 0
0, 1, 0.42, 1, 0, 0, 0.4800, 0, 1, 0
0, 1, 0.60, 1, 0, 0, 0.7130, 1, 0, 0
1, 0, 0.27, 0, 0, 1, 0.3480, 1, 0, 0
0, 1, 0.29, 0, 1, 0, 0.3710, 1, 0, 0
1, 0, 0.43, 1, 0, 0, 0.5670, 0, 1, 0
0, 1, 0.48, 1, 0, 0, 0.5670, 0, 1, 0
0, 1, 0.27, 0, 0, 1, 0.2940, 0, 0, 1
1, 0, 0.44, 1, 0, 0, 0.5520, 1, 0, 0
0, 1, 0.23, 0, 1, 0, 0.2630, 0, 0, 1
1, 0, 0.36, 0, 1, 0, 0.5300, 0, 0, 1
0, 1, 0.64, 0, 0, 1, 0.7250, 1, 0, 0
0, 1, 0.29, 0, 0, 1, 0.3000, 0, 0, 1
1, 0, 0.33, 1, 0, 0, 0.4930, 0, 1, 0
1, 0, 0.66, 0, 1, 0, 0.7500, 0, 0, 1
1, 0, 0.21, 0, 0, 1, 0.3430, 1, 0, 0
0, 1, 0.27, 1, 0, 0, 0.3270, 0, 0, 1
0, 1, 0.29, 1, 0, 0, 0.3180, 0, 0, 1
1, 0, 0.31, 1, 0, 0, 0.4860, 0, 1, 0
0, 1, 0.36, 0, 0, 1, 0.4100, 0, 1, 0
0, 1, 0.49, 0, 1, 0, 0.5570, 0, 1, 0
1, 0, 0.28, 1, 0, 0, 0.3840, 1, 0, 0
1, 0, 0.43, 0, 0, 1, 0.5660, 0, 1, 0
1, 0, 0.46, 0, 1, 0, 0.5880, 0, 1, 0
0, 1, 0.57, 1, 0, 0, 0.6980, 1, 0, 0
1, 0, 0.52, 0, 0, 1, 0.5940, 0, 1, 0
1, 0, 0.31, 0, 0, 1, 0.4350, 0, 1, 0
1, 0, 0.55, 1, 0, 0, 0.6200, 0, 0, 1
0, 1, 0.50, 1, 0, 0, 0.5640, 0, 1, 0
0, 1, 0.48, 0, 1, 0, 0.5590, 0, 1, 0
1, 0, 0.22, 0, 0, 1, 0.3450, 1, 0, 0
0, 1, 0.59, 0, 0, 1, 0.6670, 1, 0, 0
0, 1, 0.34, 1, 0, 0, 0.4280, 0, 0, 1
1, 0, 0.64, 1, 0, 0, 0.7720, 0, 0, 1
0, 1, 0.29, 0, 0, 1, 0.3350, 0, 0, 1
1, 0, 0.34, 0, 1, 0, 0.4320, 0, 1, 0
1, 0, 0.61, 1, 0, 0, 0.7500, 0, 0, 1
0, 1, 0.64, 0, 0, 1, 0.7110, 1, 0, 0
1, 0, 0.29, 1, 0, 0, 0.4130, 1, 0, 0
0, 1, 0.63, 0, 1, 0, 0.7060, 1, 0, 0
1, 0, 0.29, 0, 1, 0, 0.4000, 1, 0, 0
1, 0, 0.51, 1, 0, 0, 0.6270, 0, 1, 0
1, 0, 0.24, 0, 0, 1, 0.3770, 1, 0, 0
0, 1, 0.48, 0, 1, 0, 0.5750, 0, 1, 0
0, 1, 0.18, 1, 0, 0, 0.2740, 1, 0, 0
0, 1, 0.18, 1, 0, 0, 0.2030, 0, 0, 1
0, 1, 0.33, 0, 1, 0, 0.3820, 0, 0, 1
1, 0, 0.20, 0, 0, 1, 0.3480, 1, 0, 0
0, 1, 0.29, 0, 0, 1, 0.3300, 0, 0, 1
1, 0, 0.44, 0, 0, 1, 0.6300, 1, 0, 0
1, 0, 0.65, 0, 0, 1, 0.8180, 1, 0, 0
1, 0, 0.56, 1, 0, 0, 0.6370, 0, 0, 1
1, 0, 0.52, 0, 0, 1, 0.5840, 0, 1, 0
1, 0, 0.29, 0, 1, 0, 0.4860, 1, 0, 0
1, 0, 0.47, 0, 1, 0, 0.5890, 0, 1, 0
0, 1, 0.68, 1, 0, 0, 0.7260, 0, 0, 1
0, 1, 0.31, 0, 0, 1, 0.3600, 0, 1, 0
0, 1, 0.61, 0, 1, 0, 0.6250, 0, 0, 1
0, 1, 0.19, 0, 1, 0, 0.2150, 0, 0, 1
0, 1, 0.38, 0, 0, 1, 0.4300, 0, 1, 0
1, 0, 0.26, 1, 0, 0, 0.4230, 1, 0, 0
0, 1, 0.61, 0, 1, 0, 0.6740, 1, 0, 0
0, 1, 0.40, 1, 0, 0, 0.4650, 0, 1, 0
1, 0, 0.49, 1, 0, 0, 0.6520, 0, 1, 0
0, 1, 0.56, 1, 0, 0, 0.6750, 1, 0, 0
1, 0, 0.48, 0, 1, 0, 0.6600, 0, 1, 0
0, 1, 0.52, 1, 0, 0, 0.5630, 0, 0, 1
1, 0, 0.18, 1, 0, 0, 0.2980, 1, 0, 0
1, 0, 0.56, 0, 0, 1, 0.5930, 0, 0, 1
1, 0, 0.52, 0, 1, 0, 0.6440, 0, 1, 0
1, 0, 0.18, 0, 1, 0, 0.2860, 0, 1, 0
1, 0, 0.58, 1, 0, 0, 0.6620, 0, 0, 1
1, 0, 0.39, 0, 1, 0, 0.5510, 0, 1, 0
1, 0, 0.46, 1, 0, 0, 0.6290, 0, 1, 0
1, 0, 0.40, 0, 1, 0, 0.4620, 0, 1, 0
1, 0, 0.60, 1, 0, 0, 0.7270, 0, 0, 1
0, 1, 0.36, 0, 1, 0, 0.4070, 0, 0, 1
0, 1, 0.44, 1, 0, 0, 0.5230, 0, 1, 0
0, 1, 0.28, 1, 0, 0, 0.3130, 0, 0, 1
0, 1, 0.54, 0, 0, 1, 0.6260, 1, 0, 0

Test data:

# people_test_two_node.txt
#
1, 0, 0.51, 1, 0, 0, 0.6120, 0, 1, 0
1, 0, 0.32, 0, 1, 0, 0.4610, 0, 1, 0
0, 1, 0.55, 1, 0, 0, 0.6270, 1, 0, 0
0, 1, 0.25, 0, 0, 1, 0.2620, 0, 0, 1
0, 1, 0.33, 0, 0, 1, 0.3730, 0, 0, 1
1, 0, 0.29, 0, 1, 0, 0.4620, 1, 0, 0
0, 1, 0.65, 1, 0, 0, 0.7270, 1, 0, 0
1, 0, 0.43, 0, 1, 0, 0.5140, 0, 1, 0
1, 0, 0.54, 0, 1, 0, 0.6480, 0, 0, 1
0, 1, 0.61, 0, 1, 0, 0.7270, 1, 0, 0
0, 1, 0.52, 0, 1, 0, 0.6360, 1, 0, 0
0, 1, 0.30, 0, 1, 0, 0.3350, 0, 0, 1
0, 1, 0.29, 1, 0, 0, 0.3140, 0, 0, 1
1, 0, 0.47, 0, 0, 1, 0.5940, 0, 1, 0
0, 1, 0.39, 0, 1, 0, 0.4780, 0, 1, 0
0, 1, 0.47, 0, 0, 1, 0.5200, 0, 1, 0
1, 0, 0.49, 1, 0, 0, 0.5860, 0, 1, 0
1, 0, 0.63, 0, 0, 1, 0.6740, 0, 0, 1
1, 0, 0.30, 1, 0, 0, 0.3920, 1, 0, 0
1, 0, 0.61, 0, 0, 1, 0.6960, 0, 0, 1
1, 0, 0.47, 0, 0, 1, 0.5870, 0, 1, 0
0, 1, 0.30, 0, 0, 1, 0.3450, 0, 0, 1
1, 0, 0.51, 0, 0, 1, 0.5800, 0, 1, 0
1, 0, 0.24, 1, 0, 0, 0.3880, 0, 1, 0
1, 0, 0.49, 1, 0, 0, 0.6450, 0, 1, 0
0, 1, 0.66, 0, 0, 1, 0.7450, 1, 0, 0
1, 0, 0.65, 1, 0, 0, 0.7690, 1, 0, 0
1, 0, 0.46, 0, 1, 0, 0.5800, 1, 0, 0
1, 0, 0.45, 0, 0, 1, 0.5180, 0, 1, 0
1, 0, 0.47, 1, 0, 0, 0.6360, 1, 0, 0
1, 0, 0.29, 1, 0, 0, 0.4480, 1, 0, 0
1, 0, 0.57, 0, 0, 1, 0.6930, 0, 0, 1
1, 0, 0.20, 1, 0, 0, 0.2870, 0, 0, 1
1, 0, 0.35, 1, 0, 0, 0.4340, 0, 1, 0
1, 0, 0.61, 0, 0, 1, 0.6700, 0, 0, 1
1, 0, 0.31, 0, 0, 1, 0.3730, 0, 1, 0
0, 1, 0.18, 1, 0, 0, 0.2080, 0, 0, 1
0, 1, 0.26, 0, 0, 1, 0.2920, 0, 0, 1
1, 0, 0.28, 1, 0, 0, 0.3640, 0, 0, 1
1, 0, 0.59, 0, 0, 1, 0.6940, 0, 0, 1
This entry was posted in PyTorch. Bookmark the permalink.

1 Response to PyTorch Binary Classification Using Two Output Nodes Instead of One Output Node

  1. Thorsten Kleppe's avatar Thorsten Kleppe says:

    My intuition would be to use 2 outputs in the hope that better predictions can be made. However, for efficiency reasons, I would only use one output. But for multiclass classification, I could just expand 2 outputs to 3 outputs. So 2 to 1 for two output neurons.

    If there is only one output neuron, I would probably use tanh instead of sigmoid, my experience with sigmoid was not so good.

Leave a Reply