Example of a PyTorch Multi-Class Classifier Using a Transformer

I’d been experimenting with the idea of using a Transformer module as the core of a multi-class classifier. The idea is rather weird because Transformer systems were designed to accept sequential information, such as a sequence of words, where order matters. After many weeks, I finally got a successful demo up and running.

The key to the Transformer based classifier was the creation of a helper module that creates the equivalent of a numeric embedding layer to mimic a standard Embedding layer that’s used for NLP problems. In NLP, each word/token in the input sequence is an integer, like “the” = 5, “boy” = 678, etc. Each integer is mapped to a vector like 5 = [0.123, -9.876, . . . ]. The number of values in the vector is usually about 100 or so and is called the embedding dimension.

A standard Embedding layer is implemented as a lookup table where the integer acts as an index. But for multi-class classification, all the inputs are floating point values, so I needed to implement a fairly complex PyTorch module that I named a SkipLayer because it’s like a neural layer that’s not fully connected — some of the connections/weights are skipped.

I used one of my standard synthetic datasets for my demo. The data looks like:

 1   0.24   1  0  0   0.2950   2
-1   0.39   0  0  1   0.5120   1
 1   0.63   0  1  0   0.7580   0
-1   0.36   1  0  0   0.4450   1
. . .

The fields are sex (male = -1, female = +1), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by $100,000), and political leaning (0 = conservative, moderate = 1, liberal = 2). The goal is to predict political leaning from sex, age, State, and income.

My neural architecture is (6–24)-T-10-3 meaning the input is 6 values, mapped to 24 values (i.e., a numeric embedding dim = 4), into a Transformer, sent to a hidden layer of 10 nodes, which is sent to an output layer of 3 nodes (one for each possible political leaning value).

class TransformerNet(T.nn.Module):  # (6--24)-T-10-3
  def __init__(self):
    super(TransformerNet, self).__init__()  # old syntax

    # numeric pseudo-embedding, dim=4
    self.embed = SkipLinear(6, 24)  # 6 inputs, each goes to 4 

    self.pos_enc = \
      PositionalEncoding(4, dropout=0.00)  # positional

    self.enc_layer = T.nn.TransformerEncoderLayer(d_model=4,
      nhead=2, dim_feedforward=10, 
      batch_first=True)  # d_model divisible by nhead

    self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
      num_layers=2)  # 6 layers default

    # People dataset has 6 inputs
    self.fc1 = T.nn.Linear(4*6, 10)  # 10 hidden nodes
    self.fc2 = T.nn.Linear(10, 3)    # 3 classes

  def forward(self, x):
    # x = 6 inputs, fixed length
    z = self.embed(x)  # 6 inpts to 24 embed 
    z = z.reshape(-1, 6, 4)  # bat seq embed 
    z = self.pos_enc(z) 
    z = self.trans_enc(z) 
    z = z.reshape(-1, 4*6)  # torch.Size([bs, xxx])
    z = T.tanh(self.fc1(z))
    z = T.log_softmax(self.fc2(z), dim=1)  # NLLLoss()
    return z 

This was a very satisfying Transformer architecture experiment.



In 1950s science fiction movies, radiation is often the cause of unfortunate transformations.

Left: “First Man in Space” (1959) – Test pilot Bill Edwards ignores orders and flies the experimental Y-13 into the ionosphere where he gets exposed to cosmic rays. He turns into a weird encrusted being that craves blood. Doesn’t end well for him.

Center: “The H-Man” (1958) – In this Japanese movie, fallout from a hydrogen bomb test turns some people into vaporous beings. The plot confused me a bit — there are gangsters, nightclub singers, scientists, and police. In the end, all the H-Men are destroyed.

Right: “From Hell It Came” (1957) – On a South Pacific island, a native man named Kimo is framed for a murder and executed. He is buried in a hollow tree trunk. Unfortunately, the island is close to the location of several atomic bomb tests. The tree-Kimo eventually meets his end in quicksand.


Demo code below. Data can be found at https://jamesmccaffreyblog.com/2022/09/01/multi-class-classification-using-pytorch-1-12-1-on-windows-10-11/.

# people_transformer.py
# PyTorch 2.0.0-CPU Anaconda3-2022.10  Python 3.9.13
# Windows 10/11

# naive Transformer architecture for People political leaning

import numpy as np
import torch as T

device = T.device('cpu')
T.set_num_threads(1)

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex  age    state    income   politics
  # -1   0.27   0  1  0   0.7610   2
  # +1   0.19   0  0  1   0.6550   0
  # sex: -1 = male, +1 = female
  # state: michigan, nebraska, oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,7),
      delimiter="\t", comments="#", dtype=np.float32)
    tmp_x = all_xy[:,0:6]   # cols [0,6) = [0,5]
    tmp_y = all_xy[:,6]     # 1-D

    self.x_data = T.tensor(tmp_x, 
      dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y,
      dtype=T.int64).to(device)  # 1-D

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return preds, trgts  # as a Tuple

# -----------------------------------------------------------

class SkipLinear(T.nn.Module):

  # -----

  class Core(T.nn.Module):
    def __init__(self, n):
      super().__init__()
      # 1 node to n nodes, n gte 2
      self.weights = T.nn.Parameter(T.zeros((n,1),
        dtype=T.float32))
      self.biases = T.nn.Parameter(T.tensor(n,
        dtype=T.float32))
      lim = 0.01
      T.nn.init.uniform_(self.weights, -lim, lim)
      T.nn.init.zeros_(self.biases)

    def forward(self, x):
      wx= T.mm(x, self.weights.t())
      v = T.add(wx, self.biases)
      return v

  # -----

  def __init__(self, n_in, n_out):
    super().__init__()
    self.n_in = n_in; self.n_out = n_out
    if n_out  % n_in != 0:
      print("FATAL: n_out must be divisible by n_in")
    n = n_out // n_in  # num nodes per input

    self.lst_modules = \
      T.nn.ModuleList([SkipLinear.Core(n) for \
        i in range(n_in)])

  def forward(self, x):
    lst_nodes = []
    for i in range(self.n_in):
      xi = x[:,i].reshape(-1,1)
      oupt = self.lst_modules[i](xi)
      lst_nodes.append(oupt)
    result = T.cat((lst_nodes[0], lst_nodes[1]), 1)
    for i in range(2,self.n_in):
      result = T.cat((result, lst_nodes[i]), 1)
    result = result.reshape(-1, self.n_out)
    return result

# -----------------------------------------------------------

class TransformerNet(T.nn.Module):  # (6--24)-T-10-3
  def __init__(self):
    super(TransformerNet, self).__init__()  # old syntax

    # numeric pseudo-embedding, dim=4
    self.embed = SkipLinear(6, 24)  # 6 inputs, each goes to 4 

    self.pos_enc = \
      PositionalEncoding(4, dropout=0.00)  # positional

    self.enc_layer = T.nn.TransformerEncoderLayer(d_model=4,
      nhead=2, dim_feedforward=10, 
      batch_first=True)  # d_model divisible by nhead

    self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
      num_layers=2)  # 6 layers default

    # People dataset has 6 inputs
    self.fc1 = T.nn.Linear(4*6, 10)  # 10 hidden nodes
    self.fc2 = T.nn.Linear(10, 3)    # 3 classes

  def forward(self, x):
    # x = 6 inputs, fixed length
    z = self.embed(x)  # 6 inpts to 24 embed 
    z = z.reshape(-1, 6, 4)  # bat seq embed 
    z = self.pos_enc(z) 
    z = self.trans_enc(z) 
    z = z.reshape(-1, 4*6)  # torch.Size([bs, xxx])
    z = T.tanh(self.fc1(z))
    z = T.log_softmax(self.fc2(z), dim=1)  # NLLLoss()
    return z 

# -----------------------------------------------------------

class PositionalEncoding(T.nn.Module):  # documentation code
  def __init__(self, d_model: int, dropout: float=0.1,
   max_len: int=5000):
    super(PositionalEncoding, self).__init__()  # old syntax
    self.dropout = T.nn.Dropout(p=dropout)
    pe = T.zeros(max_len, d_model)  # like 10x4
    position = \
      T.arange(0, max_len, dtype=T.float).unsqueeze(1)
    div_term = T.exp(T.arange(0, d_model, 2).float() * \
      (-np.log(10_000.0) / d_model))
    pe[:, 0::2] = T.sin(position * div_term)
    pe[:, 1::2] = T.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe)  # allows state-save

  def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  # item-by-item version
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def main():
  # 0. setup
  print("\nBegin Transformer architecture People politics demo ")
  np.random.seed(1)  # 0, 2000, .02 = 93.5 77.5; 
  T.manual_seed(1)   # 1, 2000, .025 = 82.5 77.5

  # 1. create Dataset
  print("\nCreating 200-item train Dataset from text file ")
  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create network
  print("\nCreating Transformer network ")
  net = TransformerNet().to(device)
  
# -----------------------------------------------------------

  # 3. train model
  max_epochs = 2000 
  ep_log_interval = 400
  lrn_rate = 0.025
  
  loss_func = T.nn.NLLLoss()  # assumes log-softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)
  # optimizer = T.optim.Adam(net.parameters(), lr=lrn_rate)
  
  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("lrn_rate = %0.3f " % lrn_rate)
  print("max_epochs = %3d " % max_epochs)

  print("\nStarting training")
  net.train()  # set mode
  for epoch in range(0, max_epochs):
    ep_loss = 0.0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      (X, y) = batch  # X = pixels, y = target labels
      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, y)  # a tensor
      ep_loss += loss_val.item()  # accumulate
      loss_val.backward()  # compute grads
      optimizer.step()     # update weights
    if epoch % ep_log_interval == 0:
      print("epoch = %4d   |  loss = %9.4f" % (epoch, ep_loss))
      net.eval()

  print("Done ") 

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy")
  net.eval()
  acc_train = accuracy(net, train_ds)  # item-by-item
  print("Accuracy on training data = %0.4f" % acc_train)

  net.eval()
  acc_test = accuracy(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

# -----------------------------------------------------------

  # 5. use model
  print("\nPredicting politics for M  30  oklahoma  $50,000: ")
  X = np.array([[-1, 0.30,  0,0,1,  0.5000]], dtype=np.float32)
  X = T.tensor(X, dtype=T.float32).to(device) 

  with T.no_grad():
    logits = net(X)  # do not sum to 1.0
  probs = T.exp(logits)  # sum to 1.0
  probs = probs.numpy()  # numpy vector prints better
  np.set_printoptions(precision=4, suppress=True)
  print(probs)

# -----------------------------------------------------------

  # 6. save model
  print("\nSaving trained model state")
  # fn = ".\\Models\\people_model.pt"
  # T.save(net.state_dict(), fn)  

  print("\nEnd Transformer demo ")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch, Transformers. Bookmark the permalink.

1 Response to Example of a PyTorch Multi-Class Classifier Using a Transformer

  1. Jerome Achir's avatar Jerome Achir says:

    Good work. Have being trying to modify your skiplayer for more classes but no success yet

Comments are closed.