The PyTorch LSTM Module Input Shape Tricked Me Again

Grrrr. I’ve been wrestling with PyTorch LSTM (long, short-term memory) code for many weeks now. Some time ago I thought I had figured out how to deal with the shapes of the input and output tensors. But I discovered I had been wrong.

Here’s an example of how crazy tricky these things are. I imagined a scenario where I have a bunch of sentences, such as “I like this” and “Not a good movie”. Each word will be converted to a vector embedding, such as “movie” = (0.12, 0.23, 0.34). And sentences would be batched together in groups so that the network can train more easily.

So I set up dummy input like:

inpt = np.array(
 [[[0.01, 0.02, 0.03, 0.04, 0.05],   
   [0.06, 0.07, 0.08, 0.09, 0.10],  
   [0.11, 0.12, 0.13, 0.14, 0.15],  
   [0.16, 0.17, 0.18, 0.19, 0.20]], 

  [[0.21, 0.22, 0.23, 0.24, 0.25], 
   [0.26, 0.27, 0.28, 0.29, 0.30],
   [0.31, 0.32, 0.33, 0.34, 0.35],
   [0.36, 0.37, 0.38, 0.39, 0.40]],

  [[0.41, 0.42, 0.43, 0.44, 0.45], 
   [0.46, 0.47, 0.48, 0.49, 0.50],
   [0.51, 0.52, 0.53, 0.54, 0.55],
   [0.56, 0.57, 0.58, 0.59, 0.60]]], dtype=np.float32)

inptT = T.tensor(inpt, dtype=T.float32)  

The shape is (3,4,5) and intuitively, this should represent three sentences where each sentence has four words and each word has five values. In other words, I thought that the first sentence was:

  [[[0.01, 0.02, 0.03, 0.04, 0.05],  # 1st word 
    [0.06, 0.07, 0.08, 0.09, 0.10],  # 2nd word 
    [0.11, 0.12, 0.13, 0.14, 0.15],  # 3rd word  
    [0.16, 0.17, 0.18, 0.19, 0.20]], # 4th word

Wrong, wrong, wrong. The storage is completely unintuitive and is actually:

inpt = np.array(
 [[[0.01, 0.02, 0.03, 0.04, 0.05],  # 1st word of first sentence 
   [0.06, 0.07, 0.08, 0.09, 0.10],  # 1st word of second sentence
   [0.11, 0.12, 0.13, 0.14, 0.15],  # 1st word of third sentence
   [0.16, 0.17, 0.18, 0.19, 0.20]], # 1st word of fourth sentence

  [[0.21, 0.22, 0.23, 0.24, 0.25],  # 2nd word of first sentence
   [0.26, 0.27, 0.28, 0.29, 0.30],
   [0.31, 0.32, 0.33, 0.34, 0.35],
   [0.36, 0.37, 0.38, 0.39, 0.40]],

  [[0.41, 0.42, 0.43, 0.44, 0.45],  # 3rd word of first sentence
   [0.46, 0.47, 0.48, 0.49, 0.50],
   [0.51, 0.52, 0.53, 0.54, 0.55],
   [0.56, 0.57, 0.58, 0.59, 0.60]]], dtype=np.float32)  # end batch

So a PyTorch LSTM input shape of (3,4,5) means each sentence has 3 words, there are 4 sentences in a batch, and each word is represented by 5 numeric values. Argh!

One of the things that tricked was the special case where a batch contains only a single sentence. For example an input with shape (3,1,5) such as:

  [[[0.01, 0.02, 0.03, 0.04, 0.05]],
   [[0.06, 0.07, 0.08, 0.09, 0.10]],
   [[0.11, 0.12, 0.13, 0.14, 0.15]]])

Because of the weird geometry, the single-sentence per batch scenario does have an intuitive representation:

  [[[0.01, 0.02, 0.03, 0.04, 0.05]],  # 1st word of only sentence
   [[0.06, 0.07, 0.08, 0.09, 0.10]],  # 2nd word
   [[0.11, 0.12, 0.13, 0.14, 0.15]]]) # 3rd word

Early on I focused on the simple example and didn’t keep both eyes open to other possibilities. Well, this little exercise cost me hours and hours and hours of time. But this is the price you pay when dealing with new, immature technologies that have skimpy documentation.


This entry was posted in PyTorch. Bookmark the permalink.

1 Response to The PyTorch LSTM Module Input Shape Tricked Me Again

  1. Pingback: Unraveling the Mysteries of a PyTorch LSTM Module | James D. McCaffrey

Comments are closed.