PyTorch IMDB Example Using LSTM Batch-First Geometry

This post assumes you know what the IMDB movie review problem is, and what LSTMs are.


Demo run using the batch-first approach.

If you implement a standard PyTorch Dataset object for IMDB movie reviews, data will be served up by the associated DataLoader in a “batch first” geometry. Weirdly, the default geometry for an LSTM layer/module is “sequence first”, not batch first. To use the easier-to-understand batch-first approach, you 1.) use the batch_first=True in the LSTM definition, 2.) serve up batches of training data without any changes, and 3.) fetch output as lstm_out[:,-1] rather than lstm_out[-1]. Here are some side-by-side code fragments to illustrate.

When defining the LSTM layer in the overall neural network:

self.lstm = T.nn.LSTM(32, 100)  # sequence-first

self.lstm = T.nn.LSTM(32, 100, 
  batch_first=True)             # batch-first

When grabbing the output in the forward() method in the neural network:

z = lstm_oupt[-1]     # sequence-first

z = lstm_oupt[:, -1]  # batch-first

When serving up data and sending the the network during training:

# sequence-first
for (batch_idx, batch) in enumerate(train_ldr):
  X = T.transpose(batch[0], 0, 1)  # seq first
  Y = batch[1]
  oupt = net(X)

# batch-first:
for (batch_idx, batch) in enumerate(train_ldr):
  X = batch[0]
  Y = batch[1]
  oupt = net(X)

In most cases when there are two ways to implement code, I’ll have a preference for one way or another. But for LSTM geometry, the pros and con of the batch-first vs. sequence-first are mostly a matter of style and the two approaches balance out, for me at least.



I usually prefer certain types of coding style. But for music, I enjoy almost every style I can think of. Here are three examples of musical style that are all wildly different, but styles that I like and have a personal connection to.

Left: The group “Gentle Soul” featured Pamela Polland and Rick Stanley. Beautiful harmonies. I sent a cold-email to Pamela and she graciously replied. She seems like a super nice person in addition to being an incredible songwriter and singer.

Center: Herb Alpert and the Tijuana Brass was an international sensation in the 1960s. This album cover is one of the most famous in history. The model is Delores Erickson. She grew up in, and still lives in, the Pacific Northwest where I live too. I met her at a party several years ago. She was very courteous and friendly.

Right: The “Strawberry Alarm Clock” band had a gigantic psychedelic hit “Incense and Peppermints” in 1967. My guitar instructor in Fullerton, California, was Larry Samson who knew members of the band and (I think) sometimes played with them. He was a great guy as well as an incredible guitarist. He taught me and my pal Paul Ruiz to play “Pipeline” as our first song.

Machine learning and coding are important, but person-to-person connections are ultimately more important.


Demo code:

# imdb_lstm_baatch_first.py
# PyTorch 1.10.0-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11

import numpy as np
import torch as T
device = T.device('cpu')

# -----------------------------------------------------------

class LSTM_Net(T.nn.Module):
  def __init__(self):
    # vocab_size = 129892
    super(LSTM_Net, self).__init__()
    self.embed = T.nn.Embedding(129892, 32)
    # self.lstm = T.nn.LSTM(32, 100)  # sequence-first
    self.lstm = T.nn.LSTM(32, 100, batch_first=True)
    self.fc1 = T.nn.Linear(100, 2)  # 0=neg, 1=pos
 
  def forward(self, x):
    # x = review/sentence. length = fixed w/ padding
    z = self.embed(x)  # x can be arbitrary shape - not
    z = z.reshape(-1, 50, 32)  # bat seq embed
    lstm_oupt, (h_n, c_n) = self.lstm(z)
    # z = lstm_oupt[-1]   # for seq first
    z = lstm_oupt[:, -1]  # for batch_first
    z = T.log_softmax(self.fc1(z), dim=1)  # NLLLoss()
    return z 

# -----------------------------------------------------------

class IMDB_Dataset(T.utils.data.Dataset):
  # 50 token IDs then 0 or 1 label, space delimited
  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,51),
      delimiter=" ", comments="#", dtype=np.int64)
    tmp_x = all_xy[:,0:50]   # cols [0,50) = [0,49]
    tmp_y = all_xy[:,50]     # all rows, just col 50
    self.x_data = T.tensor(tmp_x, dtype=T.int64) 
    self.y_data = T.tensor(tmp_y, dtype=T.int64) 

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    tokens = self.x_data[idx]
    trgts = self.y_data[idx] 
    return (tokens, trgts)

# -----------------------------------------------------------

def accuracy(model, dataset):
  # data_x and data_y are lists of tensors
  # assumes model.eval()
  num_correct = 0; num_wrong = 0
  ldr = T.utils.data.DataLoader(dataset,
    batch_size=1, shuffle=False)
  for (batch_idx, batch) in enumerate(ldr):
    X = batch[0]  # inputs
    Y = batch[1]  # target sentiment label
    with T.no_grad():
      oupt = model(X)  # log-probs
   
    idx = T.argmax(oupt.data)
    if idx == Y:  # predicted == target
      num_correct += 1
    else:
      num_wrong += 1
  acc = (num_correct * 100.0) / (num_correct + num_wrong)
  return acc

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin PyTorch IMDB LSTM demo ")
  print("Using only reviews with 50 or less words ")
  T.manual_seed(1)
  np.random.seed(1)

  # 1. load data 
  print("\nLoading preprocessed train and test data ")
  train_file = ".\\Data\\imdb_train_50w.txt"
  train_ds = IMDB_Dataset(train_file) 

  test_file = ".\\Data\\imdb_test_50w.txt"
  test_ds = IMDB_Dataset(test_file) 

  bat_size = 8
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True, drop_last=True)
  n_train = len(train_ds)
  n_test = len(test_ds)
  print("Num train = %d Num test = %d " % (n_train, n_test))

# -----------------------------------------------------------

  # 2. create network
  net = LSTM_Net().to(device)

  # 3. train model
  loss_func = T.nn.NLLLoss()  # log-softmax() activation
  optimizer = T.optim.Adam(net.parameters(), lr=1.0e-3)
  max_epochs = 20
  log_interval = 2  # display progress

  print("\nbatch size = " + str(bat_size))
  print("loss func = " + str(loss_func))
  print("optimizer = Adam ")
  print("learn rate = 0.001 ")
  print("max_epochs = %d " % max_epochs)

  print("\nStarting training ")
  net.train()  # set training mode
  for epoch in range(0, max_epochs):
    tot_err = 0.0  # for one epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      # X = T.transpose(batch[0], 0, 1)  # seq first
      X = batch[0]  # batch first
      Y = batch[1]
      optimizer.zero_grad()
      oupt = net(X)
      
      loss_val = loss_func(oupt, Y) 
      tot_err += loss_val.item()
      loss_val.backward()  # compute gradients
      optimizer.step()     # update weights
  
    if epoch % log_interval == 0:
      print("epoch = %4d  |" % epoch, end="")
      print("   loss = %10.4f  |" % tot_err, end="")

      net.eval()
      train_acc = accuracy(net, train_ds)
      print("  accuracy = %8.2f%%" % train_acc)
      net.train()
  print("Training complete")

# -----------------------------------------------------------

  # 4. evaluate model
  net.eval()
  test_acc = accuracy(net, test_ds)
  print("\nAccuracy on test data = %8.2f%%" % test_acc)

  # 5. save model
  print("\nSaving trained model state")
  fn = ".\\Models\\imdb_model.pt"
  T.save(net.state_dict(), fn)

  # saved_model = Net()
  # saved_model.load_state_dict(T.load(fn))
  # use saved_model to make prediction(s)

  # 6. use model
  print("\nSentiment for \"the movie was a great \
waste of my time\"")
  print("0 = negative, 1 = positive ")
  review = np.array([4, 20, 16, 6, 86, 425, 7, 58, 64],
    dtype=np.int64)
  padding = np.zeros(41, dtype=np.int64)
  review = np.concatenate([padding, review])
  review = T.tensor(review, dtype=T.int64).to(device)
  
  net.eval()
  with T.no_grad():
    prediction = net(review)  # log-probs
  print("raw output : ", end=""); print(prediction)
  print("pseud-probs: ", end=""); print(T.exp(prediction))

  print("\nEnd PyTorch IMDB LSTM sentiment demo")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.