A PyTorch torch.utils.data Dataset Class Wrapper Over the torchtext.datasets.IMDB Class

The problem addressed by this blog post requires a somewhat lengthy explanation, so if you’re reading this, bear with me. The PyTorch torch.utils.data module has a Dataset class and a DataLoader class. Using these two classes is the de facto standard way to read data and serve it up in batches for tabular data problems, such as predicting the species of an iris flower from sepal length and width, and petal length and width. The torch.utils.data module is very nicely implemented.

Because PyTorch is open source, its growth and organization have been, and continue to be, chaotic. In particular, there is a companion torchvision library (for image processing), and a companion torchtext library (for natural language processing). These two companion libraries have their own Dataset classes for images and text sources — and none of these Dataset classes are compatible with each other.

The torchtext library is especially weak, and almost everything in it has been deprecated since November 2020. The new torchtext library will have a Dataset class that is largely compatible with the torch.utils.data Dataset, but there’s no telling when the new torchtext will be released — it’s already taken months and it could take months longer.

So, if you want to work with text problems, in particular the new Transformer architecture (like me), you’re faced with a significant problem. Suppose you want to work with the IMDB movie review sentiment analysis data.

You can use the old, deprecated torchtext built-in IMDB Dataset but when the new toechtext Dataset is finally released, you’ll likely have to make major modifications to all your code.

Another option is to just wait until the new torchtext Dataset is released. It will have a built-in IMDB Dataset. But this is unappealing. Most of my colleagues, and me, have a strong passion for learning new things, and waiting is painful.

A third option is to implement a torch.utils.data.Dataset for IMDB data, from scratch. But this is a lot of work — a really lot of work. At least a week of dev time.

A fourth option, and the one that’s the topic of this blog post, is to write a custom torch.utils.data Dataset that is a wrapper around the old deprecated torchtext.datasets.IMDB class. The idea is to use the functionality of the deprecated code but because the interface is (hopefully) from the upcoming new interface, you should, in theory, be able to drop the new torchtext code into your program code.

So, I set out to experiment with this fourth option.

The bottom line is that the wrapper idea is viable but it was much more difficult to implement than I thought it’d be. The main problem is that library APIs have a huge number of parameters because they have to accommodate every possible programming scenario. To wrap a complex API with another API is just really, really complicated.

In the end, I couldn’t decide how best to proceed. I’ll probably use my custom wrapper approach as a stopgap until the new torchtext library is released.



Three dresses made from recycled wrappers. Left: Dress made from Cadbury chocolate wrappers. Center: Dress made from generic brown paper wrappers. Right: Dress made from 10,000 Starburst candy wrappers.


# experiment.py

# The torchtext datasets have been deprecated for months.
# The new torchtext will be compatible with torch
# data.utils.Dataset and DatLoader. Exploring the idea of
# a wrapper data.utils.Dataset class over the deprecated
# torchtext IMDB class.

# Python 3.7.6  PyTorch 1.7.0
# Windows 10  CPU

import torch as T
import torchtext as tt

# -----------------------------------------------------------

class MyIMDB_Dataset(T.utils.data.Dataset):
  # a wrapper over the deprecated torchtext IMDB class
  # for use until new torchtext IMDB is released

  def __init__(self, train_or_test):
    # get IMDB data into memory
    # build vocab from train

    self.train_or_test = train_or_test

    TEXT = tt.data.Field(lower=True, include_lengths=True,
      batch_first=True, tokenize="basic_english")
    LABEL = tt.data.Field(sequential=False)

    # fetch all data and create vocabulary
    train_ds, test_ds = tt.datasets.IMDB.splits(TEXT, LABEL)
    # print(train_ds.examples[0].text)  # words

    TEXT.build_vocab(train_ds, max_size=50_000) 
    LABEL.build_vocab(train_ds)  # unk neg, pos
    if train_or_test == "train":
      self.vocab = TEXT.vocab
    else:
      self.vocab = None

    train_itr = tt.data.Iterator(train_ds, shuffle=False, \
      batch_size=25_000)
    test_itr = tt.data.Iterator(test_ds, shuffle=False, \
      batch_size=25_000)

    if train_or_test == "train":
      itr = train_itr
    else:
      itr = test_itr

    for idx, batch in enumerate(itr):
      pass

    self.all_x = batch.text[0]
    self.all_y = batch.label - 1  # 

  def __len__(self):
    return len(self.all_x)

  def __getitem__(self, idx):
    # get one (review, 0/1 label)
    # review as ids, not strings
    review = self.all_x[idx, :]
    lbl = self.all_y[idx]
    return (review, lbl)

# -----------------------------------------------------------
 
def main():

  print("\nIMDB utils.data Dataset wrapper over tt Dataset")
  train_ds = MyIMDB_Dataset("train")
  test_ds = MyIMDB_Dataset("test")

  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=10, shuffle=False)
  for idx, batch in enumerate(train_ldr):
    if idx == 2:
      break
    print(batch[0])         # padded reviws
    print(batch[0].shape)
    print(batch[1])          # batch of 0/1 labels
    print(batch[1].shape)

  print("\nEnd experiment ")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.