Anomaly Detection for Tabular Data Using a PyTorch Transformer with Numeric Embedding

I’ve been looking at unsupervised anomaly detection using a PyTorch Transformer module. My first set of experiments used the UCI Digits dataset because the inputs (64 pixels with values between 0 and 16 — a scaled down MNIST) are all integers and so they easily map to a Transformer which was designed to accept integer tokens that represent words.

The idea is that a Transformer model was originally designed for input items that are a sequence of words (or tokens), such as “I think therefore I am”. Each word/token is mapped to an integer ID such as [17, 283, 167, 17, 35]. Then each word/token ID is mapped to a word embedding such as 17 = [0.1234, 0.9876, 0.2468, 0.1357]. For UCI Digits data items, each pixel value corresponds to a word/token.


My experiment seems to work quite well.

My next set of experiments looked at mixed tabular data such as an Employee item with sex (Boolean), age (integer), city (categorical), income (real), and job type (categorical). Because standard Transformer networks use an Embedding layer, I had to simulate embedding using a fully connected layer.

Recently, after a lot of thrashing around, I implemented a custom numeric embedding layer that maps real values (rather than integers) to a vector. So I figured I’d try to modify my program that used simulated embedding to use my new numeric embedding code.

Bottom line: It seemed to work very well.

The synthetic dataset to analyze for anomalies looks like:

 1   0.24   1 0 0   0.2950   0 0 1
-1   0.39   0 0 1   0.5120   0 1 0
 1   0.63   0 1 0   0.7580   1 0 0
-1   0.36   1 0 0   0.4450   0 1 0
. . .

The columns are sex (male = -1, female = +1), age (divided by 100), city (Anaheim, Boulder, Concord, one-hot encoded), income (divided by 100,000) and job type (Mgmt, Supp, Tech, one-hot encoded).

My Transformer based autoencoder has a (9–36)-T-18-9 architecture. The 9 input values are fed to a numeric embedding layer with dim=4, then those 36 values are fed to a Transformer Encoder, then the output is mapped to a fully connected layer with 18 nodes, which is mapped to 9 nodes. For a particular data item, the 9 output node values should match the 9 input node values. If there is a big difference, then the data item is anomalous.

It’s important to note that Transformer modules were designed for sequential input (like a sentence) where order matters. The experiment in this blog post uses tabular data that’s not inherently sequential. That’s a whole other topic to explore.

Interesting stuff. It will require a lot of additional work to see if a Transformer based autoencoder anomaly detection system is better, worse, or equivalent to a standard autoencoder system.



There are many analogies between the development of aviation technologies and ML technologies. A successful software system is almost always the result of many partial failures where a valuable lesson was learned. The same is true for aviation technology. Note: For the U.S. military, XF indicates an experimental model and YF indicates a prototype.

Left: The McDonnell XF-88 Voodoo (1948) was not accepted for production, but much of the technology was adapted into the very successful F-101 Voodoo (1954).

Center: The Northrop YF-17 Cobra (1974) lost a competition to the General Dynamics YF-16 Fighting Falcon. However, much of the technology was adapted into the very successful F/A-18 Hornet (1978).

Right: The Northrop-McDonnell YF-23 Black Widow (1990) lost a competition to the Lockheed YF-22 Raptor. But some of the technology was adapted into the successful Lockheed-Northrop F-35 Lightning II (2006).


Demo code. Replace “lt”, “gt”, “lte”, “gte” with Boolean operator symbols. The Employee data can be found at:

jamesmccaffrey.wordpress.com/2022/05/17/autoencoder-anomaly-detection-using-pytorch-1-10-on-windows-11/

# emp_trans_num_embed_anom.py
# Transformer based reconstruction error anomaly detection
# uses custom numeric embedding layer instead of FC layer

# PyTorch 2.0.0-CPU Anaconda3-2022.10  Python 3.9.13
# Windows 10/11

import numpy as np
import torch as T

device = T.device('cpu') 
T.set_num_threads(1)

# -----------------------------------------------------------

class EmployeeDataset(T.utils.data.Dataset):
  # sex  age   city     income  job
  # -1   0.27  0  1  0  0.7610  0  0  1
  # +1   0.19  0  0  1  0.6550  0  1  0
  # sex: -1 = male, +1 = female
  # city: anaheim, boulder, concord
  # job: mgmt, supp, tech

  def __init__(self, src_file):
    tmp_x = np.loadtxt(src_file, usecols=range(0,9),
      delimiter="\t", comments="#", dtype=np.float32)
    self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device)

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx, :]  # row idx, all cols
    sample = { 'predictors' : preds }  # as Dictionary
    return sample  

# -----------------------------------------------------------

class SkipLinear(T.nn.Module):

  # -----

  class Core(T.nn.Module):
    def __init__(self, n):
      super().__init__()
      # 1 node to n nodes, n gte 2
      self.weights = T.nn.Parameter(T.zeros((n,1),
        dtype=T.float32))
      self.biases = T.nn.Parameter(T.tensor(n,
        dtype=T.float32))
      lim = 0.01
      T.nn.init.uniform_(self.weights, -lim, lim)
      T.nn.init.zeros_(self.biases)

    def forward(self, x):
      wx= T.mm(x, self.weights.t())
      v = T.add(wx, self.biases)
      return v

  # -----

  def __init__(self, n_in, n_out):
    super().__init__()
    self.n_in = n_in; self.n_out = n_out
    if n_out  % n_in != 0:
      print("FATAL: n_out must be divisible by n_in")
    n = n_out // n_in  # num nodes per input

    self.lst_modules = \
      T.nn.ModuleList([SkipLinear.Core(n) for \
        i in range(n_in)])

  def forward(self, x):
    lst_nodes = []
    for i in range(self.n_in):
      xi = x[:,i].reshape(-1,1)
      oupt = self.lst_modules[i](xi)
      lst_nodes.append(oupt)
    result = T.cat((lst_nodes[0], lst_nodes[1]), 1)
    for i in range(2,self.n_in):
      result = T.cat((result, lst_nodes[i]), 1)
    result = result.reshape(-1, self.n_out)
    return result

# -----------------------------------------------------------

class PositionalEncoding(T.nn.Module):  # documentation code
  def __init__(self, d_model: int, dropout: float=0.1,
   max_len: int=5000):
    super(PositionalEncoding, self).__init__()  # old syntax
    self.dropout = T.nn.Dropout(p=dropout)
    pe = T.zeros(max_len, d_model)  # like 10x4
    position = \
      T.arange(0, max_len, dtype=T.float).unsqueeze(1)
    div_term = T.exp(T.arange(0, d_model, 2).float() * \
      (-np.log(10_000.0) / d_model))
    pe[:, 0::2] = T.sin(position * div_term)
    pe[:, 1::2] = T.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe)  # allows state-save

  def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)

# -----------------------------------------------------------

class Transformer_Net(T.nn.Module):
  def __init__(self):
    # 9 numeric inputs: no exact word embedding equivalent
    # pseudo embed_dim = 4
    # seq_len = 9
    super(Transformer_Net, self).__init__()

    # self.fc1 = T.nn.Linear(9, 9*4)  # pseudo-embedding
    self.embed = SkipLinear(9, 36)  # 9 inputs, each goes to 4

    self.pos_enc = \
      PositionalEncoding(4, dropout=0.00)  # positional

    self.enc_layer = T.nn.TransformerEncoderLayer(d_model=4,
      nhead=2, dim_feedforward=10, 
      batch_first=True)  # d_model divisible by nhead

    self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
      num_layers=2)

    self.dec1 = T.nn.Linear(36, 18)
    self.dec2 = T.nn.Linear(18, 9)

    # use default weight initialization

  def forward(self, x):
    # x is Size([bs, 9])
    z = self.embed(x)         # [bs,36]
    z = z.reshape(-1, 9, 4)   # [bs, 9, 4] 
    z = self.pos_enc(z)       # [bs, 9, 4]
    z = self.trans_enc(z)     # [bs, 9, 4]

    z = z.reshape(-1, 36)              # [bs, 36]
    z = T.tanh(self.dec1(z))           # [bs, 18]
    z = self.dec2(z)  # no activation  # [bs, 9]
  
    return z

# -----------------------------------------------------------

def analyze_error(model, ds):
  largest_err = 0.0
  worst_x = None
  worst_y = None
  n_features = len(ds[0]['predictors'])

  for i in range(len(ds)):
    X = ds[i]['predictors']
    X = X.reshape(-1,9)  # make it a batch
    with T.no_grad():
      Y = model(X)  # should be same as X
    err = T.sum((X-Y)*(X-Y)).item()  # SSE all features
    err = err / n_features           # sort of norm'ed SSE 

    if err "gt" largest_err:
      largest_err = err
      worst_x = X
      worst_y = Y

  np.set_printoptions(formatter={'float': '{: 0.4f}'.format})
  print("Largest reconstruction error: %0.4f" % largest_err)
  print("Worst data item    = ")
  print(worst_x.numpy())
  print("Its reconstruction = " )
  print(worst_y.numpy())

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nEmployee transformer numeric embed anomaly detect ")
  T.manual_seed(1)
  np.random.seed(1)
  
  # 1. create DataLoader objects
  print("\nCreating Employee Dataset ")

  data_file = ".\\Data\\employee_all.txt"
  data_ds = EmployeeDataset(data_file)  # 240 rows

  bat_size = 10
  data_ldr = T.utils.data.DataLoader(data_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create network
  print("\nCreating Transformer encoder-decoder network ")
  net = Transformer_Net().to(device)

# -----------------------------------------------------------

  # 3. train autoencoder model
  max_epochs = 100
  ep_log_interval = 10
  lrn_rate = 0.005

  loss_func = T.nn.MSELoss()
  optimizer = T.optim.Adam(net.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = Adam")
  print("lrn_rate = %0.3f " % lrn_rate)
  print("max_epochs = %3d " % max_epochs)
  
  print("\nStarting training")
  net.train()
  for epoch in range(0, max_epochs):
    epoch_loss = 0  # for one full epoch

    for (batch_idx, batch) in enumerate(data_ldr):
      X = batch['predictors']  # [bs,9]
      Y = batch['predictors']  # same

      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %4d  |  loss = %0.4f" % \
       (epoch, epoch_loss))
  print("Done ")

# -----------------------------------------------------------

  # 4. find item with largest reconstruction error
  print("\nAnalyzing data for largest reconstruction error \n")
  net.eval()
  analyze_error(net, data_ds)

  print("\nEnd transformer autoencoder anomaly demo ")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch, Transformers. Bookmark the permalink.