The PyTorch scatter() Function Explained

The PyTorch scatter() function is strange. If you have a matrix named “source”, and another matrix of the same shape named “place_at”, and a third matrix named “destination” of the same shape or larger, the scatter() function will use the information in “place_at” to place the values in “source” into “destination”.

Here’s an example, using the values from one of the PyTorch documentation examples:

# scatter_demo.py
import torch as T

source = T.tensor([
  [ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
  [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]], dtype=T.float32)

place_at = T.tensor([
  [0, 1, 2, 0, 0],
  [2, 0, 0, 1, 2]], dtype=T.int64)

destination = T.tensor([
  [0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0]], dtype=T.float32)

# place values from src into destination using place_at
result =
 destination.scatter(dim=0, index=place_at, src=source)

print(source)
print(place_at)
print(destination)
print(result)

The result is:

[[0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
 [0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
 [0.5735, 0.0000, 0.9044, 0.0000, 0.1732]]

If you compare the source values with the result matrix, you’ll see that all values in source stay in their original columns, but move to a different row. This is because dim=0. If dim=1 then all source values stay in their same rows but move to a different column.

The place_at matrix is:

[[0, 1, 2, 0, 0],
 [2, 0, 0, 1, 2]])

Conceptually, the place_at matrix is really:

[ (0,0)  (1,1)  (2,2)  (0,3)  (0,4)
  (2,0)  (0,1)  (0,2)  (1,3)  (2,4) ]

I have never needed the scatter() function. I encountered scatter() because it is a cousin to a function named gather() which I do use sometimes. I find the scatter() function very ugly in some sense, and I would use a different approach if I ever encounter a scenario where I need to place values from one matrix into another larger matrix. For example, the following custom function is roughly comparable to the built-in scatter() function. The custom function accepts a place_at matrix where the locations to place values at are explicit, as tuples, rather than using an implied column index.

def my_scatter(src, dest, place_at):
  nrows = len(place_at)
  ncols = len(place_at[0])
  for i in range(nrows):
    for j in range(ncols):
      xy = place_at[i][j]  # tuple
      x = xy[0].item()
      y = xy[1].item()
      dest[x][y] = src[i][j]
  return

The custom my_scatter() function could be called like:

source = T.tensor([
  [ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
  [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]],
    dtype=T.float32)

place_at = \
  T.tensor([[(0,0), (1,1), (2,2), (0,3), (0,4)],
            [(2,0), (0,1), (0,2), (1,3), (2,4)]],
    dtype=T.int64)

destination = T.tensor([
  [0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0]], dtype=T.float32)

my_scatter(source, destination, place_at)

print("\nNew destination: ")
print(destination)

You can argue that it may not be a good idea to write a custom function when a built-in function exists, but the counterargument is that it’s not a good idea to use a built-in function when that function is overly complex and therefore error prone.


I think I’m pretty good at recognizing ugly computer language things like the scatter() function. But I am not good at all at recognizing ugliness in real life. Here are three images returned from an Internet image search for “ugly dress”. But they all look good to me — suggesting that I am clearly no judge of fashion ugliness.

This entry was posted in PyTorch. Bookmark the permalink.