Understanding the PyTorch TransformerEncoderLayer

The hottest thing in natural language processing is the neural Transformer architecture. A Transformer can be used for sequence-to-sequence tasks such as summarizing a document to an abstract, or translating an English document to German.

I’ve been slowly but surely learning about Transformers. They are easily the most complex software components I’ve encountered. Ever.

I know from experience that gaining expertise in Transformer architecture will take many months. So I’m looking at the architecture slowly but surely — it’s important to have time between explorations of complex topics.

I spent a few hours looking at the source code on GitHub. The key takeaway is that a Transformer is made of a TransformerEncoder and a TransformerDecoder, and these are made of TransformerEncoderLayer objects and TransformerDecoderLayer objects respectively:

A PyTorch top-level Transformer class contains one
 TransformerEncoder object and one TransformerDecoder object.

A TransformerEncoder class contains one or more (six by default)
 TransformerEncoderLayer objects.
A TransformerDecoder class contains one or more (six by default)
 TransformerDecoderLayer objects.

A TransformerEncoderLayer class contains one MultiheadAttention
 object and one ordinary neural network
 (2048 hidden nodes by default).
A TransformerDecoderLayer class contains two MultiheadAttention
 objects and one ordinary neural network
 (2048 hidden nodes by default).

My initial thought was to look at the MultiheadAttention class first. But I quickly realized that it’s at too-low a level. The source for the class is nearly 1,000 lines of very complex code. I’ll do MultiheadAttention later. So I decided to move up one level of abstraction and look at the TransformerEncoderLayer.

I wrote a demo to explore IO for the class. The key lines of code in the demo are:

encoder_layer = T.nn.TransformerEncoderLayer(d_model=4,
  nhead=2).to(device)
src = T.rand(3, 5, 4).to(device) 
oupt = encoder_layer(src) 

The d_model parameter is the word embedding dimension, which is sort of synonymous with the overall model dimension. If I didn’t fully understand word embeddings, I’d have been sunk right there. The n_head parameter is the number of attention heads. I believe that the number of heads is similar to running an experiment multiple times and then averaging the results to get a better answer. I’m not sure though — it’s one of zillions of details that I’ll need to master . . . eventually.

The demo input is a tensor that has shape (3,5,4). Because multiple TransformerEncoderLayer objects are used internally, the meaning of the three parameters isn’t clear to me, but I believe the first parameter is the number of words in a sentence, and the second parameter is the number of sentences in a training batch. The third parameter is the embedding dimension again.

So the demo input is shape (3,5,4) and the output shape is also (3,5,4). Therefore, at this point in my understanding, a TransformerEncoderLayer accepts a 3d tensor and emits a 3d tensor of the same shape. This makes sense because TransformerEncoderLayer objects are chained together inside a TransformerEncoder object. And because the last TransformerEncoderLayer in a TransformerEncoder connects to a TransformerDecoder with multiple TransformerEncoderLayer objects chained together, I’m guessing that the input and output shapes for both a TransformerEncoder and its component TransformerEncoderLayer objects will also be the same. I’ll find out soon.

One of many things I don’t understand yet is masking — you can pass masks as arguments. The source code comments read:

src_mask: the mask for the src sequence.
src_key_padding_mask: the mask for the src keys per batch.

Classic documentation — essentially zero information. But I’m used to this situation with new, complex code libraries.

Next, I’ll look at the parallel TransformerDecoderLayer. I expect it to have similar behavior. But I won’t know for sure until I code up a demo.

The moral to all of this is that learning how to use extremely complex software library code is not easy. No shortcuts. Just persistence and relentless determination.


Left: This shortcut on the road to success did not work out well. Right: Canine shortcut gone wrong.

This entry was posted in PyTorch. Bookmark the permalink.