I’m in the process of preparing PyTorch machine learning training classes for employees at my company. One of my standard examples is binary classification. I use a set of synthetic Employee data that looks like:
1 0.24 1 0 0 0.2950 0 0 1 0 0.39 0 0 1 0.5120 0 1 0 1 0.63 0 1 0 0.7580 1 0 0 0 0.36 1 0 0 0.4450 0 1 0 . . .
Each line of data represents an employee. The first column is sex (male = 0, female = 1) — the value to predict. The second through ninth columns are age (normalized by dividing by 100), city (anaheim = 100, boulder = 010, concord = 001), income (divided by 100,000), and job type (mgmt = 100, supp = 010, tech = 001).
The demo neural network has a 8-(10-10)-1 architecture with tanh() hidden activation and sigmoid() output activation. Training uses stochastic gradient descent with a fixed learning rate of 0.01 and a batch size of 10. The loss function is BCELoss() (binary cross entropy) but MSELoss() (mean squared error) could have been used.
One weird quirk occurred in the Dataset definition. I load all the data into a 2D array then peel off the class binary labels in column 0:
all_data = np.loadtxt(src_file, max_rows=num_rows, usecols=range(0,9), delimiter="\t", skiprows=0, comments="#", dtype=np.float32) # all 9 columns self.x_data = T.tensor(all_data[:,1:9], dtype=T.float32).to(device) self.y_data = T.tensor(all_data[:,0], dtype=T.float32).to(device).reshape(-1,1) # tricky
Without the reshape(-1,1) on the labels, I got a runtime error. In other words, for binary classification with BCELoss(), the labels had to be in a 2D array. But for multi-class classification with CrossEntropyLoss(), the labels must be in a 1D array. Luckily, I had seen this type of error many times before and so I was able to correct it quickly. But without my background of seeing many shape-related errors, it probably would have taken me a long time to track down and fix the error.

In my opinion, some movies are enhanced by the use of black and white rather than color. Black and white can convey a sense of dread and fear. One example is “Oliver Twist” (1948) based on the novel of the same name by Charles Dickens. Left: Fagin, Oliver and the Artful Dodger in Fagin’s lair in London. Right: Oliver in a London alley at night.
Demo code. Replace “lt”, “gte”, etc. with Boolean symbol operators — my weak blog editor chokes on symbols.
# employee_sex.py
# predict sex from age, city, income, job_type
# PyTorch 1.10.0-CPU Anaconda3-2020.02 Python 3.7.6
# Windows 10/11
import numpy as np
import torch as T
device = T.device('cpu') # apply to Tensor or Module
# -----------------------------------------------------------
class EmployeeDataset(T.utils.data.Dataset):
# sex age city income job
# 0 0.27 0 1 0 0.7610 0 0 1
# 1 0.19 0 0 1 0.6550 1 0 0
# sex: 0 = male, 1 = female
# city: anaheim, boulder, concord
# job: mgmt, supp, tech
def __init__(self, src_file, num_rows=None):
all_data = np.loadtxt(src_file, max_rows=num_rows,
usecols=range(0,9), delimiter="\t", skiprows=0,
comments="#", dtype=np.float32) # all 9 columns
self.x_data = T.tensor(all_data[:,1:9],
dtype=T.float32).to(device)
self.y_data = T.tensor(all_data[:,0],
dtype=T.float32).to(device).reshape(-1,1) # tricky
def __len__(self):
return len(self.x_data)
def __getitem__(self, idx):
preds = self.x_data[idx,:] # idx rows, all 8 cols
sex = self.y_data[idx,:] # idx rows, the only col
return (preds, sex) # as a tuple
# ----------------------------------------------------------
class Net(T.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hid1 = T.nn.Linear(8, 10) # 8-(10-10)-1
self.hid2 = T.nn.Linear(10, 10)
self.oupt = T.nn.Linear(10, 1)
T.nn.init.xavier_uniform_(self.hid1.weight)
T.nn.init.zeros_(self.hid1.bias)
T.nn.init.xavier_uniform_(self.hid2.weight)
T.nn.init.zeros_(self.hid2.bias)
T.nn.init.xavier_uniform_(self.oupt.weight)
T.nn.init.zeros_(self.oupt.bias)
def forward(self, x):
z = T.tanh(self.hid1(x))
z = T.tanh(self.hid2(z))
z = T.sigmoid(self.oupt(z)) # see BCELoss() below
return z
# -----------------------------------------------------------
def accuracy(model, ds):
# ds is a iterable Dataset of Tensors
n_correct = 0; n_wrong = 0
for i in range(len(ds)):
inpts = ds[i][0]
target = ds[i][1] # float32 [0.0] or [1.0]
with T.no_grad():
oupt = model(inpts)
# avoid 'target == 1.0'
if target "lt" 0.5 and oupt "lt" 0.5: # .item() no needed
n_correct += 1
elif target "gte" 0.5 and oupt "gte" 0.5:
n_correct += 1
else:
n_wrong += 1
return (n_correct * 1.0) / (n_correct + n_wrong)
# ---------------------------------------------------------
def acc_coarse(model, ds):
inpts = ds[:][0] # all rows
targets = ds[:][1] # all targets 0s and 1s
with T.no_grad():
oupts = model(inpts) # all computed ouputs
pred_y = oupts "gte" 0.5 # tensor of 0s and 1s
num_correct = T.sum(targets==pred_y)
acc = (num_correct.item() * 1.0 / len(ds)) # scalar
return acc
# ----------------------------------------------------------
def my_bce(model, batch):
# mean binary cross entropy error. somewhat slow
sum = 0.0
inpts = batch[0]
targets = batch[1]
with T.no_grad():
oupts = model(inpts)
for i in range(len(inpts)):
oupt = oupts[i]
# should prevent log(0) which is -infinity
if targets[i] "gte" 0.5: # avoiding == 1.0
sum += T.log(oupt)
else:
sum += T.log(1 - oupt)
return -sum / len(inpts)
# ----------------------------------------------------------
def main():
# 0. get started
print("\nBegin employee gender using PyTorch ")
T.manual_seed(1)
np.random.seed(1)
# 1. create Dataset and DataLoader objects
print("\nCreating Employee train and test Datasets ")
train_file = ".\\Data\\employee_train.txt"
test_file = ".\\Data\\employee_test.txt"
train_ds = EmployeeDataset(train_file) # all 200 rows
test_ds = EmployeeDataset(test_file)
bat_size = 10
train_ldr = T.utils.data.DataLoader(train_ds,
batch_size=bat_size, shuffle=True)
# test_ldr not used
# -----------------------------------------------------------
# 2. create neural network
print("Creating 8-(10-10)-1 sigmoid NN classifier ")
net = Net().to(device)
# 3. train network
print("\nPreparing training")
net.train() # set training mode
lrn_rate = 0.01
loss_func = T.nn.BCELoss() # binary cross entropy
# loss_func = T.nn.MSELoss() # alternative
optimizer = T.optim.SGD(net.parameters(),
lr=lrn_rate)
max_epochs = 1000
ep_log_interval = 100
print("Loss function: " + str(loss_func))
print("Optimizer: SGD")
print("Learn rate: 0.01")
print("Batch size: 10")
print("Max epochs: " + str(max_epochs))
print("\nStarting training")
for epoch in range(0, max_epochs):
epoch_loss = 0.0 # for one full epoch
for (batch_idx, batch) in enumerate(train_ldr):
X = batch[0] # [10,4] inputs
Y = batch[1] # [10,1] targets
oupt = net(X) # [10,1] computeds
loss_val = loss_func(oupt, Y) # a tensor
epoch_loss += loss_val.item() # accumulate
optimizer.zero_grad() # reset all gradients
loss_val.backward() # compute all gradients
optimizer.step() # update all weights
if epoch % ep_log_interval == 0:
print("epoch = %4d | loss = %9.4f" % \
(epoch, epoch_loss))
print("Done ")
# ----------------------------------------------------------
# 4. evaluate model
net.eval()
acc_train = accuracy(net, train_ds)
print("\nAccuracy on train data = %0.2f%%" % \
(acc_train * 100))
acc_test = acc_coarse(net, test_ds)
print("Accuracy on test data = %0.2f%%" % \
(acc_test * 100))
# 5. save model
print("\nSaving trained model state_dict \n")
path = ".\\Models\\employee_model.pt"
# T.save(net.state_dict(), path)
# 6. make a prediction
print("Predicting for 30 concord $40,000 tech")
inpt = np.array([[0.30, 0,0,1, 0.40, 0,1,0]],
dtype=np.float32)
inpt = T.tensor(inpt, dtype=T.float32).to(device)
with T.no_grad():
oupt = net(inpt) # a Tensor
pred_prob = oupt.item() # scalar, [0.0, 1.0]
print("Computed output: ", end="")
print("%0.4f" % pred_prob)
if pred_prob "less_than" 0.5:
print("Prediction = male")
else:
print("Prediction = female")
print("\nEnd Employee binary classification demo")
if __name__== "__main__":
main()
Training data:
# employee_train.txt # # sex (0 = male, 1 = female), age / 100, # city (anaheim = 100, boulder = 010, concord = 001), # income / 100_000, # job type (mgmt = 100, supp = 010, tech = 001) # 1 0.24 1 0 0 0.2950 0 0 1 0 0.39 0 0 1 0.5120 0 1 0 1 0.63 0 1 0 0.7580 1 0 0 0 0.36 1 0 0 0.4450 0 1 0 1 0.27 0 1 0 0.2860 0 0 1 1 0.50 0 1 0 0.5650 0 1 0 1 0.50 0 0 1 0.5500 0 1 0 0 0.19 0 0 1 0.3270 1 0 0 1 0.22 0 1 0 0.2770 0 1 0 0 0.39 0 0 1 0.4710 0 0 1 1 0.34 1 0 0 0.3940 0 1 0 0 0.22 1 0 0 0.3350 1 0 0 1 0.35 0 0 1 0.3520 0 0 1 0 0.33 0 1 0 0.4640 0 1 0 1 0.45 0 1 0 0.5410 0 1 0 1 0.42 0 1 0 0.5070 0 1 0 0 0.33 0 1 0 0.4680 0 1 0 1 0.25 0 0 1 0.3000 0 1 0 0 0.31 0 1 0 0.4640 1 0 0 1 0.27 1 0 0 0.3250 0 0 1 1 0.48 1 0 0 0.5400 0 1 0 0 0.64 0 1 0 0.7130 0 0 1 1 0.61 0 1 0 0.7240 1 0 0 1 0.54 0 0 1 0.6100 1 0 0 1 0.29 1 0 0 0.3630 1 0 0 1 0.50 0 0 1 0.5500 0 1 0 1 0.55 0 0 1 0.6250 1 0 0 1 0.40 1 0 0 0.5240 1 0 0 1 0.22 1 0 0 0.2360 0 0 1 1 0.68 0 1 0 0.7840 1 0 0 0 0.60 1 0 0 0.7170 0 0 1 0 0.34 0 0 1 0.4650 0 1 0 0 0.25 0 0 1 0.3710 1 0 0 0 0.31 0 1 0 0.4890 0 1 0 1 0.43 0 0 1 0.4800 0 1 0 1 0.58 0 1 0 0.6540 0 0 1 0 0.55 0 1 0 0.6070 0 0 1 0 0.43 0 1 0 0.5110 0 1 0 0 0.43 0 0 1 0.5320 0 1 0 0 0.21 1 0 0 0.3720 1 0 0 1 0.55 0 0 1 0.6460 1 0 0 1 0.64 0 1 0 0.7480 1 0 0 0 0.41 1 0 0 0.5880 0 1 0 1 0.64 0 0 1 0.7270 1 0 0 0 0.56 0 0 1 0.6660 0 0 1 1 0.31 0 0 1 0.3600 0 1 0 0 0.65 0 0 1 0.7010 0 0 1 1 0.55 0 0 1 0.6430 1 0 0 0 0.25 1 0 0 0.4030 1 0 0 1 0.46 0 0 1 0.5100 0 1 0 0 0.36 1 0 0 0.5350 1 0 0 1 0.52 0 1 0 0.5810 0 1 0 1 0.61 0 0 1 0.6790 1 0 0 1 0.57 0 0 1 0.6570 1 0 0 0 0.46 0 1 0 0.5260 0 1 0 0 0.62 1 0 0 0.6680 0 0 1 1 0.55 0 0 1 0.6270 1 0 0 0 0.22 0 0 1 0.2770 0 1 0 0 0.50 1 0 0 0.6290 1 0 0 0 0.32 0 1 0 0.4180 0 1 0 0 0.21 0 0 1 0.3560 1 0 0 1 0.44 0 1 0 0.5200 0 1 0 1 0.46 0 1 0 0.5170 0 1 0 1 0.62 0 1 0 0.6970 1 0 0 1 0.57 0 1 0 0.6640 1 0 0 0 0.67 0 0 1 0.7580 0 0 1 1 0.29 1 0 0 0.3430 0 0 1 1 0.53 1 0 0 0.6010 1 0 0 0 0.44 1 0 0 0.5480 0 1 0 1 0.46 0 1 0 0.5230 0 1 0 0 0.20 0 1 0 0.3010 0 1 0 0 0.38 1 0 0 0.5350 0 1 0 1 0.50 0 1 0 0.5860 0 1 0 1 0.33 0 1 0 0.4250 0 1 0 0 0.33 0 1 0 0.3930 0 1 0 1 0.26 0 1 0 0.4040 1 0 0 1 0.58 1 0 0 0.7070 1 0 0 1 0.43 0 0 1 0.4800 0 1 0 0 0.46 1 0 0 0.6440 1 0 0 1 0.60 1 0 0 0.7170 1 0 0 0 0.42 1 0 0 0.4890 0 1 0 0 0.56 0 0 1 0.5640 0 0 1 0 0.62 0 1 0 0.6630 0 0 1 0 0.50 1 0 0 0.6480 0 1 0 1 0.47 0 0 1 0.5200 0 1 0 0 0.67 0 1 0 0.8040 0 0 1 0 0.40 0 0 1 0.5040 0 1 0 1 0.42 0 1 0 0.4840 0 1 0 1 0.64 1 0 0 0.7200 1 0 0 0 0.47 1 0 0 0.5870 0 0 1 1 0.45 0 1 0 0.5280 0 1 0 0 0.25 0 0 1 0.4090 1 0 0 1 0.38 1 0 0 0.4840 1 0 0 1 0.55 0 0 1 0.6000 0 1 0 0 0.44 1 0 0 0.6060 0 1 0 1 0.33 1 0 0 0.4100 0 1 0 1 0.34 0 0 1 0.3900 0 1 0 1 0.27 0 1 0 0.3370 0 0 1 1 0.32 0 1 0 0.4070 0 1 0 1 0.42 0 0 1 0.4700 0 1 0 0 0.24 0 0 1 0.4030 1 0 0 1 0.42 0 1 0 0.5030 0 1 0 1 0.25 0 0 1 0.2800 0 0 1 1 0.51 0 1 0 0.5800 0 1 0 0 0.55 0 1 0 0.6350 0 0 1 1 0.44 1 0 0 0.4780 0 0 1 0 0.18 1 0 0 0.3980 1 0 0 0 0.67 0 1 0 0.7160 0 0 1 1 0.45 0 0 1 0.5000 0 1 0 1 0.48 1 0 0 0.5580 0 1 0 0 0.25 0 1 0 0.3900 0 1 0 0 0.67 1 0 0 0.7830 0 1 0 1 0.37 0 0 1 0.4200 0 1 0 0 0.32 1 0 0 0.4270 0 1 0 1 0.48 1 0 0 0.5700 0 1 0 0 0.66 0 0 1 0.7500 0 0 1 1 0.61 1 0 0 0.7000 1 0 0 0 0.58 0 0 1 0.6890 0 1 0 1 0.19 1 0 0 0.2400 0 0 1 1 0.38 0 0 1 0.4300 0 1 0 0 0.27 1 0 0 0.3640 0 1 0 1 0.42 1 0 0 0.4800 0 1 0 1 0.60 1 0 0 0.7130 1 0 0 0 0.27 0 0 1 0.3480 1 0 0 1 0.29 0 1 0 0.3710 1 0 0 0 0.43 1 0 0 0.5670 0 1 0 1 0.48 1 0 0 0.5670 0 1 0 1 0.27 0 0 1 0.2940 0 0 1 0 0.44 1 0 0 0.5520 1 0 0 1 0.23 0 1 0 0.2630 0 0 1 0 0.36 0 1 0 0.5300 0 0 1 1 0.64 0 0 1 0.7250 1 0 0 1 0.29 0 0 1 0.3000 0 0 1 0 0.33 1 0 0 0.4930 0 1 0 0 0.66 0 1 0 0.7500 0 0 1 0 0.21 0 0 1 0.3430 1 0 0 1 0.27 1 0 0 0.3270 0 0 1 1 0.29 1 0 0 0.3180 0 0 1 0 0.31 1 0 0 0.4860 0 1 0 1 0.36 0 0 1 0.4100 0 1 0 1 0.49 0 1 0 0.5570 0 1 0 0 0.28 1 0 0 0.3840 1 0 0 0 0.43 0 0 1 0.5660 0 1 0 0 0.46 0 1 0 0.5880 0 1 0 1 0.57 1 0 0 0.6980 1 0 0 0 0.52 0 0 1 0.5940 0 1 0 0 0.31 0 0 1 0.4350 0 1 0 0 0.55 1 0 0 0.6200 0 0 1 1 0.50 1 0 0 0.5640 0 1 0 1 0.48 0 1 0 0.5590 0 1 0 0 0.22 0 0 1 0.3450 1 0 0 1 0.59 0 0 1 0.6670 1 0 0 1 0.34 1 0 0 0.4280 0 0 1 0 0.64 1 0 0 0.7720 0 0 1 1 0.29 0 0 1 0.3350 0 0 1 0 0.34 0 1 0 0.4320 0 1 0 0 0.61 1 0 0 0.7500 0 0 1 1 0.64 0 0 1 0.7110 1 0 0 0 0.29 1 0 0 0.4130 1 0 0 1 0.63 0 1 0 0.7060 1 0 0 0 0.29 0 1 0 0.4000 1 0 0 0 0.51 1 0 0 0.6270 0 1 0 0 0.24 0 0 1 0.3770 1 0 0 1 0.48 0 1 0 0.5750 0 1 0 1 0.18 1 0 0 0.2740 1 0 0 1 0.18 1 0 0 0.2030 0 0 1 1 0.33 0 1 0 0.3820 0 0 1 0 0.20 0 0 1 0.3480 1 0 0 1 0.29 0 0 1 0.3300 0 0 1 0 0.44 0 0 1 0.6300 1 0 0 0 0.65 0 0 1 0.8180 1 0 0 0 0.56 1 0 0 0.6370 0 0 1 0 0.52 0 0 1 0.5840 0 1 0 0 0.29 0 1 0 0.4860 1 0 0 0 0.47 0 1 0 0.5890 0 1 0 1 0.68 1 0 0 0.7260 0 0 1 1 0.31 0 0 1 0.3600 0 1 0 1 0.61 0 1 0 0.6250 0 0 1 1 0.19 0 1 0 0.2150 0 0 1 1 0.38 0 0 1 0.4300 0 1 0 0 0.26 1 0 0 0.4230 1 0 0 1 0.61 0 1 0 0.6740 1 0 0 1 0.40 1 0 0 0.4650 0 1 0 0 0.49 1 0 0 0.6520 0 1 0 1 0.56 1 0 0 0.6750 1 0 0 0 0.48 0 1 0 0.6600 0 1 0 1 0.52 1 0 0 0.5630 0 0 1 0 0.18 1 0 0 0.2980 1 0 0 0 0.56 0 0 1 0.5930 0 0 1 0 0.52 0 1 0 0.6440 0 1 0 0 0.18 0 1 0 0.2860 0 1 0 0 0.58 1 0 0 0.6620 0 0 1 0 0.39 0 1 0 0.5510 0 1 0 0 0.46 1 0 0 0.6290 0 1 0 0 0.40 0 1 0 0.4620 0 1 0 0 0.60 1 0 0 0.7270 0 0 1 1 0.36 0 1 0 0.4070 0 0 1 1 0.44 1 0 0 0.5230 0 1 0 1 0.28 1 0 0 0.3130 0 0 1 1 0.54 0 0 1 0.6260 1 0 0
Test data:
# employee_test.txt # 0 0.51 1 0 0 0.6120 0 1 0 0 0.32 0 1 0 0.4610 0 1 0 1 0.55 1 0 0 0.6270 1 0 0 1 0.25 0 0 1 0.2620 0 0 1 1 0.33 0 0 1 0.3730 0 0 1 0 0.29 0 1 0 0.4620 1 0 0 1 0.65 1 0 0 0.7270 1 0 0 0 0.43 0 1 0 0.5140 0 1 0 0 0.54 0 1 0 0.6480 0 0 1 1 0.61 0 1 0 0.7270 1 0 0 1 0.52 0 1 0 0.6360 1 0 0 1 0.30 0 1 0 0.3350 0 0 1 1 0.29 1 0 0 0.3140 0 0 1 0 0.47 0 0 1 0.5940 0 1 0 1 0.39 0 1 0 0.4780 0 1 0 1 0.47 0 0 1 0.5200 0 1 0 0 0.49 1 0 0 0.5860 0 1 0 0 0.63 0 0 1 0.6740 0 0 1 0 0.30 1 0 0 0.3920 1 0 0 0 0.61 0 0 1 0.6960 0 0 1 0 0.47 0 0 1 0.5870 0 1 0 1 0.30 0 0 1 0.3450 0 0 1 0 0.51 0 0 1 0.5800 0 1 0 0 0.24 1 0 0 0.3880 0 1 0 0 0.49 1 0 0 0.6450 0 1 0 1 0.66 0 0 1 0.7450 1 0 0 0 0.65 1 0 0 0.7690 1 0 0 0 0.46 0 1 0 0.5800 1 0 0 0 0.45 0 0 1 0.5180 0 1 0 0 0.47 1 0 0 0.6360 1 0 0 0 0.29 1 0 0 0.4480 1 0 0 0 0.57 0 0 1 0.6930 0 0 1 0 0.20 1 0 0 0.2870 0 0 1 0 0.35 1 0 0 0.4340 0 1 0 0 0.61 0 0 1 0.6700 0 0 1 0 0.31 0 0 1 0.3730 0 1 0 1 0.18 1 0 0 0.2080 0 0 1 1 0.26 0 0 1 0.2920 0 0 1 0 0.28 1 0 0 0.3640 0 0 1 0 0.59 0 0 1 0.6940 0 0 1

.NET Test Automation Recipes
Software Testing
SciPy Programming Succinctly
Keras Succinctly
R Programming
2026 Visual Studio Live
2025 Summer MLADS Conference
2026 DevIntersection Conference
2025 Machine Learning Week
2025 Ai4 Conference
2026 G2E Conference
2026 iSC West Conference
You must be logged in to post a comment.