The PyTorch view() reshape() squeeze() and flatten() Functions

I was teaching a workshop on PyTorch deep neural networks recently and I noticed that people got tripped up on some of the details. For example, the differences between view() and reshape(), and between squeeze() and flatten() are important to understand.

First, view() and reshape() are essentially the same except for how they work behind the scenes. The view() function requires the tensor data to be contiguous in memory, but reshape() allows discontiguous data storage. I’ll limit myself to view(). The function is best explained by example:

X = T.tensor([[1,2,3,4],
              [5,6,7,8],
              [9,10,11,12]], dtype=T.float32)
print(X)

X = X.view(1,12)  # or reshape(1,12)
print(X)

X = X.view(3,-1)
print(X)

These lines of code set up a 3 rows x 4 columns tensor:

tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])

The view() function reshapes the data and so view(1,12) returns a tensor with the same data but with 1 row and 12 columns:

[[1 2 3 4 5 6 7 8 9 10 11 12]]

The special -1 argument tells the interpreter to “determine the correct value” so view(3,-1) returns a (3,4) tensor (the original shape). If calling view() generates an error, you can fix by placing a statement like X = X.contiguous() before calling the view() function.

Sometimes you have a tensor where one of the dimensions is 1, usually so the shape is compatible with some other tensor. For example:

X = T.tensor([[1,2,3,4],
              [5,6,7,8],
              [9,10,11,12]], dtype=T.float32)

Y = X.view(1,3,4)
print(Y)

Now because the first dimension has size 1, the only legal value is [0], for example:

print([0][2][1])  # tensor(10.)

So in some sense a dimension with size 1 is useless. The squeeze() function eliminate any dimension that has size 1. So,

Y = Y.squeeze()
print(Y)

removes the somewhat-useless first dimension and gives a tensor with shape 3 rows and 4 columns.

The flatten() function is sort of a cousin to squeeze(). The flatten() function returns a tensor with the same data but reshaped to a single dimension.

X = T.tensor([[1,2,3,4],
              [5,6,7,8],
              [9,10,11,12]], dtype=T.float32)

X = X.flatten()

gives a vector-like tensor with shape [12], not [1,12] as above, of:

[1 2 3 4 5 6 7 8 9 10 11 12]

Reshaping Pytorch tensors is not difficult conceptually but it is a big syntax problem for both beginners and experienced people who use PyTorch.



Images for squeeze and flatten. Scene from “Swiss Family Robinson” (1960). The girl is having entirely too much fun squeezing that bottle. “Flatland: A Romance of Many Dimensions” by Edwin A. Abbott (1884). Woman + flat tire = trouble.

This entry was posted in PyTorch. Bookmark the permalink.