Fashion-MNIST Classification Example Using PyTorch

The MNIST (Modified National Institute of Standards and Technology) dataset has 60,000 training and 10,000 test data items. Each item is a crude image of a handwritten digit from ‘0’ to ‘9’. Each image is 28 by 28 pixels, and each pixel is a grayscle value from 0 (white) to 255 (black).

Years ago, in the early days of neural networks, creating a classifier for MNIST was a major challenge. But with today’s hardware (fast CPUs) and neural library software (PyTorch, TensorFlow), MNIST is not a difficult problem. Even a basic CNN can score 98% accuracy.

The Fashion-MNIST dataset was created to be a drop-in replacement for MNIST but be more difficult. Instead of ten digits, Fashion-MNIST images are ten different articles of clothing: T-shirt, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle-boot.

As part of an exploration of warm-start training, I decided to create a CNN classifier for Fashion-MNIST data using the same architecture design that I use for my basic MNIST CNN example. The architecture usew two convolution layers and three linear layers.

My first step was to get the Fashion-MNIST data into an easy-to-use form by converting the raw binary files to text files. I also used a small 1,000-item training subset and a 100-item test subset to keep things simpler and faster. See https://jamesmccaffreyblog.com/2022/08/22/converting-fashion-mnist-binary-files-to-text-files/.

As expected, the neural network designed for MNIST data worked reasonably well for the Fashion-MNIST data. I now have a baseline set of accuracy results derived from training Fashion-MNIST from scratch. My next step will be to initialize the Fashion-MNIST network with weights and biases values from a trained MNIST network to see if that leads to a better Fashion-MNIST model.



In a stunning upset, the small (pop. 34 million) Central Asia country of Uzbekistan won the 2022 Chess Olympiad team title. There were 188 teams, but the strong Russia and China teams did not participate because of politics. The top-seeded United States finished in fifth place. The small (pop. 24 million) West Africa country of Niger finished in last place.

Left: Uzbekistan (1st place in Chess Olympiad) traditional fashion . Center: Kyrgyzstan (64th place in Olympiad) fashion. Right: Kazakhstan (17th place in Olympiad) fashion .


Demo code. Replace “lt”, “gt”, “lte”, “gte” with Boolean operator symbols.

# fmnist_cnn.py
# PyTorch 1.10.0-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11

# reads Fashion-MNIST data from text file rather than using
# built-in black box Dataset from torchvision

import numpy as np
import matplotlib.pyplot as plt
import torch as T

device = T.device('cpu')

# -----------------------------------------------------------

class FMNIST_Dataset(T.utils.data.Dataset):
  # 784 tab-delim pixel values (0-255) then label (0-9)
  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(785),
      delimiter="\t", comments="#", dtype=np.float32)

    tmp_x = all_xy[:, 0:784]  # all rows, cols [0,783]
    tmp_x /= 255.0
    tmp_x = tmp_x.reshape(-1, 1, 28, 28)  # bs, chnls, 28x28
    tmp_y = all_xy[:, 784]    # 1-D required

    self.x_data = \
      T.tensor(tmp_x, dtype=T.float32).to(device)
    self.y_data = \
      T.tensor(tmp_y, dtype=T.int64).to(device) 

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    lbl = self.y_data[idx] 
    pixels = self.x_data[idx] 
    return (pixels, lbl)

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()  # pre Python 3.3 syntax

    self.conv1 = T.nn.Conv2d(1, 32, 5)  # chnl-in, out, krnl
    self.conv2 = T.nn.Conv2d(32, 64, 5)
    self.fc1 = T.nn.Linear(1024, 512)   # [64*4*4, x]
    self.fc2 = T.nn.Linear(512, 256)
    self.fc3 = T.nn.Linear(256, 10)     # 10 classes
    self.pool1 = T.nn.MaxPool2d(2, 2)   # kernel, stride
    self.pool2 = T.nn.MaxPool2d(2, 2)
    self.drop1 = T.nn.Dropout(0.25)
    self.drop2 = T.nn.Dropout(0.50)
    # default weight and bias initialization
  
  def forward(self, x):
    # convolution phase         # x is [bs, 1, 28, 28]
    z = T.relu(self.conv1(x))   # Size([bs, 32, 24, 24])
    z = self.pool1(z)           # Size([bs, 32, 12, 12])
    z = self.drop1(z)
    z = T.relu(self.conv2(z))   # Size([bs, 64, 8, 8])
    z = self.pool2(z)           # Size([bs, 64, 4, 4])
   
    # neural network phase
    z = z.reshape(-1, 1024)     # Size([bs, 1024])
    z = T.relu(self.fc1(z))     # Size([bs, 512])
    z = self.drop2(z)
    z = T.relu(self.fc2(z))     # Size([bs, 256])
    z = self.fc3(z)             # Size([bs, 10])
    return z

# -----------------------------------------------------------

def accuracy(model, ds):
  ldr = T.utils.data.DataLoader(ds,
    batch_size=len(ds), shuffle=False)
  n_correct = 0
  for data in ldr:
    (pixels, labels) = data
    with T.no_grad():
      oupts = model(pixels)
    (_, predicteds) = T.max(oupts, 1)
    n_correct += (predicteds == labels).sum().item()

  acc = (n_correct * 1.0) / len(ds)
  return acc

# -----------------------------------------------------------

def main():
  # 0. setup
  print("\nBegin Fashion-MNIST with PyTorch CNN demo ")
  np.random.seed(1)
  T.manual_seed(1)

  # 1. create Dataset
  print("\nCreating 1000-item train Dataset from text file ")
  train_file = ".\\Data\\f-mnist_train_1000.txt"
  train_ds = FMNIST_Dataset(train_file)

  bat_size = 20
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create network
  print("\nCreating CNN network with 2 conv and 3 linear ")
  net = Net().to(device)
  
# -----------------------------------------------------------

  # 3. train model
  max_epochs = 30  # 100 gives better results
  ep_log_interval = 5
  lrn_rate = 0.05
  
  loss_func = T.nn.CrossEntropyLoss()  # does log-softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)
  # optimizer = T.optim.Adam(net.parameters(), lr=0.005)
  
  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("lrn_rate = %0.3f " % lrn_rate)
  print("max_epochs = %3d " % max_epochs)

  print("\nStarting training")
  net.train()  # set mode
  for epoch in range(0, max_epochs):
    ep_loss = 0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      (X, y) = batch  # X = pixels, y = target labels
      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, y)  # a tensor
      ep_loss += loss_val.item()  # accumulate
      loss_val.backward()  # compute grads
      optimizer.step()     # update weights
    if epoch % ep_log_interval == 0:
      print("epoch = %4d   |  loss = %9.4f" % (epoch, ep_loss))
  print("Done ") 

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy")
  net.eval()
  acc_train = accuracy(net, train_ds)  # all at once
  print("Accuracy on training data = %0.4f" % acc_train)

  test_file = ".\\Data\\f-mnist_test_100.txt"
  test_ds = FMNIST_Dataset(test_file)
  net.eval()
  acc_test = accuracy(net, test_ds)  # all at once
  print("Accuracy on test data = %0.4f" % acc_test)

# -----------------------------------------------------------

  # 5. use model
  # print("\nMaking prediction for fake image: ")
  # x = np.zeros(shape=(28,28), dtype=np.float32)
  # for row in range(28):
  #   for col in range(28):
  #     x[row][col] = np.random.randint(0,256)
  # x /= 255.0

  items = [ "T-shirt", "Trouser", "Pullover", "Dress",
    "Coat", "Sandal", "Shirt", "Sneaker", "Bag", 
    "Ankle_Boot" ]

  print("\nMaking prediction for first test image: ")
  x = test_ds[0][0].numpy()
  x = x.reshape(28,28)

  plt.tight_layout()
  plt.imshow(x, cmap=plt.get_cmap('gray_r'))
  plt.show()

  print("\nActual image label: ")
  y = test_ds[0][1].item()
  print(items[y])

  x = x.reshape(1, 1, 28, 28)  # 1 image, 1 channel
  x = T.tensor(x, dtype=T.float32).to(device)
  with T.no_grad():
    oupt = net(x)  # 10 logits like [[-0.12, 1.03, . . ]]
  pred_probs = T.softmax(oupt, dim=1)
  print("\nPrediction probabilities: ")
  np.set_printoptions(precision=4, suppress=True)
  print(pred_probs.numpy())

  
  
  am = T.argmax(oupt) # 0 to 9
  print("\nPredicted item is \'" + items[am] + "\'")

# -----------------------------------------------------------

  # 6. save model
  print("\nSaving trained model state")
  fn = ".\\Models\\fmnist_model.pt"
  T.save(net.state_dict(), fn)  

  print("\nEnd Fashion-MNIST PyTorch CNN demo ")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.