I rarely use type hints when I implement PyTorch neural network systems. Most of my colleagues don’t use type hints either. Briefly, the main advantage of using type hints for PyTorch programs is that it provides a good form of documentation, making code easier to read. The main disadvantage is that using type hints clutters up the code, making it more difficult to read.
Here’s a typical auxiliary function that computes the classification accuracy of a PyTorch model:
def accuracy(model, dataset, num_rows):
dataldr = T.utils.data.DataLoader(dataset, batch_size=1,
shuffle=False)
# code here
acc = (n_correct * 1.0) / (n_correct + n_wrong)
return acc
The types of the three parameters (model, dataset, num_rows) are easy to guess at but aren’t explicit. You could add comments like so:
def accuracy(model, dataset, num_rows):
# model: class Net
# dataset: a PyTorch Dataset
# num_rows: number of rows to process (int)
dataldr = T.utils.data.DataLoader(dataset, batch_size=1,
shuffle=False)
# code here
acc = (n_correct * 1.0) / (n_correct + n_wrong)
return acc
Or you could use type hints:
def accuracy(model: Net, dataset: T.utils.data.Dataset,
num_rows: int) -> float:
dataldr = T.utils.data.DataLoader(dataset, batch_size=1,
shuffle=False)
# code here
acc = (n_correct * 1.0) / (n_correct + n_wrong)
return acc
In theory, you can use type hints with a static code check program, but in practice there aren’t any good code checkers (in my opinion), probably because PyTorch is still in very active development mode and changes too quickly. So, for now I use comments to annotate my PyTorch programs.

Three images from an Internet search for the word “hint”. I don’t fully grasp why these images appeared in the search results. Left: The 1962 Ford Cougar concept car. Maybe “hint of Mercedes Benz”? Center: A bottle of Ron Zacapa brand rum. Maybe “hint of vanilla”? Right: A girl or model of some sort. Maybe “hint of Italian”?
Demo code. Replace “lt” (less-than), etc. with symbols.
# iris_hints.py
# iris example using type hints
# PyTorch 1.9.0-CPU Anaconda3-2020.02 Python 3.7.6
# Windows 10
import numpy as np
import torch as T
device = T.device("cpu") # apply to Tensor or Module
# -----------------------------------------------------------
class IrisDataset(T.utils.data.Dataset):
def __init__(self, src_file: str, num_rows: int=None) -> None:
# 5.0, 3.5, 1.3, 0.3, 0
tmp_x = np.loadtxt(src_file, max_rows=num_rows,
usecols=range(0,4), delimiter=",", skiprows=0,
dtype=np.float32)
tmp_y = np.loadtxt(src_file, max_rows=num_rows,
usecols=4, delimiter=",", skiprows=0,
dtype=np.int64)
self.x_data = T.tensor(tmp_x, dtype=T.float32)
self.y_data = T.tensor(tmp_y, dtype=T.int64)
def __len__(self):
return len(self.x_data)
def __getitem__(self, idx: int) -> dict:
if T.is_tensor(idx):
idx = idx.tolist()
preds = self.x_data[idx]
spcs = self.y_data[idx]
sample = { 'predictors' : preds, 'species' : spcs }
return sample
# -----------------------------------------------------------
class Net(T.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.hid1 = T.nn.Linear(4, 7) # 4-7-3
self.oupt = T.nn.Linear(7, 3)
T.nn.init.xavier_uniform_(self.hid1.weight)
T.nn.init.zeros_(self.hid1.bias)
T.nn.init.xavier_uniform_(self.oupt.weight)
T.nn.init.zeros_(self.oupt.bias)
def forward(self, x: T.Tensor) -> T.Tensor:
z = T.tanh(self.hid1(x))
z = self.oupt(z) # no softmax: CrossEntropyLoss()
return z
# -----------------------------------------------------------
def accuracy(model: Net, dataset: T.utils.data.Dataset)
-"gt" float:
# assumes model.eval()
dataldr = T.utils.data.DataLoader(dataset, batch_size=1,
shuffle=False)
n_correct = 0; n_wrong = 0
for (_, batch) in enumerate(dataldr):
X = batch['predictors']
# Y = T.flatten(batch['species'])
Y = batch['species'] # already flattened by Dataset
with T.no_grad():
oupt = model(X) # logits form
big_idx = T.argmax(oupt)
# if big_idx.item() == Y.item():
if big_idx == Y:
n_correct += 1
else:
n_wrong += 1
acc = (n_correct * 1.0) / (n_correct + n_wrong)
return acc
# -----------------------------------------------------------
def main():
# 0. get started
print("\nBegin Iris dataset using PyTorch 1.9 demo \n")
T.manual_seed(1)
np.random.seed(1)
# 1. create DataLoader objects
print("Creating Iris train and test DataLoader ")
train_file = ".\\Data\\iris_train.txt"
test_file = ".\\Data\\iris_test.txt"
train_ds = IrisDataset(train_file, num_rows=120)
test_ds = IrisDataset(test_file) # 120
bat_size = 4
train_ldr = T.utils.data.DataLoader(train_ds,
batch_size=bat_size, shuffle=True)
test_ldr = T.utils.data.DataLoader(test_ds,
batch_size=1, shuffle=False)
# 2. create network
net = Net().to(device)
# 3. train model
max_epochs = 12
ep_log_interval = 2
# ep_log_ct = 10
# ep_log_interval = max_epochs // ep_log_count
lrn_rate = 0.05
loss_func = T.nn.CrossEntropyLoss() # applies softmax()
optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)
print("\nbat_size = %3d " % bat_size)
print("loss = " + str(loss_func))
print("optimizer = SGD")
print("max_epochs = %3d " % max_epochs)
print("lrn_rate = %0.3f " % lrn_rate)
print("\nStarting training")
net.train()
for epoch in range(0, max_epochs):
epoch_loss = 0 # for one full epoch
num_lines_read = 0
for (batch_idx, batch) in enumerate(train_ldr):
X = batch['predictors'] # [10,4]
# Y = T.flatten(batch['species']) # [10,1] to [10]
Y = batch['species'] # OK; alreay flattened
# num_lines_read += bat_size # early exit
optimizer.zero_grad()
oupt = net(X)
loss_obj = loss_func(oupt, Y) # a tensor
epoch_loss += loss_obj.item() # accumulate
loss_obj.backward()
optimizer.step()
if epoch % ep_log_interval == 0:
print("epoch = %4d loss = %0.4f" % (epoch, epoch_loss))
print("Done ")
# 4. evaluate model accuracy
print("\nComputing model accuracy")
net.eval()
# acc = accuracy(net, test_ds) # item-by-item
acc = accuracy(net, train_ds) # item-by-item
# print("Accuracy on test data = %0.4f" % acc)
print("Accuracy on train data = %0.4f" % acc)
# 5. make a prediction
print("\nPredicting species for [6.1, 3.1, 5.1, 1.1]: ")
unk = np.array([[6.1, 3.1, 5.1, 1.1]], dtype=np.float32)
unk = T.tensor(unk, dtype=T.float32).to(device)
with T.no_grad():
logits = net(unk).to(device) # values do not sum to 1.0
probs = T.softmax(logits, dim=1).to(device)
T.set_printoptions(precision=4)
print(probs)
# 6. save model (state_dict approach)
print("\nSaving trained model state")
fn = ".\\Models\\iris_model.pth"
T.save(net.state_dict(), fn)
# saved_model = Net()
# saved_model.load_state_dict(T.load(fn))
# use saved_model to make prediction(s)
print("\nEnd Iris demo")
if __name__ == "__main__":
main()

.NET Test Automation Recipes
Software Testing
SciPy Programming Succinctly
Keras Succinctly
R Programming
2026 Visual Studio Live
2025 Summer MLADS Conference
2026 DevIntersection Conference
2025 Machine Learning Week
2025 Ai4 Conference
2026 G2E Conference
2026 iSC West Conference
You must be logged in to post a comment.