How PyTorch Computes BCE Loss

I was exploring binary classification using the PyTorch neural network library. During training you must compute a loss value, which is the same as error. By far the most common form of loss for binary classification is binary cross entropy (BCE). The loss value is used to determine how to update the weight values during training. And you almost always print the value of BCE during training so you can tell if training is working or not.

The key code for training looks like:

loss_func = T.nn.BCELoss()  # binary cross entropy
optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)

for (batch_idx, batch) in enumerate(train_ldr):
  X = batch['predictors']  # [10,4]    # input values
  Y = batch['target']  # [10,1]        # 0 or 1

  optimizer.zero_grad()
  oupt = net(X)            # computed output like 0.68
  loss_obj = loss_func(oupt, Y)  # a tensor
  epoch_loss += loss_obj.item()  # accumulate
  loss_obj.backward()
  optimizer.step()

This code assumes that batches of data are being served up by a DataLoader object, which is a very complex topic in itself. The code above computes error for a batch of training items — but it’s not clear exactly what going on behind the scenes. For example, is the BCE loss value the total loss for all items in the input batch, or is it the average loss for the items? So I decided to code up a custom, from scratch, implementation of BCE loss. The idea is that if I replicated the results of the built-in PyTorch BCELoss() function, then I’d be sure I completely understand what’s happening.

After a few hours of coding, I succeeded in writing a function called my_bce() that got the exact same results as the library BCELoss() function. My code is:

def my_bce(model, batch):
  # mean binary cross entropy error.
  sum = 0.0
  inpts = batch['predictors']
  targets = batch['target']
  oupts = model(inpts)
  for i in range(len(inpts)):
    oupt = oupts[i]
    if targets[i] >= 0.5:  # avoiding == 1.0
      sum += T.log(oupt)
    else:
      sum += T.log(1 - oupt)

  return -sum / len(inpts)

The code is short but quite tricky. To understand the code, you have to understand how cross entropy error is computed. Suppose you have

inputs            target   output
---------------------------------
(0.23, 0.34, 0.45)  1       0.60
(0.32, 0.87, 0.72)  0       0.20
(0.48, 0.92, 0.55)  1       0.70

The total binary cross entropy and average BCE are:

 -1 * log(0.60)      = 0.51
 -1 * log(1 - 0.20)  = 0.22
 -1 * log(0.70)      = 0.36
                    --------
           total BCE = 1.09

mean BCE = 1.09 / 3 = 0.3633

In words, for an item, if the target is 1, the binary cross entropy is minus the log of the computed output. If the target is 0, the binary cross entropy is minus the log of 1 – computed output. The total binary cross entropy is the sum of the terms.

Notice that for my experimental my_bce() function code, I don’t put in a check for the case of trying to compute log(0) which is negative infinity.

Anyway, dissecting the PyTorch binary cross entropy function was interesting and a lot of fun.



Three interesting illustrations by artist Casimir Lee. I like the bright colors and mixed media. Unlike neural networks, I don’t think there’s any way to quantify error for art.

This entry was posted in PyTorch. Bookmark the permalink.