Adding an Attention Layer to a PyTorch Neural Network

A few weeks ago, I explored adding a Transformer Encoder component to a PyTorch neural network regression system. A Transformer is a very complex component, but the core internal functionality is an Attention module. I decided to investigate replacing the Transformer Encoder with a custom, from-scratch Attention layer.

It was very difficult problem and took many hours, but eventually I got a demo up and running. The demo is one of the most complex pieces of software I’ve ever written, from both a conceptual point of view and an engineering/implementation point of view.

The bottom line is that it’s not conclusive if adding an Attention layer to a PyTorch neural network regression system helps or hurts or has no significant effect.

I can’t possibly explain all of the details of my demo program, but I’ll try to point out a few of the main ideas.

I used one of my standard examples. The data looks like:

 1   0.24   1  0  0   0.2950   0  0  1
-1   0.39   0  0  1   0.5120   0  1  0
 1   0.63   0  1  0   0.7580   1  0  0
-1   0.36   1  0  0   0.4450   0  1  0
 1   0.27   0  1  0   0.2860   0  0  1
. . .

The comma-delimited fields are sex (male = -1, female = +1), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by 100,000), politics (conservative = 100, moderate = 010, liberal = 001). The goal is to predict income from sex, age, state, and political leaning. The data is synthetic. There are 200 training items and 40 test items.

For my demo, I created an (8–32)-A-10-10-1 neural network. There are 8 input nodes, a numeric pseudo-embedding layer than maps each input node to 4 values, a from-scratch Attention layer, two fully connected hidden layers with 10 nodes and tanh activation, and an output layer with a single node (predicted income).

Because Attention assigns an importance weighting to each of the predictor variables, it’s necessary to implement a Positional Encoding layer too — at least I think it’s necessary; there are several ideas that I’m not 100% sure of at this point. I used the standard sine-cosine positional encoding algorithm.

To implement the from-scratch attention layer, I relied on dozens of Internet resources. The self-attention (I use the terms attention and self-attention interchangeably — that’s another topic) layer definition I finally came up with is:

class SelfAttention(T.nn.Module):
  def __init__(self, embed_dim):
    super(SelfAttention, self).__init__()
    self.dim = embed_dim
    self.Q = T.nn.Linear(embed_dim, embed_dim)
    self.K = T.nn.Linear(embed_dim, embed_dim)
    self.V = T.nn.Linear(embed_dim, embed_dim)
    self.O = T.nn.Linear(embed_dim, embed_dim)
    self.soft = T.nn.Softmax(dim=2)
   
  def forward(self, x): # x is (batch, seq, embed)
    q = self.Q(x)
    k = self.K(x)
    v = self.V(x)

    scores = T.matmul(q, k.transpose(1,2))
    scores = scores / np.sqrt(self.dim)
    attn = self.soft(scores)
    z = T.matmul(attn, v)
    return self.O(z)  # final layer makes big diff

There are many possibilities for an attention layer algorithm. I’m not satisfied I have the best design and I’ll have to investigate a lot more.

It’s important to note that my approach to applying attention to a regression problem uses the same pattern as that needed for a natural language problem. In NLP, each word token maps to a vector of embedding values. This results in an input shape of [batch_size, seq_length, embedding_dim]. For my regression demo, I use pseudo-embedding via the custom SkipLinear layer (a terrible name but I couldn’t come up with anything better when I designed it). For regression, there’s no technical need to map each numeric input (like age) to an embedded vector of multiple values. This is a very tricky topic and one that I haven’t explored in depth — it will require a significant effort when I have time.

The Attention layer is complex, but far more difficult was figuring out how to integrate it into a standard PyTorch neural network regression system. Anyway, the output of the demo program is:

Begin regression with Attention demo

Creating People Dataset objects

Creating 8-32-A-(10-10)-1 regression neural network

bat_size = 10
loss = MSELoss()
optimizer = Adam
lrn_rate = 0.01

Starting training
epoch =    0  |  loss =    1.3310
epoch =   20  |  loss =    0.4219
epoch =   40  |  loss =    0.0993
epoch =   60  |  loss =    0.0649
epoch =   80  |  loss =    0.0823
epoch =  100  |  loss =    0.0550
epoch =  120  |  loss =    0.0420
epoch =  140  |  loss =    0.0447
epoch =  160  |  loss =    0.0312
epoch =  180  |  loss =    0.0388
Done

Computing model accuracy (within 0.10 of true)
Accuracy on train data = 0.8100
Accuracy on test data = 0.7500

Predicting income for M 34 Oklahoma moderate:
$45906.71

End People income regression demo

I implemented a program-defined accuracy() function where a correct income prediction is one that’s within a specified percentage of the true income. After training, using a 10% closeness percentage, my model scored 81.00% accuracy on the training data (162 of 200 correct), and 75.00% accuracy on the test data (30 of 40 correct). These results are comparable to a regression system without Attention, and also with a regression system with a Transformer Encoder layer.

It’s not possible to draw any strong conclusions because the dataset is so small and the neural architecture has many hyperparameters that I didn’t experiment with. But it was a fascinating exploration and I learned a ton about the attention mechanism.



I grew up in the era before the Internet, when music was dominated by vinyl albums. It was critically important for an album cover to attract a buyer’s attention. He are three old faith-related album covers that might have benefited from a bit more reflection on the text used.


Demo program. Very long and extremely complex. Replace “lt” with the Boolean less-than operator symbol. My blog editor often chokes on symbols. The training and test data are at: https://jamesmccaffreyblog.com/2023/12/01/regression-using-a-pytorch-neural-network-with-a-transformer-component/.

# people_income_attention.py
# predict income from sex, age, city, politics
# from-scratch attention layer
# PyTorch 2.1.2-CPU Anaconda3-2023.09-1  Python 3.11.5
# Windows 10/11 

import numpy as np
import torch as T

device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  def __init__(self, src_file):
    # sex age   state   income   politics
    # -1  0.27  0 1 0   0.7610   0 0 1
    # +1  0.19  0 0 1   0.6550   1 0 0

    # two-pass technique
    tmp_x = np.loadtxt(src_file, usecols=[0,1,2,3,4,6,7,8],
      delimiter=",", comments="#", dtype=np.float32)
    tmp_y = np.loadtxt(src_file, usecols=5, delimiter=",",
      comments="#", dtype=np.float32)
    tmp_y = tmp_y.reshape(-1,1)  # 2D required

    self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y, dtype=T.float32).to(device)

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    incom = self.y_data[idx] 
    return (preds, incom)  # as a tuple

# -----------------------------------------------------------

class SkipLinear(T.nn.Module):

  # -----

  class Core(T.nn.Module):
    def __init__(self, n):
      super().__init__()
      # 1 node to n nodes, n gte 2
      self.weights = T.nn.Parameter(T.zeros((n,1),
        dtype=T.float32))
      self.biases = T.nn.Parameter(T.tensor(n,
        dtype=T.float32))
      lim = 0.01
      T.nn.init.uniform_(self.weights, -lim, lim)
      T.nn.init.zeros_(self.biases)

    def forward(self, x):
      wx= T.mm(x, self.weights.t())
      v = T.add(wx, self.biases)
      return v

  # -----

  def __init__(self, n_in, n_out):
    super().__init__()
    self.n_in = n_in; self.n_out = n_out
    if n_out  % n_in != 0:
      print("FATAL: n_out must be divisible by n_in")
    n = n_out // n_in  # num nodes per input

    self.lst_modules = \
      T.nn.ModuleList([SkipLinear.Core(n) for \
        i in range(n_in)])

  def forward(self, x):
    lst_nodes = []
    for i in range(self.n_in):
      xi = x[:,i].reshape(-1,1)
      oupt = self.lst_modules[i](xi)
      lst_nodes.append(oupt)
    result = T.cat((lst_nodes[0], lst_nodes[1]), 1)
    for i in range(2,self.n_in):
      result = T.cat((result, lst_nodes[i]), 1)
    result = result.reshape(-1, self.n_out)
    return result

# -----------------------------------------------------------

class PositionalEncoding(T.nn.Module):  # documentation code
  def __init__(self, d_model: int, dropout: float=0.1,
   max_len: int=5000):
    super(PositionalEncoding, self).__init__()  # old syntax
    self.dropout = T.nn.Dropout(p=dropout)
    pe = T.zeros(max_len, d_model)  # like 10x4
    position = \
      T.arange(0, max_len, dtype=T.float).unsqueeze(1)
    div_term = T.exp(T.arange(0, d_model, 2).float() * \
      (-np.log(10_000.0) / d_model))
    pe[:, 0::2] = T.sin(position * div_term)
    pe[:, 1::2] = T.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe)  # allows state-save

  def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)

# -----------------------------------------------------------

class SelfAttention(T.nn.Module):
  def __init__(self, embed_dim):
    super(SelfAttention, self).__init__()
    self.dim = embed_dim
    self.Q = T.nn.Linear(embed_dim, embed_dim)
    self.K = T.nn.Linear(embed_dim, embed_dim)
    self.V = T.nn.Linear(embed_dim, embed_dim)
    self.O = T.nn.Linear(embed_dim, embed_dim)
    self.soft = T.nn.Softmax(dim=2)
   
  def forward(self, x): # x is (batch, seq, embed)
    q = self.Q(x)
    k = self.K(x)
    v = self.V(x)

    scores = T.matmul(q, k.transpose(1,2))
    scores = scores / np.sqrt(self.dim)
    attn = self.soft(scores)
    z = T.matmul(attn, v)
    return self.O(z)

# -----------------------------------------------------------

class AttentionNet(T.nn.Module):
  def __init__(self):
    super(AttentionNet, self).__init__()
    self.embed = SkipLinear(8, 32)  # 8 inputs, each to 4
    self.pos_enc = \
      PositionalEncoding(4, dropout=0.10)  # positional
    self.att = SelfAttention(4)  # embed dim
    self.hid1 = T.nn.Linear(32, 10)  # 32-(10-10)-1
    self.hid2 = T.nn.Linear(10, 10)
    self.oupt = T.nn.Linear(10, 1)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = self.embed(x)  # 8 inputs to 32 embed
    z = z.reshape(-1, 8, 4)  # to 3D
    z = self.pos_enc(z)
    z = self.att(z)  # 10,8,4
    z = z.reshape(-1,32)
    z = T.tanh(self.hid1(z))
    z = T.tanh(self.hid2(z))
    z = self.oupt(z)  # regression: no activation
    return z

# -----------------------------------------------------------

def accuracy(model, ds, pct_close):
  # assumes model.eval()
  # correct within pct of true income
  n_correct = 0; n_wrong = 0

  for i in range(len(ds)):
    X = ds[i][0].reshape(1,8)  # 2D
    Y = ds[i][1]   # 2D
    with T.no_grad():
      oupt = model(X)         # computed income

    # print("predicted = "); print(oupt)
    # print("actual = "); print(Y); input()

    if T.abs(oupt - Y) "lt" T.abs(pct_close * Y):
      n_correct += 1
    else:
      n_wrong += 1
  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def train(model, ds, bs, lr, me, le):
  # dataset, bat_size, lrn_rate, max_epochs, log interval
  train_ldr = T.utils.data.DataLoader(ds, batch_size=bs,
    shuffle=True)
  loss_func = T.nn.MSELoss()
  optimizer = T.optim.Adam(model.parameters(), lr=lr)

  for epoch in range(0, me):
    epoch_loss = 0.0  # for one full epoch
    for (b_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # predictors [10,8]
      y = batch[1]  # target income
      optimizer.zero_grad()
      oupt = model(X)
      loss_val = loss_func(oupt, y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()  # compute gradients
      optimizer.step()     # update weights

    if epoch % le == 0:
      print("epoch = %4d  |  loss = %9.4f" % \
        (epoch, epoch_loss)) 

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin regression with Attention demo ")
  T.manual_seed(0)
  np.random.seed(0)
  
  # 1. create Dataset objects
  print("\nCreating People Dataset objects ")
  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)  # 200 rows

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)  # 40 rows

  # 2. create network
  print("\nCreating 8-A-(10-10)-1 regression neural network ")
  net = AttentionNet().to(device)

# -----------------------------------------------------------

  # 3. train model
  print("\nbat_size = 10 ")
  print("loss = MSELoss() ")
  print("optimizer = Adam ")
  print("lrn_rate = 0.01 ")

  print("\nStarting training")
  net.train()  # set mode
  train(net, train_ds, bs=10, lr=0.01, me=200, le=20)
  print("Done ")

# -----------------------------------------------------------

  # 4. evaluate model accuracy
  print("\nComputing model accuracy (within 0.10 of true) ")
  net.eval()
  acc_train = accuracy(net, train_ds, 0.10)  # item-by-item
  print("Accuracy on train data = %0.4f" % acc_train)

  acc_test = accuracy(net, test_ds, 0.10)
  print("Accuracy on test data = %0.4f" % acc_test)

# -----------------------------------------------------------

  # 5. make a prediction
  print("\nPredicting income for M 34 Oklahoma moderate: ")
  x = np.array([[-1, 0.34, 0,0,1,  0,1,0]],
    dtype=np.float32)  # 2D input
  x = T.tensor(x, dtype=T.float32).to(device) 

  with T.no_grad():
    pred_inc = net(x)
  pred_inc = pred_inc.item()  # scalar
  print("$%0.2f" % (pred_inc * 100_000))  # un-normalized

# -----------------------------------------------------------

  # 6. save model (state_dict approach)
  # print("\nSaving trained model state")
  # fn = ".\\Models\\people_income_model.pt"
  # T.save(net.state_dict(), fn)

  # model = Net()
  # model.load_state_dict(T.load(fn))
  # use model to make prediction(s)

  print("\nEnd People income regression demo ")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch, Transformers. Bookmark the permalink.

Leave a Reply