Adding a Transformer Module to a PyTorch Regression Network – Linear Layer Pseudo-Embedding

I’ve been looking at adding a Transformer module to a PyTorch regression network. Because the key functionality of a Transformer is the attention mechanism, I’ve also been looking at adding a custom Attention module instead of a Transformer.

There are dozens of design alternatives, and many architecture and training hyperparameters. For the baseline architecture against which I’m comparing, I use a standard PyTorch TransformerEncoder layer, a numeric pseudo-embedding layer, and a simplified positional encoding layer.

All of my ideas are based on natural language processing systems. In NLP, each word/token is mapped to an integer, such as “date” = 3572. Then the integer is mapped to a vector of real values, typically about 100-1000 values — the embedding. The idea is to capture different dimensions of the meanings of the source word/token, such as date is a calendar day and year, or a social engagement between a man and a woman, or a kind of fruit, etc. After embedding, in an NLP system positional encoding is necessary because the order of words in a sentence is critical. I intend to examine the architecture options as best I can.

For this investigation, I’m replacing the numeric pseudo-embedding with a fully connected linear layer. My working hypothesis going in was that the architecture would result in a model similar to one with numeric pseudo-embedding. I was correct. The linear layer pseudo-embedding architecture resulted in a good model. Implication: when using a Transformer, you need to use some form of embedding even when the inputs are already numeric.



A diagram of a 6-L12-PE-T-(8-8)-1 system, not the 6-L24-PE-T-(10-10)-1 system of the demo.


The demo uses one of my standard synthetic datasets. The data looks like:

-0.1660,  0.4406, -0.9998, -0.3953, -0.7065, -0.8153,  0.7022
-0.2065,  0.0776, -0.1616,  0.3704, -0.5911,  0.7562,  0.5666
-0.9452,  0.3409, -0.1654,  0.1174, -0.7192, -0.6038,  0.8186
 0.7528,  0.7892, -0.8299, -0.9219, -0.6603,  0.7563,  0.3687
. . .

The first six values on each line are predictors. The last value on each line is the target to predict. The data was generated by a 6-10-1 neural network with random weights and biases. There are 200 training items and 40 test items.

The baseline architecture is 6-24-PE-T-(10-10)-1 meaning 6 inputs, each to 4 embedding, plus positional encoding, fed to a Transformer, sent to two hidden layers of 10 nodes each, sent to a single output prediction node. The demo architecture explored in this blog post is 6-L24-PE-T-(10-10)-1 meaning 6 inputs, fed to fully connected layer in place of an embedding layer, positional encoding added, sent to a Transformer (with embedding dimension set to 4), sent to two fully connected layers with 10 nodes each, sent to a single output node.

Here’s the output of one demo run:

Begin Transformer regression on synthetic data
Linear layer pseudo-embedding

Loading train (200) and test (40) data to memory
Done

First three rows of training predictors:
tensor([-0.1660,  0.4406, -0.9998, -0.3953, -0.7065, -0.8153])
tensor([-0.2065,  0.0776, -0.1616,  0.3704, -0.5911,  0.7562])
tensor([-0.9452,  0.3409, -0.1654,  0.1174, -0.7192, -0.6038])

First three target y values:
0.7022
0.5666
0.8186

Creating 6--L24-PE-T-(10-10)-1 regression model

bat_size = 10
loss = MSELoss()
optimizer = Adam
lrn_rate = 0.001

Starting training
epoch =    0  |  loss = 10.0987
epoch =   20  |  loss = 0.0568
epoch =   40  |  loss = 0.0313
epoch =   60  |  loss = 0.0245
epoch =   80  |  loss = 0.0175
Done

Computing model accuracy (within 0.10 of true)
Accuracy on train data = 0.9750
Accuracy on test data = 0.8500

MSE on train data = 0.0004
MSE on test data = 0.0010

Predicting target y for train[0]:
Predicted y = 0.7185

End demo

The model accuracy (97.5% = 195 out of 200 correct on the training data) and mean squared error (0.0004 on training data) were even better than the baseline architecture (84% accuracy and MSE = 0.0012).

A very interesting experiment.



The classic story of transformation is “(Strange Case of) Dr. Jekyll and Mr. Hyde” (1886) by Robert Louis Stevenson. The story and variations have appeared in dozens of film and TV adaptations. Here are three not-so-serious adaptations that I like.

Left: “Abbott and Costello Meet Dr. Jekyll and Mr. Hyde” (1953) is one of my favorites. Slim and Tubby (Abbott and Costello respectively) are two inept American policemen who go to London to learn from Scotland Yard. Actor Boris Karloff plays Jekyll/Hyde, and his serum accidentally transforms Tubby too.

Center: “Nowhere to Hyde” (season 2, episode 1, Sept. 1970) has the Scooby-Doo Where Are You team investigating jewel thefts, apparently by the ghost of Mr. Hyde. In the end the culprit is a descendant of Dr. Jekyll who needs money to fund his experiments. I always enjoy Scooby-Doo cartoons.

Right: “Spooks” (1953) is a short subject (~25 minutes) featuring the Three Stooges, Larry, Moe, and Shemp. The trio are private investigators who manage to rescue kidnapped Mary Bopper from a not-nice Dr. Jekyll who intends to use her in his experiments. One of just two Stooges movies filmed in 3D.


Demo code. Replace “lt” (less than) with Boolean operator symbol (my blog editor chokes on symbols).

# synthetic_transformer_linear_embed.py

# regression with Transformer and linear layer embedding,
# simple positional encoding on a synthetic dataset 
# PyTorch 2.3.1-CPU  Anaconda3-2023.09  Python 3.11.5
# Windows 10/11 

import numpy as np
import torch as T  # non-standard alias

device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class SynthDataset(T.utils.data.Dataset):
  # six predictors followed by target
  def __init__(self, src_file):
    tmp_x = np.loadtxt(src_file, delimiter=",",
      usecols=[0,1,2,3,4,5], dtype=np.float32)
    tmp_y = np.loadtxt(src_file, usecols=6, delimiter=",",
      dtype=np.float32)
    tmp_y = tmp_y.reshape(-1,1)  # 2D required

    self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y, dtype=T.float32).to(device)

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return (preds, trgts)  # as a tuple

# -----------------------------------------------------------

class PositionEncode(T.nn.Module):
  def __init__(self, n_features):
    super(PositionEncode, self).__init__()  # old syntax
    self.nf = n_features
    self.pe = T.zeros(n_features, dtype=T.float32)
    for i in range(n_features):
      self.pe[i] = i * (0.01 / n_features)  # no sin, cos

  def forward(self, x):
    for i in range(len(x)):
      for j in range(len(x[0])):
        x[i][j] += self.pe[j]
    return x

# -----------------------------------------------------------

class TransformerNet(T.nn.Module):
  def __init__(self):

    super(TransformerNet, self).__init__()
    self.linear_embed = T.nn.Linear(6, 24)  # 6 to 4 each
    self.pos_enc = \
      PositionEncode(24)  # positional
    self.enc_layer = T.nn.TransformerEncoderLayer(d_model=4,
      nhead=2, dim_feedforward=10, 
      batch_first=True)  # d_model divisible by nhead
    self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
      num_layers=2)  # 6 layers is default
    
    self.fc1 = T.nn.Linear(24, 10)  # 6--24-PE-T-10-10-1
    self.fc2 = T.nn.Linear(10, 10)
    self.fc3 = T.nn.Linear(10, 1)

    # default weight and bias initialization

  def forward(self, x):
    z = self.linear_embed(x)  # 6 inpts to 24 pseudo-embed
    z = z.reshape(-1, 6, 4)  # bat seq embed
    z = self.pos_enc(z) 
    z = self.trans_enc(z) 
    z = z.reshape(-1, 24)  # torch.Size([bs, xxx])
    z = T.tanh(self.fc1(z))
    z = T.tanh(self.fc2(z))
    z = self.fc3(z)  # regression: no activation
    return z

# -----------------------------------------------------------

def accuracy(model, ds, pct_close):
  # assumes model.eval()
  # correct within pct of true income
  n_correct = 0; n_wrong = 0

  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)
    with T.no_grad():
      oupt = model(X)    # predicted target

    if T.abs(oupt - Y) "lt" T.abs(pct_close * Y): # note
      n_correct += 1
    else:
      n_wrong += 1
  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def mean_sq_err(model, ds):
  # assumes model.eval()
  n = len(ds)
  sum = 0.0

  for i in range(n):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    y = ds[i][1].reshape(1)
    with T.no_grad():
      oupt = model(X)         # predicted target
    diff = oupt.item() - y.item()
    sum += diff * diff 
  mse = sum / n
  return mse

# -----------------------------------------------------------

def train(model, ds, bs, lr, me, le):
  # dataset, bat_size, lrn_rate, max_epochs, log interval
  train_ldr = T.utils.data.DataLoader(ds, batch_size=bs,
    shuffle=True)
  loss_func = T.nn.MSELoss()
  optimizer = T.optim.Adam(model.parameters(), lr=lr)
  # optimizer = T.optim.SGD(model.parameters(), lr=lr)

  for epoch in range(0, me):
    epoch_loss = 0.0  # for one full epoch
    for (b_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # predictors
      y = batch[1]  # target
      optimizer.zero_grad()
      oupt = model(X)
      loss_val = loss_func(oupt, y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()  # compute gradients
      optimizer.step()     # update weights

    if epoch % le == 0:
      print("epoch = %4d  |  loss = %0.4f" % \
        (epoch, epoch_loss)) 

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin Transformer regression on synthetic data ")
  print("Linear layer pseudo-embedding ")
  np.random.seed(0)
  T.manual_seed(0) 

  # 1. load data
  print("\nLoading train (200) and test (40) data to memory ")
  train_file = ".\\Data\\synthetic_train.txt"
  train_ds = SynthDataset(train_file)  # 200 rows
  test_file = ".\\Data\\synthetic_test.txt"
  test_ds = SynthDataset(test_file)    # 40 rows
  print("Done ")

  print("\nFirst three rows of training predictors: ")
  for i in range(3):
    print(train_ds[i][0])
  print("\nFirst three target y values: ")
  for i in range(3):
    print("%0.4f " % train_ds[i][1])

  # 2. create regression model
  print("\nCreating 6--L24-PE-T-(10-10)-1 regression model ")
  net = TransformerNet().to(device)

  # 3. train model
  print("\nbat_size = 10 ")
  print("loss = MSELoss() ")
  print("optimizer = Adam ")
  print("lrn_rate = 0.001 ")

  print("\nStarting training")
  net.train()
  train(net, train_ds, bs=10, lr=0.001, me=100, le=20)
  print("Done ")

  # 4. evaluate model
  net.eval()
  print("\nComputing model accuracy (within 0.10 of true) ")
  acc_train = accuracy(net, train_ds, 0.10)  # item-by-item
  print("Accuracy on train data = %0.4f" % acc_train)
  acc_test = accuracy(net, test_ds, 0.10) 
  print("Accuracy on test data = %0.4f" % acc_test)

  mse_train = mean_sq_err(net, train_ds)
  print("\nMSE on train data = %0.4f" % mse_train)
  mse_test = mean_sq_err(net, test_ds)
  print("MSE on test data = %0.4f" % mse_test)

  # 5. make a prediction
  print("\nPredicting target y for train[0]: ")
  x = train_ds[0][0].reshape(1,-1)  # item - predictors

  with T.no_grad():
    y = net(x)
  pred_raw = y.item()  # scalar
  print("Predicted y = %0.4f" % pred_raw)  

  # 6. TODO: save model (state_dict approach)

  print("\nEnd demo ")

# -----------------------------------------------------------

if __name__=="__main__":
  main()

Training data:

# synthetic_train.txt
#
-0.1660,  0.4406, -0.9998, -0.3953, -0.7065, -0.8153,  0.7022
-0.2065,  0.0776, -0.1616,  0.3704, -0.5911,  0.7562,  0.5666
-0.9452,  0.3409, -0.1654,  0.1174, -0.7192, -0.6038,  0.8186
 0.7528,  0.7892, -0.8299, -0.9219, -0.6603,  0.7563,  0.3687
-0.8033, -0.1578,  0.9158,  0.0663,  0.3838, -0.3690,  0.7535
-0.9634,  0.5003,  0.9777,  0.4963, -0.4391,  0.5786,  0.7076
-0.7935, -0.1042,  0.8172, -0.4128, -0.4244, -0.7399,  0.8454
-0.0169, -0.8933,  0.1482, -0.7065,  0.1786,  0.3995,  0.7302
-0.7953, -0.1719,  0.3888, -0.1716, -0.9001,  0.0718,  0.8692
 0.8892,  0.1731,  0.8068, -0.7251, -0.7214,  0.6148,  0.4740
-0.2046, -0.6693,  0.8550, -0.3045,  0.5016,  0.4520,  0.6714
 0.5019, -0.3022, -0.4601,  0.7918, -0.1438,  0.9297,  0.4331
 0.3269,  0.2434, -0.7705,  0.8990, -0.1002,  0.1568,  0.3716
 0.8068,  0.1474, -0.9943,  0.2343, -0.3467,  0.0541,  0.3829
 0.7719, -0.2855,  0.8171,  0.2467, -0.9684,  0.8589,  0.4700
 0.8652,  0.3936, -0.8680,  0.5109,  0.5078,  0.8460,  0.2648
 0.4230, -0.7515, -0.9602, -0.9476, -0.9434, -0.5076,  0.8059
 0.1056,  0.6841, -0.7517, -0.4416,  0.1715,  0.9392,  0.3512
 0.1221, -0.9627,  0.6013, -0.5341,  0.6142, -0.2243,  0.6840
 0.1125, -0.7271, -0.8802, -0.7573, -0.9109, -0.7850,  0.8640
-0.5486,  0.4260,  0.1194, -0.9749, -0.8561,  0.9346,  0.6109
-0.4953,  0.4877, -0.6091,  0.1627,  0.9400,  0.6937,  0.3382
-0.5203, -0.0125,  0.2399,  0.6580, -0.6864, -0.9628,  0.7400
 0.2127,  0.1377, -0.3653,  0.9772,  0.1595, -0.2397,  0.4081
 0.1019,  0.4907,  0.3385, -0.4702, -0.8673, -0.2598,  0.6582
 0.5055, -0.8669, -0.4794,  0.6095, -0.6131,  0.2789,  0.6644
 0.0493,  0.8496, -0.4734, -0.8681,  0.4701,  0.5444,  0.3214
 0.9004,  0.1133,  0.8312,  0.2831, -0.2200, -0.0280,  0.3149
 0.2086,  0.0991,  0.8524,  0.8375, -0.2102,  0.9265,  0.3619
-0.7298,  0.0113, -0.9570,  0.8959,  0.6542, -0.9700,  0.6451
-0.6476, -0.3359, -0.7380,  0.6190, -0.3105,  0.8802,  0.6606
 0.6895,  0.8108, -0.0802,  0.0927,  0.5972, -0.4286,  0.2427
-0.0195,  0.1982, -0.9689,  0.1870, -0.1326,  0.6147,  0.4773
 0.1557, -0.6320,  0.5759,  0.2241, -0.8922, -0.1596,  0.7581
 0.3581,  0.8372, -0.9992,  0.9535, -0.2468,  0.9476,  0.2962
 0.1494,  0.2562, -0.4288,  0.1737,  0.5000,  0.7166,  0.3513
 0.5102,  0.3961,  0.7290, -0.3546,  0.3416, -0.0983,  0.3153
-0.1970, -0.3652,  0.2438, -0.1395,  0.9476,  0.3556,  0.4719
-0.6029, -0.1466, -0.3133,  0.5953,  0.7600,  0.8077,  0.3875
-0.4953,  0.7098,  0.0554,  0.6043,  0.1450,  0.4663,  0.4739
 0.0380,  0.5418,  0.1377, -0.0686, -0.3146, -0.8636,  0.6048
 0.9656, -0.6368,  0.6237,  0.7499,  0.3768,  0.1390,  0.3705
-0.6781, -0.0662, -0.3097, -0.5499,  0.1850, -0.3755,  0.7668
-0.6141, -0.0008,  0.4572, -0.5836, -0.5039,  0.7033,  0.7301
-0.1683,  0.2334, -0.5327, -0.7961,  0.0317, -0.0457,  0.5777
 0.0880,  0.3083, -0.7109,  0.5031, -0.5559,  0.0387,  0.5118
 0.5706, -0.9553, -0.3513,  0.7458,  0.6894,  0.0769,  0.4329
-0.8025,  0.3026,  0.4070,  0.2205,  0.5992, -0.9309,  0.7098
 0.5405,  0.4635, -0.4806, -0.4859,  0.2646, -0.3094,  0.3566
 0.5655,  0.9809, -0.3995, -0.7140,  0.8026,  0.0831,  0.2551
 0.9495,  0.2732,  0.9878,  0.0921,  0.0529, -0.7291,  0.3074
-0.6792,  0.4913, -0.9392, -0.2669,  0.7247,  0.3854,  0.4362
 0.3819, -0.6227, -0.1162,  0.1632,  0.9795, -0.5922,  0.4435
 0.5003, -0.0860, -0.8861,  0.0170, -0.5761,  0.5972,  0.5136
-0.4053, -0.9448,  0.1869,  0.6877, -0.2380,  0.4997,  0.7859
 0.9189,  0.6079, -0.9354,  0.4188, -0.0700,  0.8951,  0.2696
-0.5571, -0.4659, -0.8371, -0.1428, -0.7820,  0.2676,  0.8566
 0.5324, -0.3151,  0.6917, -0.1425,  0.6480,  0.2530,  0.4252
-0.7132, -0.8432, -0.9633, -0.8666, -0.0828, -0.7733,  0.9217
-0.0952, -0.0998, -0.0439, -0.0520,  0.6063, -0.1952,  0.5140
 0.8094, -0.9259,  0.5477, -0.7487,  0.2370, -0.9793,  0.5562
 0.9024,  0.8108,  0.5919,  0.8305, -0.7089, -0.6845,  0.2993
-0.6247,  0.2450,  0.8116,  0.9799,  0.4222,  0.4636,  0.4619
-0.5003, -0.6531, -0.7611,  0.6252, -0.7064, -0.4714,  0.8452
 0.6382, -0.3788,  0.9648, -0.4667,  0.0673, -0.3711,  0.5070
-0.1328,  0.0246,  0.8778, -0.9381,  0.4338,  0.7820,  0.5680
-0.9454,  0.0441, -0.3480,  0.7190,  0.1170,  0.3805,  0.6562
-0.4198, -0.9813,  0.1535, -0.3771,  0.0345,  0.8328,  0.7707
-0.1471, -0.5052, -0.2574,  0.8637,  0.8737,  0.6887,  0.3436
-0.3712, -0.6505,  0.2142, -0.1728,  0.6327, -0.6297,  0.7430
 0.4038, -0.5193,  0.1484, -0.3020, -0.8861, -0.5424,  0.7499
 0.0380, -0.6506,  0.1414,  0.9935,  0.6337,  0.1887,  0.4509
 0.9520,  0.8031,  0.1912, -0.9351, -0.8128, -0.8693,  0.5336
 0.9507, -0.6640,  0.9456,  0.5349,  0.6485,  0.2652,  0.3616
 0.3375, -0.0462, -0.9737, -0.2940, -0.0159,  0.4602,  0.4840
-0.7247, -0.9782,  0.5166, -0.3601,  0.9688, -0.5595,  0.7751
-0.3226,  0.0478,  0.5098, -0.0723, -0.7504, -0.3750,  0.8025
 0.5403, -0.7393, -0.9542,  0.0382,  0.6200, -0.9748,  0.5359
 0.3449,  0.3736, -0.1015,  0.8296,  0.2887, -0.9895,  0.4390
 0.6608,  0.2983,  0.3474,  0.1570, -0.4518,  0.1211,  0.3624
 0.3435, -0.2951,  0.7117, -0.6099,  0.4946, -0.4208,  0.5283
 0.6154, -0.2929, -0.5726,  0.5346, -0.3827,  0.4665,  0.4907
 0.4889, -0.5572, -0.5718, -0.6021, -0.7150, -0.2458,  0.7202
-0.8389, -0.5366, -0.5847,  0.8347,  0.4226,  0.1078,  0.6391
-0.3910,  0.6697, -0.1294,  0.8469,  0.4121, -0.0439,  0.4693
-0.1376, -0.1916, -0.7065,  0.4586, -0.6225,  0.2878,  0.6695
 0.5086, -0.5785,  0.2019,  0.4979,  0.2764,  0.1943,  0.4666
 0.8906, -0.1489,  0.5644, -0.8877,  0.6705, -0.6155,  0.3480
-0.2098, -0.3998, -0.8398,  0.8093, -0.2597,  0.0614,  0.6341
-0.5871, -0.8476,  0.0158, -0.4769, -0.2859, -0.7839,  0.9006
 0.5751, -0.7868,  0.9714, -0.6457,  0.1448, -0.9103,  0.6049
 0.0558,  0.4802, -0.7001,  0.1022, -0.5668,  0.5184,  0.4612
 0.4458, -0.6469,  0.7239, -0.9604,  0.7205,  0.1178,  0.5941
 0.4339,  0.9747, -0.4438, -0.9924,  0.8678,  0.7158,  0.2627
 0.4577,  0.0334,  0.4139,  0.5611, -0.2502,  0.5406,  0.3847
-0.1963,  0.3946, -0.9938,  0.5498,  0.7928, -0.5214,  0.5025
-0.7585, -0.5594, -0.3958,  0.7661,  0.0863, -0.4266,  0.7481
 0.2277, -0.3517, -0.0853, -0.1118,  0.6563, -0.1473,  0.4798
-0.3086,  0.3499, -0.5570, -0.0655, -0.3705,  0.2537,  0.5768
 0.5689, -0.0861,  0.3125, -0.7363, -0.1340,  0.8186,  0.5035
 0.2110,  0.5335,  0.0094, -0.0039,  0.6858, -0.8644,  0.4243
 0.0357, -0.6111,  0.6959, -0.4967,  0.4015,  0.0805,  0.6611
 0.8977,  0.2487,  0.6760, -0.9841,  0.9787, -0.8446,  0.2873
-0.9821,  0.6455,  0.7224, -0.1203, -0.4885,  0.6054,  0.6908
-0.0443, -0.7313,  0.8557,  0.7919, -0.0169,  0.7134,  0.6039
-0.2040,  0.0115, -0.6209,  0.9300, -0.4116, -0.7931,  0.6495
-0.7114, -0.9718,  0.4319,  0.1290,  0.5892,  0.0142,  0.7675
 0.5557, -0.1870,  0.2955, -0.6404, -0.3564, -0.6548,  0.6295
-0.1827, -0.5172, -0.1862,  0.9504, -0.3594,  0.9650,  0.5685
 0.7150,  0.2392, -0.4959,  0.5857, -0.1341, -0.2850,  0.3585
-0.3394,  0.3947, -0.4627,  0.6166, -0.4094,  0.0882,  0.5962
 0.7768, -0.6312,  0.1707,  0.7964, -0.1078,  0.8437,  0.4243
-0.4420,  0.2177,  0.3649, -0.5436, -0.9725, -0.1666,  0.8086
 0.5595, -0.6505, -0.3161, -0.7108,  0.4335,  0.3986,  0.5846
 0.3770, -0.4932,  0.3847, -0.5454, -0.1507, -0.2562,  0.6335
 0.2633,  0.4146,  0.2272,  0.2966, -0.6601, -0.7011,  0.5653
 0.0284,  0.7507, -0.6321, -0.0743, -0.1421, -0.0054,  0.4219
-0.4762,  0.6891,  0.6007, -0.1467,  0.2140, -0.7091,  0.6098
 0.0192, -0.4061,  0.7193,  0.3432,  0.2669, -0.7505,  0.6549
 0.8966,  0.2902, -0.6966,  0.2783,  0.1313, -0.0627,  0.2876
-0.1439,  0.1985,  0.6999,  0.5022,  0.1587,  0.8494,  0.3872
 0.2473, -0.9040, -0.4308, -0.8779,  0.4070,  0.3369,  0.6825
-0.2428, -0.6236,  0.4940, -0.3192,  0.5906, -0.0242,  0.6770
 0.2885, -0.2987, -0.5416, -0.1322, -0.2351, -0.0604,  0.6106
 0.9590, -0.2712,  0.5488,  0.1055,  0.7783, -0.2901,  0.2956
-0.9129,  0.9015,  0.1128, -0.2473,  0.9901, -0.8833,  0.6500
 0.0334, -0.9378,  0.1424, -0.6391,  0.2619,  0.9618,  0.7033
 0.4169,  0.5549, -0.0103,  0.0571, -0.6984, -0.2612,  0.4935
-0.7156,  0.4538, -0.0460, -0.1022,  0.7720,  0.0552,  0.4983
-0.8560, -0.1637, -0.9485, -0.4177,  0.0070,  0.9319,  0.6445
-0.7812,  0.3461, -0.0001,  0.5542, -0.7128, -0.8336,  0.7720
-0.6166,  0.5356, -0.4194, -0.5662, -0.9666, -0.2027,  0.7401
-0.2378,  0.3187, -0.8582, -0.6948, -0.9668, -0.7724,  0.7670
-0.3579,  0.1158,  0.9869,  0.6690,  0.3992,  0.8365,  0.4184
-0.9205, -0.8593, -0.0520, -0.3017,  0.8745, -0.0209,  0.7723
-0.1067,  0.7541, -0.4928, -0.4524, -0.3433,  0.0951,  0.4645
-0.5597,  0.3429, -0.7144, -0.8118,  0.7404, -0.5263,  0.6117
 0.0516, -0.8480,  0.7483,  0.9023,  0.6250, -0.4324,  0.5987
 0.0557, -0.3212,  0.1093,  0.9488, -0.3766,  0.3376,  0.5739
-0.3484,  0.7797,  0.5034,  0.5253, -0.0610, -0.5785,  0.5365
-0.9170, -0.3563, -0.9258,  0.3877,  0.3407, -0.1391,  0.7131
-0.9203, -0.7304, -0.6132, -0.3287, -0.8954,  0.2102,  0.9329
 0.0241,  0.2349, -0.1353,  0.6954, -0.0919, -0.9692,  0.5744
 0.6460,  0.9036, -0.8982, -0.5299, -0.8733, -0.1567,  0.4425
 0.7277, -0.8368, -0.0538, -0.7489,  0.5458,  0.6828,  0.5848
-0.5212,  0.9049,  0.8878,  0.2279,  0.9470, -0.3103,  0.4255
 0.7957, -0.1308, -0.5284,  0.8817,  0.3684, -0.8702,  0.3969
 0.2099,  0.4647, -0.4931,  0.2010,  0.6292, -0.8918,  0.4620
-0.7390,  0.6849,  0.2367,  0.0626, -0.5034, -0.4098,  0.7137
-0.8711,  0.7940, -0.5932,  0.6525,  0.7635, -0.0265,  0.5705
 0.1969,  0.0545,  0.2496,  0.7101, -0.4357,  0.7675,  0.4242
-0.5460,  0.1920, -0.5211, -0.7372, -0.6763,  0.6897,  0.6769
 0.2044,  0.9271, -0.3086,  0.1913,  0.1980,  0.2314,  0.2998
-0.6149,  0.5059, -0.9854, -0.3435,  0.8352,  0.1767,  0.4497
 0.7104,  0.2093,  0.6452,  0.7590, -0.3580, -0.7541,  0.4076
-0.7465,  0.1796, -0.9279, -0.5996,  0.5766, -0.9758,  0.7713
-0.3933, -0.9572,  0.9950,  0.1641, -0.4132,  0.8579,  0.7421
 0.1757, -0.4717, -0.3894, -0.2567, -0.5111,  0.1691,  0.7088
 0.3917, -0.8561,  0.9422,  0.5061,  0.6123,  0.5033,  0.4824
-0.1087,  0.3449, -0.1025,  0.4086,  0.3633,  0.3943,  0.3760
 0.2372, -0.6980,  0.5216,  0.5621,  0.8082, -0.5325,  0.5297
-0.3589,  0.6310,  0.2271,  0.5200, -0.1447, -0.8011,  0.5903
-0.7699, -0.2532, -0.6123,  0.6415,  0.1993,  0.3777,  0.6039
-0.5298, -0.0768, -0.6028, -0.9490,  0.4588,  0.4498,  0.6159
-0.3392,  0.6870, -0.1431,  0.7294,  0.3141,  0.1621,  0.4501
 0.7889, -0.3900,  0.7419,  0.8175, -0.3403,  0.3661,  0.4087
 0.7984, -0.8486,  0.7572, -0.6183,  0.6995,  0.3342,  0.5025
 0.2707,  0.6956,  0.6437,  0.2565,  0.9126,  0.1798,  0.2331
-0.6043, -0.1413, -0.3265,  0.9839, -0.2395,  0.9854,  0.5444
-0.8509, -0.2594, -0.7532,  0.2690, -0.1722,  0.9818,  0.6516
 0.8599, -0.7015, -0.2102, -0.0768,  0.1219,  0.5607,  0.4747
-0.4760,  0.8216, -0.9555,  0.6422, -0.6231,  0.3715,  0.5485
-0.2896,  0.9484, -0.7545, -0.6249,  0.7789,  0.1668,  0.3415
-0.5931,  0.7926,  0.7462,  0.4006, -0.0590,  0.6543,  0.4781
-0.0083, -0.2730, -0.4488,  0.8495, -0.2260, -0.0142,  0.5854
-0.2335, -0.4049,  0.4352, -0.6183, -0.7636,  0.6740,  0.7596
 0.4883,  0.1810, -0.5142,  0.2465,  0.2767, -0.3449,  0.3995
-0.4922,  0.1828, -0.1424, -0.2358, -0.7466, -0.5115,  0.7968
-0.8413, -0.3943,  0.4834,  0.2300,  0.3448, -0.9832,  0.7989
-0.5382, -0.6502, -0.6300,  0.6885,  0.9652,  0.8275,  0.4353
-0.3053,  0.5604,  0.0929,  0.6329, -0.0325,  0.1799,  0.4848
 0.0740, -0.2680,  0.2086,  0.9176, -0.2144, -0.2141,  0.5856
 0.5813,  0.2902, -0.2122,  0.3779, -0.1920, -0.7278,  0.4079
-0.5641,  0.8515,  0.3793,  0.1976,  0.4933,  0.0839,  0.4716
 0.4011,  0.8611,  0.7252, -0.6651, -0.4737, -0.8568,  0.5708
-0.5785,  0.0056, -0.7901, -0.2223,  0.0760, -0.3216,  0.7252
 0.1118,  0.0735, -0.2188,  0.3925,  0.3570,  0.3746,  0.3688
 0.2262,  0.8715,  0.1938,  0.9592, -0.1180,  0.4792,  0.2952
-0.9248,  0.5295,  0.0366, -0.9894, -0.4456,  0.0697,  0.7335
 0.2992,  0.8629, -0.8505, -0.4464,  0.8385,  0.5300,  0.2702
 0.1995,  0.6659,  0.7921,  0.9454,  0.9970, -0.7207,  0.2996
-0.3066, -0.2927, -0.4923,  0.8220,  0.4513, -0.9481,  0.6617
-0.0770, -0.4374, -0.9421,  0.7694,  0.5420, -0.3405,  0.5131
-0.3842,  0.8562,  0.9538,  0.0471,  0.9039,  0.7760,  0.3215
 0.0361, -0.2545,  0.4207, -0.0887,  0.2104,  0.9808,  0.5202
-0.8220, -0.6302,  0.0537, -0.1658,  0.6013,  0.8664,  0.6598
-0.6443,  0.7201,  0.9148,  0.9189, -0.9243, -0.8848,  0.6095
-0.2880,  0.9074, -0.0461, -0.4435,  0.0060,  0.2867,  0.4025
-0.7775,  0.5161,  0.7039,  0.6885,  0.7810, -0.2363,  0.5234
-0.5484,  0.9426, -0.4308,  0.8148,  0.7811,  0.8450,  0.3479

Test data:

# synthetic_test.txt
#
-0.6877,  0.7594,  0.2640, -0.5787, -0.3098, -0.6802,  0.7071
-0.6694, -0.6056,  0.3821,  0.1476,  0.7466, -0.5107,  0.7282
 0.2592, -0.9311,  0.0324,  0.7265,  0.9683, -0.9803,  0.5832
-0.9049, -0.9797, -0.0196, -0.9090, -0.4433,  0.2799,  0.9018
-0.4106, -0.4607,  0.1811, -0.2389,  0.4050, -0.0078,  0.6916
-0.4259, -0.7336,  0.8742,  0.6097,  0.8761, -0.6292,  0.6728
 0.8663,  0.8715, -0.4329, -0.4507,  0.1029, -0.6294,  0.2936
 0.8948, -0.0124,  0.9278,  0.2899, -0.0314,  0.9354,  0.3160
-0.7136,  0.2647,  0.3238, -0.1323, -0.8813, -0.0146,  0.8133
-0.4867, -0.2171, -0.5197,  0.3729,  0.9798, -0.6451,  0.5820
 0.6429, -0.5380, -0.8840, -0.7224,  0.8703,  0.7771,  0.5777
 0.6999, -0.1307, -0.0639,  0.2597, -0.6839, -0.9704,  0.5796
-0.4690, -0.9691,  0.3490,  0.1029, -0.3567,  0.5604,  0.8151
-0.4154, -0.6081, -0.8241,  0.7400, -0.8236,  0.3674,  0.7881
-0.7592, -0.9786,  0.1145,  0.8142,  0.7209, -0.3231,  0.6968
 0.3393,  0.6156,  0.7950, -0.0923,  0.1157,  0.0123,  0.3229
 0.3840,  0.3658,  0.0406,  0.6569,  0.0116,  0.6497,  0.2879
 0.9397,  0.4839, -0.4804,  0.1625,  0.9105, -0.8385,  0.2410
-0.8329,  0.2383, -0.5510,  0.5304,  0.1363,  0.3324,  0.5862
-0.8255, -0.2579,  0.3443, -0.6208,  0.7915,  0.8997,  0.6109
 0.9231,  0.4602, -0.1874,  0.4875, -0.4240, -0.3712,  0.3165
 0.7573, -0.4908,  0.5324,  0.8820, -0.9979, -0.0478,  0.6093
 0.3141,  0.6866, -0.6325,  0.7123, -0.2713,  0.7845,  0.3050
-0.1647, -0.6616,  0.2998, -0.9260, -0.3768, -0.3530,  0.8315
 0.2149,  0.3017,  0.6921,  0.8552,  0.3209,  0.1563,  0.3157
-0.6918,  0.7902, -0.3780,  0.0970,  0.3641, -0.5271,  0.6323
-0.6645,  0.0170,  0.5837,  0.3848, -0.7621,  0.8015,  0.7440
 0.1069, -0.8304, -0.5951,  0.7085,  0.4119,  0.7899,  0.4998
-0.3417,  0.0560,  0.3008,  0.1886, -0.5371, -0.1464,  0.7339
 0.9734, -0.8669,  0.4279, -0.3398,  0.2509, -0.4837,  0.4665
 0.3020, -0.2577, -0.4104,  0.8235,  0.8850,  0.2271,  0.3066
-0.5766,  0.6603, -0.5198,  0.2632,  0.4215,  0.4848,  0.4478
-0.2195,  0.5197,  0.8059,  0.1748, -0.8192, -0.7420,  0.6740
-0.9212, -0.5169,  0.7581,  0.9470,  0.2108,  0.9525,  0.6180
-0.9131,  0.8971, -0.3774,  0.5979,  0.6213,  0.7200,  0.4642
-0.4842,  0.8689,  0.2382,  0.9709, -0.9347,  0.4503,  0.5662
 0.1311, -0.0152, -0.4816, -0.3463, -0.5011, -0.5615,  0.6979
-0.8336,  0.5540,  0.0673,  0.4788,  0.0308, -0.2001,  0.6917
 0.9725, -0.9435,  0.8655,  0.8617, -0.2182, -0.5711,  0.6021
 0.6064, -0.4921, -0.4184,  0.8318,  0.8058,  0.0708,  0.3221
This entry was posted in PyTorch, Transformers. Bookmark the permalink.

Leave a Reply