The PyTorch gather() function can be used to extract values from specified columns of a matrix. I sometimes use the gather() function when I’m working with PyTorch multi-class classification. Specifically, I use gather() when I have computed output probabilities in a matrix and I need to extract one value from each row, where the extracted value corresponds to the target output class.
For example, suppose you have a multi-class classifier with k=3 classes and just four training items. And suppose the four computed output probability vectors of the classifier network, and the four target (correct) classes are:
[[0.10, 0.50, 0.40], # correct [0.55, 0.20, 0.25], # wrong [0.60, 0.10, 0.30], # correct [0.15, 0.65, 0.20]], # correct [1, 2, 0, 1] # targets
The first output is correct because the largest probability (0.50) at position [1] corresponds to the target [1] but the second output is wrong because the largest probability (0.55) at [0] does not correspond to the target [2].
When computing cross entropy error, you need to extract the one probability from each output vector that correspond to the target. So for the values above, you want to extract:
[[0.50], [0.25], [0.60], [0.65]]
It’s possible to write a little helper function to perform the extraction, or you can use the built-in gather() function:
import torch as T
softs = T.tensor( # shape [4,3]
[[0.10, 0.50, 0.40], # correct
[0.55, 0.20, 0.25], # wrong
[0.60, 0.10, 0.30], # correct
[0.15, 0.65, 0.20]], # correct
dtype=T.float32)
targets = T.tensor([1, 2, 0, 1], dtype=T.int64)
targets = targets.reshape(4,1) # to match softs
probs = softs.gather(dim=1, index=targets)
# dim=1 means "collapse on cols; leaving all rows, 1 col"
print("\nComputed output probabilities: ")
print(softs)
print("\nTargets: ")
print(targets)
print("\nExtracted output probabilities: ")
print(probs)
There’s always a tension between the options of using a built-in PyTorch function (which is almost always a general purpose function and so it’ll have zillions of parameters and be difficult to understand), and writing a custom helper function (which can be short and specific, but increases the overall size of your program). For example, the following custom function gives the same result as the built-in gather() function:
def my_gather(src, cols):
nrows = len(src)
result = T.zeros((nrows,1), dtype=T.float32)
for i in range(nrows):
idx = cols[i]
x = src[i][idx]
result[i] = x
return result
The more I work with the PyTorch library, the more I appreciate how incredibly complex the library is. I see a lot of beauty in the library, but also a lot of ugliness in the form of the complexity that comes with writing general purpose functions. For most developers and data scientists, I believe that PyTorch is far too complex to be just one of several tools in their skill set — to gain expertise in PyTorch you must dedicate almost all of your time and effort learning it.
I’ve always been more attracted to closed systems than to open systems. For example, when I was young I learned to love chess, but I never had any interest in Dungeons and Dragons with its open rules and gameplay. Most of computer science consists of closed systems such as a programming language like C, or a library like PyTorch. Left: A really creative chess set where the pieces are origami and resemble alien plants. Right: A 3D illustration of a hypothetical chess set.


.NET Test Automation Recipes
Software Testing
SciPy Programming Succinctly
Keras Succinctly
R Programming
2026 Visual Studio Live
2025 Summer MLADS Conference
2026 DevIntersection Conference
2025 Machine Learning Week
2025 Ai4 Conference
2026 G2E Conference
2026 iSC West Conference
You must be logged in to post a comment.