Adding a Transformer Module to a PyTorch Regression Network – No Numeric Pseudo-Embedding

I’ve been looking at adding a Transformer module to a PyTorch regression network. Because the key functionality of a Transformer is the attention mechanism, I’ve also been looking at adding a custom Attention module instead of a Transformer.

There are dozens of design alternatives, and many architecture and training hyperparameters. For the baseline architecture against which I’m comparing, I use a standard PyTorch TransformerEncoder layer, a numeric pseudo-embedding layer, and a simplified positional encoding layer.

All of my ideas are based on natural language processing systems. In NLP, each word/token is mapped to an integer, such as “fly” = 2863. Then the integer is mapped to a vector of real values, typically about 100-1000 values — the embedding. The idea is to capture different dimensions of the meanings of the source word/token, such as fly is a type of insect, or a way to travel, or a baseball play, etc. After embedding, in an NLP system positional encoding is necessary because the order of words in a sentence is critical. I intend to examine the architecture options as best I can.

For this investigation, I’m eliminating the numeric pseudo-embedding but leaving in the baseline simplified positional encoding layer. My working hypothesis going in was that the architecture would result in a model similar to one without a Transformer module. I was very wrong. The no-embedding architecture resulted in a very poor model. Strong implication: when using a Transformer, you need to use some form of embedding even when the inputs are already numeric.



A diagram of a 6-6-PE-T-(8-8)-1 system, not the 6-6-PE-T-(10-10)-1 system of the demo.


The demo uses one of my standard synthetic datasets. The data looks like:

-0.1660,  0.4406, -0.9998, -0.3953, -0.7065, -0.8153,  0.7022
-0.2065,  0.0776, -0.1616,  0.3704, -0.5911,  0.7562,  0.5666
-0.9452,  0.3409, -0.1654,  0.1174, -0.7192, -0.6038,  0.8186
 0.7528,  0.7892, -0.8299, -0.9219, -0.6603,  0.7563,  0.3687
. . .

The first six values on each line are predictors. The last value on each line is the target to predict. The data was generated by a 6-10-1 neural network with random weights and biases. There are 200 training items and 40 test items.

The baseline architecture is 6-24-PE-T-(10-10)-1 meaning 6 inputs, each to 4 embedding, plus positional encoding, fed to a Transformer, sent to two hidden layers of 10 nodes each, sent to a single output prediction node. The demo architecture explored in this blog post is 6-6-PE-T-(10-10)-1 meaning 6 inputs, fed to a preliminary fully connected layer in place of an embedding layer, positional encoding added, sent to a Transformer (with embedding dimension set to 1), sent to two fully connected layers with 10 nodes each, sent to a single output node.

. . . 
self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
  num_layers=2, enable_nested_tensor=False)
. . .

I discovered that I needed to set the enable_nested_tensor parameter to False in the TransformerEncoder layer, in order to deal with an embedding dimension of 1.

Here’s the output of one demo run:

Begin Transformer regression on synthetic data
No pseudo-embedding but using position encoding

Loading train (200) and test (40) data to memory
Done

First three rows of training predictors:
tensor([-0.1660,  0.4406, -0.9998, -0.3953, -0.7065, -0.8153])
tensor([-0.2065,  0.0776, -0.1616,  0.3704, -0.5911,  0.7562])
tensor([-0.9452,  0.3409, -0.1654,  0.1174, -0.7192, -0.6038])

First three target y values:
0.7022
0.5666
0.8186

Creating 6-6-PE-T-(10-10)-1 regression model

bat_size = 10
loss = MSELoss()
optimizer = Adam
lrn_rate = 0.001

Starting training
epoch =    0  |  loss = 0.7518
epoch =   20  |  loss = 0.5424
epoch =   40  |  loss = 0.5409
epoch =   60  |  loss = 0.5426
epoch =   80  |  loss = 0.5405
Done

Computing model accuracy (within 0.10 of true)
Accuracy on train data = 0.2100
Accuracy on test data = 0.2250

MSE on train data = 0.0270
MSE on test data = 0.0328

Predicting target y for train[0]:
Predicted y = 0.5489

End demo

The model accuracy (21% = 42 out of 200 correct on the training data) and mean squared error (0.0270 on training data) were very poor compared to the baseline architecture (84% accuracy and MSE = 0.0012).

An interesting experiment.



I’m a big fan of old science fiction and fantasy movies from the 1950s and 1960s. Several old films featured a woman who could transform into a snake. Here are three examples.

Left: The UK movie “The Reptile” (1966) takes place in a small English village in the 1800s. Mysterious deaths. Anna, the daughter of the sinister Dr. Franklyn, was cursed by a Malay snake cult and periodically transforms into a not-very-nice snake woman. I give this movie a B- grade.

Center: “The 7th Voyage of Sinbad” (1958) features wonderful stop motion special effects by the well-known Ray Harryhausen. In this scene, a wizard transforms a dancing girl into a dancing snake-girl. I give this film an A- grade.

Right: In “Cult of the Cobra” (1955), a group of six U.S. servicemen sneak into a secret ceremony of a cult of snake worshippers. Not a good idea. A woman can turn into a deadly cobra. The movie is effective because the transformation is never shown. This movie has a lot of actors who went on to very successful careers. I give the movie a B grade.


Demo code. Replace “lt” (less than) with Boolean operator symbol (my blog editor chokes on symbols).

# synthetic_transformer_no_embed_w_pe.py

# regression with Transformer, no pseudo-embedding,
# with positional encoding on a synthetic dataset 
# PyTorch 2.3.1-CPU  Anaconda3-2023.09  Python 3.11.5
# Windows 10/11 

import numpy as np
import torch as T  # non-standard alias

device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class SynthDataset(T.utils.data.Dataset):
  # six predictors followed by target
  def __init__(self, src_file):
    tmp_x = np.loadtxt(src_file, delimiter=",",
      usecols=[0,1,2,3,4,5], dtype=np.float32)
    tmp_y = np.loadtxt(src_file, usecols=6, delimiter=",",
      dtype=np.float32)
    tmp_y = tmp_y.reshape(-1,1)  # 2D required

    self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y, dtype=T.float32).to(device)

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return (preds, trgts)  # as a tuple

# -----------------------------------------------------------

class PositionEncode(T.nn.Module):
  # simplified version
  def __init__(self, n_features):
    super(PositionEncode, self).__init__()  # old syntax
    self.nf = n_features
    self.pe = T.zeros(n_features, dtype=T.float32)
    for i in range(n_features):
      self.pe[i] = i * (0.01 / n_features)  # no sin, cos

  def forward(self, x):
    for i in range(len(x)):
      for j in range(len(x[0])):
        x[i][j] += self.pe[j]
    return x

# -----------------------------------------------------------

class TransformerNet(T.nn.Module):
  def __init__(self):
    super(TransformerNet, self).__init__()
    self.prelim = T.nn.Linear(6, 6)
    self.pos_enc = \
      PositionEncode(6)  # positional
    self.enc_layer = T.nn.TransformerEncoderLayer(d_model=1,
      nhead=1, dim_feedforward=10, 
      batch_first=True)  # d_model divisible by nhead
    self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
      num_layers=2,
      enable_nested_tensor=False)  # 6 layers is default
    self.fc1 = T.nn.Linear(6, 10)  # 6-6-PE-T-10-10-1
    self.fc2 = T.nn.Linear(10, 10)
    self.fc3 = T.nn.Linear(10, 1)

    # use default weight and bias initialization

  def forward(self, x):
    z = self.prelim(x)
    z = z.reshape(-1, 6, 1)
    z = self.pos_enc(z) 
    z = self.trans_enc(z) 
    z = z.reshape(-1, 6)
    z = T.tanh(self.fc1(z))
    z = T.tanh(self.fc2(z))
    z = self.fc3(z)  # regression: no activation
    return z

# -----------------------------------------------------------

def accuracy(model, ds, pct_close):
  # assumes model.eval()
  # correct within pct of true income
  n_correct = 0; n_wrong = 0

  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)
    with T.no_grad():
      oupt = model(X)         # predicted target

    if T.abs(oupt - Y) "lt" T.abs(pct_close * Y):
      n_correct += 1
    else:
      n_wrong += 1
  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def mean_sq_err(model, ds):
  # assumes model.eval()
  n = len(ds)
  sum = 0.0

  for i in range(n):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    y = ds[i][1].reshape(1)
    with T.no_grad():
      oupt = model(X)         # predicted target
    diff = oupt.item() - y.item()
    sum += diff * diff 
  mse = sum / n
  return mse

# -----------------------------------------------------------

def train(model, ds, bs, lr, me, le):
  # dataset, bat_size, lrn_rate, max_epochs, log interval
  train_ldr = T.utils.data.DataLoader(ds, batch_size=bs,
    shuffle=True)
  loss_func = T.nn.MSELoss()
  optimizer = T.optim.Adam(model.parameters(), lr=lr)
  # optimizer = T.optim.SGD(model.parameters(), lr=lr)

  for epoch in range(0, me):
    epoch_loss = 0.0  # for one full epoch
    for (b_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # predictors
      y = batch[1]  # target
      optimizer.zero_grad()
      oupt = model(X)
      loss_val = loss_func(oupt, y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()  # compute gradients
      optimizer.step()     # update weights

    if epoch % le == 0:
      print("epoch = %4d  |  loss = %0.4f" % \
        (epoch, epoch_loss)) 

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin Transformer regression on synthetic data ")
  print("No pseudo-embedding but using position encoding ")
  np.random.seed(0)
  T.manual_seed(0) 

  # 1. load data
  print("\nLoading train (200) and test (40) data to memory ")
  train_file = ".\\Data\\synthetic_train.txt"
  train_ds = SynthDataset(train_file)  # 200 rows
  test_file = ".\\Data\\synthetic_test.txt"
  test_ds = SynthDataset(test_file)    # 40 rows
  print("Done ")

  print("\nFirst three rows of training predictors: ")
  for i in range(3):
    print(train_ds[i][0])
  print("\nFirst three target y values: ")
  for i in range(3):
    print("%0.4f " % train_ds[i][1])

  # 2. create regression model
  print("\nCreating 6-6-PE-T-(10-10)-1 regression model ")
  net = TransformerNet().to(device)

  # 3. train model
  print("\nbat_size = 10 ")
  print("loss = MSELoss() ")
  print("optimizer = Adam ")
  print("lrn_rate = 0.001 ")

  print("\nStarting training")
  net.train()
  train(net, train_ds, bs=10, lr=0.001, me=100, le=20)
  print("Done ")

  # 4. evaluate model
  net.eval()
  print("\nComputing model accuracy (within 0.10 of true) ")
  acc_train = accuracy(net, train_ds, 0.10)  # item-by-item
  print("Accuracy on train data = %0.4f" % acc_train)
  acc_test = accuracy(net, test_ds, 0.10) 
  print("Accuracy on test data = %0.4f" % acc_test)

  mse_train = mean_sq_err(net, train_ds)
  print("\nMSE on train data = %0.4f" % mse_train)
  mse_test = mean_sq_err(net, test_ds)
  print("MSE on test data = %0.4f" % mse_test)

  # 5. make a prediction
  print("\nPredicting target y for train[0]: ")
  x = train_ds[0][0].reshape(1,-1)  # item - predictors

  with T.no_grad():
    y = net(x)
  pred_raw = y.item()  # scalar
  print("Predicted y = %0.4f" % pred_raw)  

  # 6. TODO: save model (state_dict approach)

  print("\nEnd demo ")

# -----------------------------------------------------------

if __name__=="__main__":
  main()

Training data:

# synthetic_train.txt
#
-0.1660,  0.4406, -0.9998, -0.3953, -0.7065, -0.8153,  0.7022
-0.2065,  0.0776, -0.1616,  0.3704, -0.5911,  0.7562,  0.5666
-0.9452,  0.3409, -0.1654,  0.1174, -0.7192, -0.6038,  0.8186
 0.7528,  0.7892, -0.8299, -0.9219, -0.6603,  0.7563,  0.3687
-0.8033, -0.1578,  0.9158,  0.0663,  0.3838, -0.3690,  0.7535
-0.9634,  0.5003,  0.9777,  0.4963, -0.4391,  0.5786,  0.7076
-0.7935, -0.1042,  0.8172, -0.4128, -0.4244, -0.7399,  0.8454
-0.0169, -0.8933,  0.1482, -0.7065,  0.1786,  0.3995,  0.7302
-0.7953, -0.1719,  0.3888, -0.1716, -0.9001,  0.0718,  0.8692
 0.8892,  0.1731,  0.8068, -0.7251, -0.7214,  0.6148,  0.4740
-0.2046, -0.6693,  0.8550, -0.3045,  0.5016,  0.4520,  0.6714
 0.5019, -0.3022, -0.4601,  0.7918, -0.1438,  0.9297,  0.4331
 0.3269,  0.2434, -0.7705,  0.8990, -0.1002,  0.1568,  0.3716
 0.8068,  0.1474, -0.9943,  0.2343, -0.3467,  0.0541,  0.3829
 0.7719, -0.2855,  0.8171,  0.2467, -0.9684,  0.8589,  0.4700
 0.8652,  0.3936, -0.8680,  0.5109,  0.5078,  0.8460,  0.2648
 0.4230, -0.7515, -0.9602, -0.9476, -0.9434, -0.5076,  0.8059
 0.1056,  0.6841, -0.7517, -0.4416,  0.1715,  0.9392,  0.3512
 0.1221, -0.9627,  0.6013, -0.5341,  0.6142, -0.2243,  0.6840
 0.1125, -0.7271, -0.8802, -0.7573, -0.9109, -0.7850,  0.8640
-0.5486,  0.4260,  0.1194, -0.9749, -0.8561,  0.9346,  0.6109
-0.4953,  0.4877, -0.6091,  0.1627,  0.9400,  0.6937,  0.3382
-0.5203, -0.0125,  0.2399,  0.6580, -0.6864, -0.9628,  0.7400
 0.2127,  0.1377, -0.3653,  0.9772,  0.1595, -0.2397,  0.4081
 0.1019,  0.4907,  0.3385, -0.4702, -0.8673, -0.2598,  0.6582
 0.5055, -0.8669, -0.4794,  0.6095, -0.6131,  0.2789,  0.6644
 0.0493,  0.8496, -0.4734, -0.8681,  0.4701,  0.5444,  0.3214
 0.9004,  0.1133,  0.8312,  0.2831, -0.2200, -0.0280,  0.3149
 0.2086,  0.0991,  0.8524,  0.8375, -0.2102,  0.9265,  0.3619
-0.7298,  0.0113, -0.9570,  0.8959,  0.6542, -0.9700,  0.6451
-0.6476, -0.3359, -0.7380,  0.6190, -0.3105,  0.8802,  0.6606
 0.6895,  0.8108, -0.0802,  0.0927,  0.5972, -0.4286,  0.2427
-0.0195,  0.1982, -0.9689,  0.1870, -0.1326,  0.6147,  0.4773
 0.1557, -0.6320,  0.5759,  0.2241, -0.8922, -0.1596,  0.7581
 0.3581,  0.8372, -0.9992,  0.9535, -0.2468,  0.9476,  0.2962
 0.1494,  0.2562, -0.4288,  0.1737,  0.5000,  0.7166,  0.3513
 0.5102,  0.3961,  0.7290, -0.3546,  0.3416, -0.0983,  0.3153
-0.1970, -0.3652,  0.2438, -0.1395,  0.9476,  0.3556,  0.4719
-0.6029, -0.1466, -0.3133,  0.5953,  0.7600,  0.8077,  0.3875
-0.4953,  0.7098,  0.0554,  0.6043,  0.1450,  0.4663,  0.4739
 0.0380,  0.5418,  0.1377, -0.0686, -0.3146, -0.8636,  0.6048
 0.9656, -0.6368,  0.6237,  0.7499,  0.3768,  0.1390,  0.3705
-0.6781, -0.0662, -0.3097, -0.5499,  0.1850, -0.3755,  0.7668
-0.6141, -0.0008,  0.4572, -0.5836, -0.5039,  0.7033,  0.7301
-0.1683,  0.2334, -0.5327, -0.7961,  0.0317, -0.0457,  0.5777
 0.0880,  0.3083, -0.7109,  0.5031, -0.5559,  0.0387,  0.5118
 0.5706, -0.9553, -0.3513,  0.7458,  0.6894,  0.0769,  0.4329
-0.8025,  0.3026,  0.4070,  0.2205,  0.5992, -0.9309,  0.7098
 0.5405,  0.4635, -0.4806, -0.4859,  0.2646, -0.3094,  0.3566
 0.5655,  0.9809, -0.3995, -0.7140,  0.8026,  0.0831,  0.2551
 0.9495,  0.2732,  0.9878,  0.0921,  0.0529, -0.7291,  0.3074
-0.6792,  0.4913, -0.9392, -0.2669,  0.7247,  0.3854,  0.4362
 0.3819, -0.6227, -0.1162,  0.1632,  0.9795, -0.5922,  0.4435
 0.5003, -0.0860, -0.8861,  0.0170, -0.5761,  0.5972,  0.5136
-0.4053, -0.9448,  0.1869,  0.6877, -0.2380,  0.4997,  0.7859
 0.9189,  0.6079, -0.9354,  0.4188, -0.0700,  0.8951,  0.2696
-0.5571, -0.4659, -0.8371, -0.1428, -0.7820,  0.2676,  0.8566
 0.5324, -0.3151,  0.6917, -0.1425,  0.6480,  0.2530,  0.4252
-0.7132, -0.8432, -0.9633, -0.8666, -0.0828, -0.7733,  0.9217
-0.0952, -0.0998, -0.0439, -0.0520,  0.6063, -0.1952,  0.5140
 0.8094, -0.9259,  0.5477, -0.7487,  0.2370, -0.9793,  0.5562
 0.9024,  0.8108,  0.5919,  0.8305, -0.7089, -0.6845,  0.2993
-0.6247,  0.2450,  0.8116,  0.9799,  0.4222,  0.4636,  0.4619
-0.5003, -0.6531, -0.7611,  0.6252, -0.7064, -0.4714,  0.8452
 0.6382, -0.3788,  0.9648, -0.4667,  0.0673, -0.3711,  0.5070
-0.1328,  0.0246,  0.8778, -0.9381,  0.4338,  0.7820,  0.5680
-0.9454,  0.0441, -0.3480,  0.7190,  0.1170,  0.3805,  0.6562
-0.4198, -0.9813,  0.1535, -0.3771,  0.0345,  0.8328,  0.7707
-0.1471, -0.5052, -0.2574,  0.8637,  0.8737,  0.6887,  0.3436
-0.3712, -0.6505,  0.2142, -0.1728,  0.6327, -0.6297,  0.7430
 0.4038, -0.5193,  0.1484, -0.3020, -0.8861, -0.5424,  0.7499
 0.0380, -0.6506,  0.1414,  0.9935,  0.6337,  0.1887,  0.4509
 0.9520,  0.8031,  0.1912, -0.9351, -0.8128, -0.8693,  0.5336
 0.9507, -0.6640,  0.9456,  0.5349,  0.6485,  0.2652,  0.3616
 0.3375, -0.0462, -0.9737, -0.2940, -0.0159,  0.4602,  0.4840
-0.7247, -0.9782,  0.5166, -0.3601,  0.9688, -0.5595,  0.7751
-0.3226,  0.0478,  0.5098, -0.0723, -0.7504, -0.3750,  0.8025
 0.5403, -0.7393, -0.9542,  0.0382,  0.6200, -0.9748,  0.5359
 0.3449,  0.3736, -0.1015,  0.8296,  0.2887, -0.9895,  0.4390
 0.6608,  0.2983,  0.3474,  0.1570, -0.4518,  0.1211,  0.3624
 0.3435, -0.2951,  0.7117, -0.6099,  0.4946, -0.4208,  0.5283
 0.6154, -0.2929, -0.5726,  0.5346, -0.3827,  0.4665,  0.4907
 0.4889, -0.5572, -0.5718, -0.6021, -0.7150, -0.2458,  0.7202
-0.8389, -0.5366, -0.5847,  0.8347,  0.4226,  0.1078,  0.6391
-0.3910,  0.6697, -0.1294,  0.8469,  0.4121, -0.0439,  0.4693
-0.1376, -0.1916, -0.7065,  0.4586, -0.6225,  0.2878,  0.6695
 0.5086, -0.5785,  0.2019,  0.4979,  0.2764,  0.1943,  0.4666
 0.8906, -0.1489,  0.5644, -0.8877,  0.6705, -0.6155,  0.3480
-0.2098, -0.3998, -0.8398,  0.8093, -0.2597,  0.0614,  0.6341
-0.5871, -0.8476,  0.0158, -0.4769, -0.2859, -0.7839,  0.9006
 0.5751, -0.7868,  0.9714, -0.6457,  0.1448, -0.9103,  0.6049
 0.0558,  0.4802, -0.7001,  0.1022, -0.5668,  0.5184,  0.4612
 0.4458, -0.6469,  0.7239, -0.9604,  0.7205,  0.1178,  0.5941
 0.4339,  0.9747, -0.4438, -0.9924,  0.8678,  0.7158,  0.2627
 0.4577,  0.0334,  0.4139,  0.5611, -0.2502,  0.5406,  0.3847
-0.1963,  0.3946, -0.9938,  0.5498,  0.7928, -0.5214,  0.5025
-0.7585, -0.5594, -0.3958,  0.7661,  0.0863, -0.4266,  0.7481
 0.2277, -0.3517, -0.0853, -0.1118,  0.6563, -0.1473,  0.4798
-0.3086,  0.3499, -0.5570, -0.0655, -0.3705,  0.2537,  0.5768
 0.5689, -0.0861,  0.3125, -0.7363, -0.1340,  0.8186,  0.5035
 0.2110,  0.5335,  0.0094, -0.0039,  0.6858, -0.8644,  0.4243
 0.0357, -0.6111,  0.6959, -0.4967,  0.4015,  0.0805,  0.6611
 0.8977,  0.2487,  0.6760, -0.9841,  0.9787, -0.8446,  0.2873
-0.9821,  0.6455,  0.7224, -0.1203, -0.4885,  0.6054,  0.6908
-0.0443, -0.7313,  0.8557,  0.7919, -0.0169,  0.7134,  0.6039
-0.2040,  0.0115, -0.6209,  0.9300, -0.4116, -0.7931,  0.6495
-0.7114, -0.9718,  0.4319,  0.1290,  0.5892,  0.0142,  0.7675
 0.5557, -0.1870,  0.2955, -0.6404, -0.3564, -0.6548,  0.6295
-0.1827, -0.5172, -0.1862,  0.9504, -0.3594,  0.9650,  0.5685
 0.7150,  0.2392, -0.4959,  0.5857, -0.1341, -0.2850,  0.3585
-0.3394,  0.3947, -0.4627,  0.6166, -0.4094,  0.0882,  0.5962
 0.7768, -0.6312,  0.1707,  0.7964, -0.1078,  0.8437,  0.4243
-0.4420,  0.2177,  0.3649, -0.5436, -0.9725, -0.1666,  0.8086
 0.5595, -0.6505, -0.3161, -0.7108,  0.4335,  0.3986,  0.5846
 0.3770, -0.4932,  0.3847, -0.5454, -0.1507, -0.2562,  0.6335
 0.2633,  0.4146,  0.2272,  0.2966, -0.6601, -0.7011,  0.5653
 0.0284,  0.7507, -0.6321, -0.0743, -0.1421, -0.0054,  0.4219
-0.4762,  0.6891,  0.6007, -0.1467,  0.2140, -0.7091,  0.6098
 0.0192, -0.4061,  0.7193,  0.3432,  0.2669, -0.7505,  0.6549
 0.8966,  0.2902, -0.6966,  0.2783,  0.1313, -0.0627,  0.2876
-0.1439,  0.1985,  0.6999,  0.5022,  0.1587,  0.8494,  0.3872
 0.2473, -0.9040, -0.4308, -0.8779,  0.4070,  0.3369,  0.6825
-0.2428, -0.6236,  0.4940, -0.3192,  0.5906, -0.0242,  0.6770
 0.2885, -0.2987, -0.5416, -0.1322, -0.2351, -0.0604,  0.6106
 0.9590, -0.2712,  0.5488,  0.1055,  0.7783, -0.2901,  0.2956
-0.9129,  0.9015,  0.1128, -0.2473,  0.9901, -0.8833,  0.6500
 0.0334, -0.9378,  0.1424, -0.6391,  0.2619,  0.9618,  0.7033
 0.4169,  0.5549, -0.0103,  0.0571, -0.6984, -0.2612,  0.4935
-0.7156,  0.4538, -0.0460, -0.1022,  0.7720,  0.0552,  0.4983
-0.8560, -0.1637, -0.9485, -0.4177,  0.0070,  0.9319,  0.6445
-0.7812,  0.3461, -0.0001,  0.5542, -0.7128, -0.8336,  0.7720
-0.6166,  0.5356, -0.4194, -0.5662, -0.9666, -0.2027,  0.7401
-0.2378,  0.3187, -0.8582, -0.6948, -0.9668, -0.7724,  0.7670
-0.3579,  0.1158,  0.9869,  0.6690,  0.3992,  0.8365,  0.4184
-0.9205, -0.8593, -0.0520, -0.3017,  0.8745, -0.0209,  0.7723
-0.1067,  0.7541, -0.4928, -0.4524, -0.3433,  0.0951,  0.4645
-0.5597,  0.3429, -0.7144, -0.8118,  0.7404, -0.5263,  0.6117
 0.0516, -0.8480,  0.7483,  0.9023,  0.6250, -0.4324,  0.5987
 0.0557, -0.3212,  0.1093,  0.9488, -0.3766,  0.3376,  0.5739
-0.3484,  0.7797,  0.5034,  0.5253, -0.0610, -0.5785,  0.5365
-0.9170, -0.3563, -0.9258,  0.3877,  0.3407, -0.1391,  0.7131
-0.9203, -0.7304, -0.6132, -0.3287, -0.8954,  0.2102,  0.9329
 0.0241,  0.2349, -0.1353,  0.6954, -0.0919, -0.9692,  0.5744
 0.6460,  0.9036, -0.8982, -0.5299, -0.8733, -0.1567,  0.4425
 0.7277, -0.8368, -0.0538, -0.7489,  0.5458,  0.6828,  0.5848
-0.5212,  0.9049,  0.8878,  0.2279,  0.9470, -0.3103,  0.4255
 0.7957, -0.1308, -0.5284,  0.8817,  0.3684, -0.8702,  0.3969
 0.2099,  0.4647, -0.4931,  0.2010,  0.6292, -0.8918,  0.4620
-0.7390,  0.6849,  0.2367,  0.0626, -0.5034, -0.4098,  0.7137
-0.8711,  0.7940, -0.5932,  0.6525,  0.7635, -0.0265,  0.5705
 0.1969,  0.0545,  0.2496,  0.7101, -0.4357,  0.7675,  0.4242
-0.5460,  0.1920, -0.5211, -0.7372, -0.6763,  0.6897,  0.6769
 0.2044,  0.9271, -0.3086,  0.1913,  0.1980,  0.2314,  0.2998
-0.6149,  0.5059, -0.9854, -0.3435,  0.8352,  0.1767,  0.4497
 0.7104,  0.2093,  0.6452,  0.7590, -0.3580, -0.7541,  0.4076
-0.7465,  0.1796, -0.9279, -0.5996,  0.5766, -0.9758,  0.7713
-0.3933, -0.9572,  0.9950,  0.1641, -0.4132,  0.8579,  0.7421
 0.1757, -0.4717, -0.3894, -0.2567, -0.5111,  0.1691,  0.7088
 0.3917, -0.8561,  0.9422,  0.5061,  0.6123,  0.5033,  0.4824
-0.1087,  0.3449, -0.1025,  0.4086,  0.3633,  0.3943,  0.3760
 0.2372, -0.6980,  0.5216,  0.5621,  0.8082, -0.5325,  0.5297
-0.3589,  0.6310,  0.2271,  0.5200, -0.1447, -0.8011,  0.5903
-0.7699, -0.2532, -0.6123,  0.6415,  0.1993,  0.3777,  0.6039
-0.5298, -0.0768, -0.6028, -0.9490,  0.4588,  0.4498,  0.6159
-0.3392,  0.6870, -0.1431,  0.7294,  0.3141,  0.1621,  0.4501
 0.7889, -0.3900,  0.7419,  0.8175, -0.3403,  0.3661,  0.4087
 0.7984, -0.8486,  0.7572, -0.6183,  0.6995,  0.3342,  0.5025
 0.2707,  0.6956,  0.6437,  0.2565,  0.9126,  0.1798,  0.2331
-0.6043, -0.1413, -0.3265,  0.9839, -0.2395,  0.9854,  0.5444
-0.8509, -0.2594, -0.7532,  0.2690, -0.1722,  0.9818,  0.6516
 0.8599, -0.7015, -0.2102, -0.0768,  0.1219,  0.5607,  0.4747
-0.4760,  0.8216, -0.9555,  0.6422, -0.6231,  0.3715,  0.5485
-0.2896,  0.9484, -0.7545, -0.6249,  0.7789,  0.1668,  0.3415
-0.5931,  0.7926,  0.7462,  0.4006, -0.0590,  0.6543,  0.4781
-0.0083, -0.2730, -0.4488,  0.8495, -0.2260, -0.0142,  0.5854
-0.2335, -0.4049,  0.4352, -0.6183, -0.7636,  0.6740,  0.7596
 0.4883,  0.1810, -0.5142,  0.2465,  0.2767, -0.3449,  0.3995
-0.4922,  0.1828, -0.1424, -0.2358, -0.7466, -0.5115,  0.7968
-0.8413, -0.3943,  0.4834,  0.2300,  0.3448, -0.9832,  0.7989
-0.5382, -0.6502, -0.6300,  0.6885,  0.9652,  0.8275,  0.4353
-0.3053,  0.5604,  0.0929,  0.6329, -0.0325,  0.1799,  0.4848
 0.0740, -0.2680,  0.2086,  0.9176, -0.2144, -0.2141,  0.5856
 0.5813,  0.2902, -0.2122,  0.3779, -0.1920, -0.7278,  0.4079
-0.5641,  0.8515,  0.3793,  0.1976,  0.4933,  0.0839,  0.4716
 0.4011,  0.8611,  0.7252, -0.6651, -0.4737, -0.8568,  0.5708
-0.5785,  0.0056, -0.7901, -0.2223,  0.0760, -0.3216,  0.7252
 0.1118,  0.0735, -0.2188,  0.3925,  0.3570,  0.3746,  0.3688
 0.2262,  0.8715,  0.1938,  0.9592, -0.1180,  0.4792,  0.2952
-0.9248,  0.5295,  0.0366, -0.9894, -0.4456,  0.0697,  0.7335
 0.2992,  0.8629, -0.8505, -0.4464,  0.8385,  0.5300,  0.2702
 0.1995,  0.6659,  0.7921,  0.9454,  0.9970, -0.7207,  0.2996
-0.3066, -0.2927, -0.4923,  0.8220,  0.4513, -0.9481,  0.6617
-0.0770, -0.4374, -0.9421,  0.7694,  0.5420, -0.3405,  0.5131
-0.3842,  0.8562,  0.9538,  0.0471,  0.9039,  0.7760,  0.3215
 0.0361, -0.2545,  0.4207, -0.0887,  0.2104,  0.9808,  0.5202
-0.8220, -0.6302,  0.0537, -0.1658,  0.6013,  0.8664,  0.6598
-0.6443,  0.7201,  0.9148,  0.9189, -0.9243, -0.8848,  0.6095
-0.2880,  0.9074, -0.0461, -0.4435,  0.0060,  0.2867,  0.4025
-0.7775,  0.5161,  0.7039,  0.6885,  0.7810, -0.2363,  0.5234
-0.5484,  0.9426, -0.4308,  0.8148,  0.7811,  0.8450,  0.3479

Test data:

# synthetic_test.txt
#
-0.6877,  0.7594,  0.2640, -0.5787, -0.3098, -0.6802,  0.7071
-0.6694, -0.6056,  0.3821,  0.1476,  0.7466, -0.5107,  0.7282
 0.2592, -0.9311,  0.0324,  0.7265,  0.9683, -0.9803,  0.5832
-0.9049, -0.9797, -0.0196, -0.9090, -0.4433,  0.2799,  0.9018
-0.4106, -0.4607,  0.1811, -0.2389,  0.4050, -0.0078,  0.6916
-0.4259, -0.7336,  0.8742,  0.6097,  0.8761, -0.6292,  0.6728
 0.8663,  0.8715, -0.4329, -0.4507,  0.1029, -0.6294,  0.2936
 0.8948, -0.0124,  0.9278,  0.2899, -0.0314,  0.9354,  0.3160
-0.7136,  0.2647,  0.3238, -0.1323, -0.8813, -0.0146,  0.8133
-0.4867, -0.2171, -0.5197,  0.3729,  0.9798, -0.6451,  0.5820
 0.6429, -0.5380, -0.8840, -0.7224,  0.8703,  0.7771,  0.5777
 0.6999, -0.1307, -0.0639,  0.2597, -0.6839, -0.9704,  0.5796
-0.4690, -0.9691,  0.3490,  0.1029, -0.3567,  0.5604,  0.8151
-0.4154, -0.6081, -0.8241,  0.7400, -0.8236,  0.3674,  0.7881
-0.7592, -0.9786,  0.1145,  0.8142,  0.7209, -0.3231,  0.6968
 0.3393,  0.6156,  0.7950, -0.0923,  0.1157,  0.0123,  0.3229
 0.3840,  0.3658,  0.0406,  0.6569,  0.0116,  0.6497,  0.2879
 0.9397,  0.4839, -0.4804,  0.1625,  0.9105, -0.8385,  0.2410
-0.8329,  0.2383, -0.5510,  0.5304,  0.1363,  0.3324,  0.5862
-0.8255, -0.2579,  0.3443, -0.6208,  0.7915,  0.8997,  0.6109
 0.9231,  0.4602, -0.1874,  0.4875, -0.4240, -0.3712,  0.3165
 0.7573, -0.4908,  0.5324,  0.8820, -0.9979, -0.0478,  0.6093
 0.3141,  0.6866, -0.6325,  0.7123, -0.2713,  0.7845,  0.3050
-0.1647, -0.6616,  0.2998, -0.9260, -0.3768, -0.3530,  0.8315
 0.2149,  0.3017,  0.6921,  0.8552,  0.3209,  0.1563,  0.3157
-0.6918,  0.7902, -0.3780,  0.0970,  0.3641, -0.5271,  0.6323
-0.6645,  0.0170,  0.5837,  0.3848, -0.7621,  0.8015,  0.7440
 0.1069, -0.8304, -0.5951,  0.7085,  0.4119,  0.7899,  0.4998
-0.3417,  0.0560,  0.3008,  0.1886, -0.5371, -0.1464,  0.7339
 0.9734, -0.8669,  0.4279, -0.3398,  0.2509, -0.4837,  0.4665
 0.3020, -0.2577, -0.4104,  0.8235,  0.8850,  0.2271,  0.3066
-0.5766,  0.6603, -0.5198,  0.2632,  0.4215,  0.4848,  0.4478
-0.2195,  0.5197,  0.8059,  0.1748, -0.8192, -0.7420,  0.6740
-0.9212, -0.5169,  0.7581,  0.9470,  0.2108,  0.9525,  0.6180
-0.9131,  0.8971, -0.3774,  0.5979,  0.6213,  0.7200,  0.4642
-0.4842,  0.8689,  0.2382,  0.9709, -0.9347,  0.4503,  0.5662
 0.1311, -0.0152, -0.4816, -0.3463, -0.5011, -0.5615,  0.6979
-0.8336,  0.5540,  0.0673,  0.4788,  0.0308, -0.2001,  0.6917
 0.9725, -0.9435,  0.8655,  0.8617, -0.2182, -0.5711,  0.6021
 0.6064, -0.4921, -0.4184,  0.8318,  0.8058,  0.0708,  0.3221
This entry was posted in PyTorch, Transformers. Bookmark the permalink.

Leave a Reply