Yet Another PyTorch ResNet Example

A ResNet (“residual network”) neural network uses a very complex architecture for image classification. I don’t work with image data very much so I’m not too familiar with ResNet. But because ResNet architecture networks give state of the art results, I figured I should explore and make sure I have a reasonable understanding of ResNet.

I found an excellent example buried in the PyTorch documentation at pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/. That example is based on a different example in the documentation.


I only trained the demo for 3 epochs. Using 80 epochs, the trained model scores about 89% accuracy which is quite good for the CIFAR-10 dataset.

It’s no secret that the best way to learn about any computer science topic is to get a demo up and running, and then make changes to the demo to see the effect of each change. So I did that.

The demo trains a model to classify the CIFAR-10 image dataset. There are 50,000 training images and 10,000 test images. Each image is 32 x 32 pixels. Because the images are in color, there are three channels (red, green, blue). Each channel-pixel value is an integer between 0 and 255. Therefore, each image is represented by 32 * 32 * 3 = 3,072 values between 0 and 255. There are 10 classes: plane (class 0), car, bird, cat, deer, dog, frog, horse, ship, truck (class 9).

I used the built-in CIFAR-10 dataset from the TorchVision library. This is kind of a cheat. In a non-demo scenario I’d fetch raw CIFAR-10 data from a text file. See https://jamesmccaffreyblog.com/2022/03/10/fetching-cifar-10-data-and-saving-as-a-text-file/.

The demo implements a toy ResNet in the sense that the network/model is a small scaled-down version that has only 3 convolution layers. Perhaps the most commonly used version of ResNet is called ResNet-50 because it has 50 layers. Here’s a diagram of ResNet-50 I found for an input color image of size 224 x 224.

My primary impression is that ResNet architecture networks are extremely complex. There is still quite a bit that I don’t fully understand, especially about how one layer is projected to the next. But, as they say, every journey begins with a single step. My second impression is a vague feeling of unease, in the sense that ResNet networks have so many design alternatives and hyperparameters that it’s impossible to explore even a tiny fraction of the alternatives.

But my third impression was that ResNet systems, like all neural networks, are incredibly fascinating.



Three results from an Internet image search for “projection image”. I think images like these would be very difficult for a ResNet system to deal with.


Demo code:

# cifar_resnet.py

# PyTorch 1.10.0-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11 

# a refactor of:
# pytorch-tutorial.readthedocs.io/en/latest/tutorial/
#   chapter03_intermediate/3_2_2_cnn_resnet_cifar10/
# which is a refactor of:
# github.com/pytorch/vision/blob/master/torchvision/
#   models/resnet.py
# which implements:
# https://arxiv.org/pdf/1512.03385.pdf

import numpy as np
import torch as T
import torchvision as tv
import torchvision.transforms as transforms

device = T.device('cpu')

# -----------------------------------------------------------

# 3x3 kernel convolution with no bias
def conv3x3(in_channels, out_channels, stride=1):
  return T.nn.Conv2d(in_channels, out_channels, \
    kernel_size=3, stride=stride, padding=1, bias=False)

# -----------------------------------------------------------

class ResidualBlock(T.nn.Module):
  def __init__(self, in_channels, out_channels, stride=1,
    downsample=None):
    super(ResidualBlock, self).__init__()
    self.conv1 = conv3x3(in_channels, out_channels, stride)
    self.bn1 = T.nn.BatchNorm2d(out_channels)
    # self.relu = T.nn.ReLU(inplace=True)  # huh?
    self.conv2 = conv3x3(out_channels, out_channels)
    self.bn2 = T.nn.BatchNorm2d(out_channels)
    self.downsample = downsample

  def forward(self, x):
    residual = x
    z = self.conv1(x)
    z = self.bn1(z)
    z = T.nn.functional.relu(z)
    z = self.conv2(z)
    z = self.bn2(z)
    if self.downsample:
      residual = self.downsample(x)
    z += residual
    z = T.nn.functional.relu(z)
    return z

# -----------------------------------------------------------

class ResNet(T.nn.Module):
  def __init__(self, block, layers, num_classes=10):
    super(ResNet, self).__init__()
    self.in_channels = 16
    self.conv = conv3x3(3, 16)
    self.bn = T.nn.BatchNorm2d(16)
    self.layer1 = self.make_layer(block, 16, layers[0])
    self.layer2 = self.make_layer(block, 32, layers[1], 2)
    self.layer3 = self.make_layer(block, 64, layers[2], 2)
    self.avg_pool = T.nn.AvgPool2d(8)
    self.fc = T.nn.Linear(64, num_classes)

  def make_layer(self, block, out_channels, blocks, stride=1):
    downsample = None
    if (stride != 1) or (self.in_channels != out_channels):
      downsample = \
        T.nn.Sequential(
          conv3x3(self.in_channels, out_channels, stride=stride),
            T.nn.BatchNorm2d(out_channels))
    layers = []
    layers.append(block(self.in_channels, out_channels,
      stride, downsample))
    self.in_channels = out_channels
    for i in range(1, blocks):
      layers.append(block(out_channels, out_channels))
    return T.nn.Sequential(*layers)

  def forward(self, x):
    z = self.conv(x)
    z = self.bn(z)
    z = T.nn.functional.relu(z)
    z = self.layer1(z)
    z = self.layer2(z)
    z = self.layer3(z)
    z = self.avg_pool(z)
    z = z.view(z.size(0), -1)
    z = self.fc(z)
    return z

# -----------------------------------------------------------

def accuracy(model, data_ldr):
  model.eval()
  n_correct = 0; n_total = 0
  for images, labels in data_ldr:
    images = images.to(device)
    labels = labels.to(device)
    with T.no_grad():
      outputs = model(images)
    _, predicted = T.max(outputs.data, 1)
    n_total += labels.size(0)
    n_correct += (predicted == labels).sum().item()
  model.train()  
  return (n_correct * 1.0) / n_total

# -----------------------------------------------------------

def update_lr(optimizer, lr):    
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

# -----------------------------------------------------------

def main():
  # 0. get ready
  print("\nBegin CIFAR-10 ResNet demo ")
  np.random.seed(1)
  T.manual_seed(1)

  # 1. get train and test data
  print("\nFetching torchvision CIFAR-10 train and test data ")
  transform_set = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

  train_ds = tv.datasets.CIFAR10(root=".\\cifar_tv_data\\",
    train=True, transform=transform_set, download=True)
  test_ds = tv.datasets.CIFAR10(root=".\\cifar_tv_data\\",
    train=False, transform=transforms.ToTensor())

  train_ldr = T.utils.data.DataLoader(dataset=train_ds,
    batch_size=100, shuffle=True)
  test_ldr = T.utils.data.DataLoader(dataset=test_ds,
    batch_size=100, shuffle=False)
  
  # 2. create model
  print("\nCreating ResNet model ")
  model = ResNet(ResidualBlock, [2, 2, 2]).to(device)

# -----------------------------------------------------------

  # 3. train model
  print("\nStarting training ")
  num_epochs = 3  # need about 80 epochs for a good result
  lrn_rate = 0.001
  loss_func = T.nn.CrossEntropyLoss()
  optimizer = T.optim.Adam(model.parameters(), lr=lrn_rate)
  num_batches = len(train_ldr)  # 50_000/100 = 500
  curr_lr = lrn_rate

  for epoch in range(num_epochs):
    for bix, (images, labels) in enumerate(train_ldr):
      images = images.to(device)
      labels = labels.to(device)

      oupts = model(images)
      loss_val = loss_func(oupts, labels)
      optimizer.zero_grad()
      loss_val.backward()
      optimizer.step()

      if (bix+1) % 100 == 0:  # 100 batches = 10,000 images
        acc_train = accuracy(model, train_ldr)
        print("epoch [%2d / %2d]  |  " % \
          (epoch+1, num_epochs), end="")
        print("batch [%2d / %2d]  |  " % \
          (bix+1, num_batches), end="")
        print("loss = %0.4f  |  " % loss_val.item(), end="")
        print("acc = %0.4f " % acc_train)

    # decay learning rate every 20 epochs
    if (epoch+1) % 20 == 0:
      curr_lr /= 3.0
      update_lr(optimizer, curr_lr)
  print("Done ")

  # 4. evaluate model
  print("\nComputing model accuracy ")
  acc_test = accuracy(model, test_ldr)
  print("Accuracy on test data = %0.4f" % acc_test)

  # 5. TODO: save model
  # 6. TODO: use model to make prediction

  print("\nEnd CIFAR-10 ResNet demo ")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.

1 Response to Yet Another PyTorch ResNet Example

  1. Thorsten Kleppe's avatar Thorsten Kleppe says:

    Once again an incredibly fascinating blog post and great how honestly you deal with the ultra complicated topic.

    The essential ResNet idea points out this picture:
    https://media.geeksforgeeks.org/wp-content/uploads/20200424011510/Residual-Block.PNG

    The trick with ResNets is allegedly relatively simple, by keeping the input and output dimensions the same, the input channel was added to the convolution just before the non-linearity with ReLU activation: z = F(x) + x

    Which leads to padding when it comes to the question how to keep the input and output dimensions equal. The traditional ResNets also use batch-norm and dropout.

    But as best as I know, the core is to add the input channel to the output convolution channel. This has allowed the creators of ResNet to build deeper networks with better results, unlike the predecessor CNN architectures which got worse with more layers.

    Hopefully I have understood something right so that I don’t talk nonsense here. The development seems currently highly in motion, after ResNet and ResNext are ConvNext and others, and I would claim there are still some more good ideas that should be considered in this context, e.g. CoordConv.

    The best trick is probably to keep your mind sharp and put this complicated stuff to the test, as you always do so well.

    A nice thought, the paper: “Pooling is neither necessary nor sufficient for appropriate deformation stability in CNNs”

    In abstract the authors say:
    “Together, these findings provide new insights into the role of interleaved pooling and deformation invariance in CNNs, and demonstrate the importance of rigorous empirical testing
    of even our most basic assumptions about the working of neural networks”

    Applying the trick by doing pooling during the convolution with a stride of 2, I was able to build much more usable CNNs that are not busy all day with calculations. 99,2% on MNIST test with simple training in 5 minutes without GPU armies seems nice. However, a ResNet is still far from what is possible for me, but maybe that will change. 🙂

Comments are closed.