The PyTorch Documentation for Weight Initialization is Not Very Good

I spent an entire day exploring PyTorch neural network weight initialization. I like PyTorch a lot, but the documentation for weight initialization is quite bad — unclear, incomplete, misleading, and sometimes incorrect.

If you create a PyTorch neural network from the torch.nn.Module class, if the network has a Linear layer, the Linear __init__() function will call into the torch.nn.init module and invoke the torch.nn.init.kaiming_uniform_() function to initialize weights and biases. But it does so in a completely wacky way, by setting an “a” parameter to sqrt(5), where “a” is described as “the negative slope of the rectifier used after this layer”. But setting a = sqrt(5) is done only because “Setting a=sqrt(5) in kaiming_uniform is the same as initializing with uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see github.com/pytorch/pytorch/issues/57109”

This is very confusing.

The bottom line is that PyTorch default weight initialization is so wacky, it’s probably best just to use the default scheme. Trying to explicitly set weights is a better idea in principle, but it’s just too complicated. For anything other than extremely deep neural networks, most reasonable initialization algorithms will work OK, and other factors, such as optimizer learning rate and batch size, have more impact. Furthermore, in addition to dealing with weight initialization, you have to worry about initializing the biases too. Sheesh.



Torch corals, Euphyllia glabrescens, are large polyp stony corals that live in Indo-Pacific reefs. The torch coral has long, flowing, fleshy polyps that extend from a calcified (stony) base. Colorful but kind of creepy.


Here is some demo code. I won’t even begin to try and explain it — there’s much too much going on with the fan_in and fan_out, the gain values related to activation functions, and so on.


import numpy as np
import torch as T
import math

print("\nPyTorch kaiming init experiments \n")

# 1. from a Net class

print("-----------")
print("1. kaiming uniform_ from Linear init ")
T.manual_seed(1)
np.random.seed(1)

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.fc1 = T.nn.Linear(3, 4)  # weight is 4x3
    # T.manual_seed(1)
    # np.random.seed(1)
    # T.nn.init.kaiming_uniform_(self.fc1.weight,
    #   a = math.sqrt(5))

  def forward(self, x):
    z = T.relu(self.fc1(x))
    return z

net = Net()
print(net.fc1.weight)  # [[0.2975, . . 0.0285]]
print("-----------")

# 2. kaiming init from definition, a = sqrt(5)

print("-----------")
print("2. kaiming uniform from algorithm with a = sqrt(5)")
T.manual_seed(1)
np.random.seed(1)
# "the negative slope of the rectifier used after this layer "
a = math.sqrt(5) 
fin = 3
gain = math.sqrt(2.0 / (1 + a**2)) 
bound = gain * math.sqrt(3.0 / fin)
# bound = 0.5773502691896257
lo = -bound
hi = bound
result = (hi - lo) * T.rand((4,3)) + lo
print(result)   # [[0.2975, . . 0.0285]]
print("-----------")

# 3. kaiming init, default a=0, 'fan_in', 'leaky_relu'
print("-----------")
print("3. kaiming uniform_ function with a = default 0")
T.manual_seed(1)
np.random.seed(1)
tnsr = T.zeros((4,3), dtype=T.float32)
T.nn.init.kaiming_uniform_(tnsr)
print(tnsr)  # [[0.7287, . . 0.0698]]  # different result
print("-----------")

# 4. Verify claim:
# Setting a=sqrt(5) in kaiming_uniform is the same as 
# initializing with uniform(-1/sqrt(in_features), 
# 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
print("-----------")
print("4. kaiming uniform definition b = 1/sqrt(3)")
T.manual_seed(1)
np.random.seed(1)
lo = -1.0 / math.sqrt(3.0)
hi = 1.0 / math.sqrt(3.0)
result = (hi - lo) * T.rand((4,3)) + lo
print(result)   # [[0.2975, . . 0.0285]]  # verified
print("-----------")
This entry was posted in PyTorch. Bookmark the permalink.