Generating PyTorch Transformer Masks

PyTorch Transformer architecture is incredibly complex. But like anything, if you dissect the topic one piece at a time, the complexity slowly but surely fades away.

One of the literally hundreds of details related to Transformer architecture is the generation and use of masks. While I was exploring the main Transformer example in the PyTorch documentation, I ran across a generate_square_subsequent_mask() function to create a mask:

class TransformerModel(nn.Module):
  def __init__(self, ntoken, ninp, nhead, nhid,
    nlayers, dropout=0.5):
    # lots of complicated code here     

  def generate_square_subsequent_mask(self, sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).\ 
      transpose(0, 1)
    mask = mask.float().\
      masked_fill(mask == 0, float('-inf')).\
        masked_fill(mask == 1, float(0.0))
    return mask

  def init_weights(self):
    # code here

  def forward(self, src, src_mask):
    # code here

My first reaction to generate_square_subsequent_mask() was, “What the Heck?” After examining it for a couple of hours, I came to the conclusion that the function is grotesquely over-engineered. I coded a simplified version.



First of all, the documentation version of generate_square_subsequent_mask() is defined in a class but it’s never used in the class and has no dependencies on anything in the class. So, the function might better be implemented as a standalone function like so:

def generate_square_subsequent_mask(sz):
  mask = (T.triu(T.ones(sz, sz)) == 1). \
    transpose(0, 1)
  mask = mask.float(). \
    masked_fill(mask == 0, float('-inf')). \
      masked_fill(mask == 1, float(0.0))
  return mask

The statements are so long I had to use several “\” line continuation character to fit them in my code box display for this blog post. Anytime there are hideously long statements in code, a red flag is raised in my mind. Another red flag was that I really didn’t have any idea of exactly what this function does. An experimental call with sz=4 generated a tensor matrix of:

[[0.0, -inf, -inf, -inf],
 [0.0,  0.0, -inf, -inf],
 [0.0,  0.0,  0.0, -inf],
 [0.0,  0.0,  0.0,  0.0]])

In other words, the result is a sz x sz square matrix with 0.0 values on and below the main diagonal, and -inf values above the main diagonal. Why the guy who wrote the documentation version used triu() and ones() and transpose() and float() conversion and masked_fill() is beyond me. After I knew what the documentation version did, the code sort of made sense but I still think the documentation code is much too complicated. (Note: float(‘-inf’) is a special Python code idiom that means “smaller than any value”).

Here’s my simplified version of a function that does the same thing as the documentation version:

def make_mask(sz):
  mask = T.zeros((sz,sz), dtype=T.float32)
  for i in range(sz):
    for j in range(sz):
      if j > i: mask[i][j] = float('-inf')
  return mask

I make a square matrix of 0.0 values, then traverse the matrix and place float(‘-inf’) in the cells above the main diagonal. Simple.

It’s true that using explicit for-loops like my code does is relatively slow, but in this scenarios the increase in simplicity outweighs whatever performance improvement, if any, you might get.

The moral of the story is that you should never blindly accept code written by someone else as optimal.


In computer science, simpler is almost always better. But in art, the challenge is to balance simplicity and complexity. Artist Stanislaw Krupp (b. 1959) does this balancing act very well in my opinion.

This entry was posted in PyTorch. Bookmark the permalink.