Some Thoughts About Using Type Hints With PyTorch

I rarely use type hints when I implement PyTorch neural network systems. Most of my colleagues don’t use type hints either. Briefly, the main advantage of using type hints for PyTorch programs is that it provides a good form of documentation, making code easier to read. The main disadvantage is that using type hints clutters up the code, making it more difficult to read.

Here’s a typical auxiliary function that computes the classification accuracy of a PyTorch model:

def accuracy(model, dataset, num_rows):
  dataldr = T.utils.data.DataLoader(dataset, batch_size=1,
    shuffle=False)
  # code here
  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

The types of the three parameters (model, dataset, num_rows) are easy to guess at but aren’t explicit. You could add comments like so:

def accuracy(model, dataset, num_rows):
  # model: class Net
  # dataset: a PyTorch Dataset
  # num_rows: number of rows to process (int)
  dataldr = T.utils.data.DataLoader(dataset, batch_size=1,
    shuffle=False)
  # code here
  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

Or you could use type hints:

def accuracy(model: Net, dataset: T.utils.data.Dataset,
  num_rows: int) -> float:
  dataldr = T.utils.data.DataLoader(dataset, batch_size=1,
    shuffle=False)
  # code here
  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

In theory, you can use type hints with a static code check program, but in practice there aren’t any good code checkers (in my opinion), probably because PyTorch is still in very active development mode and changes too quickly. So, for now I use comments to annotate my PyTorch programs.



Three images from an Internet search for the word “hint”. I don’t fully grasp why these images appeared in the search results. Left: The 1962 Ford Cougar concept car. Maybe “hint of Mercedes Benz”? Center: A bottle of Ron Zacapa brand rum. Maybe “hint of vanilla”? Right: A girl or model of some sort. Maybe “hint of Italian”?


Demo code. Replace “lt” (less-than), etc. with symbols.

# iris_hints.py
# iris example using type hints
# PyTorch 1.9.0-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10 

import numpy as np
import torch as T
device = T.device("cpu")  # apply to Tensor or Module

# -----------------------------------------------------------

class IrisDataset(T.utils.data.Dataset):
  def __init__(self, src_file: str, num_rows: int=None) -> None:
    # 5.0, 3.5, 1.3, 0.3, 0
    tmp_x = np.loadtxt(src_file, max_rows=num_rows,
      usecols=range(0,4), delimiter=",", skiprows=0,
      dtype=np.float32)
    tmp_y = np.loadtxt(src_file, max_rows=num_rows,
      usecols=4, delimiter=",", skiprows=0,
      dtype=np.int64)

    self.x_data = T.tensor(tmp_x, dtype=T.float32)
    self.y_data = T.tensor(tmp_y, dtype=T.int64)

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx: int) -> dict:
    if T.is_tensor(idx):
      idx = idx.tolist()
    preds = self.x_data[idx]
    spcs = self.y_data[idx] 
    sample = { 'predictors' : preds, 'species' : spcs }
    return sample

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(4, 7)  # 4-7-3
    self.oupt = T.nn.Linear(7, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x: T.Tensor) -> T.Tensor:
    z = T.tanh(self.hid1(x))
    z = self.oupt(z)  # no softmax: CrossEntropyLoss() 
    return z

# -----------------------------------------------------------

def accuracy(model: Net, dataset: T.utils.data.Dataset)
  -"gt" float:
  # assumes model.eval()
  dataldr = T.utils.data.DataLoader(dataset, batch_size=1,
    shuffle=False)
  n_correct = 0; n_wrong = 0
  for (_, batch) in enumerate(dataldr):
    X = batch['predictors'] 
    # Y = T.flatten(batch['species'])
    Y = batch['species']  # already flattened by Dataset
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)
    # if big_idx.item() == Y.item():
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin Iris dataset using PyTorch 1.9 demo \n")
  T.manual_seed(1)
  np.random.seed(1)
  
  # 1. create DataLoader objects
  print("Creating Iris train and test DataLoader ")

  train_file = ".\\Data\\iris_train.txt"
  test_file = ".\\Data\\iris_test.txt"

  train_ds = IrisDataset(train_file, num_rows=120)
  test_ds = IrisDataset(test_file)  # 120 

  bat_size = 4
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)
  test_ldr = T.utils.data.DataLoader(test_ds,
    batch_size=1, shuffle=False)

  # 2. create network
  net = Net().to(device)

  # 3. train model
  max_epochs = 12
  ep_log_interval = 2
  # ep_log_ct = 10
  # ep_log_interval = max_epochs // ep_log_count
  lrn_rate = 0.05

  loss_func = T.nn.CrossEntropyLoss()  # applies softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training")
  net.train()
  for epoch in range(0, max_epochs):
    epoch_loss = 0  # for one full epoch
    num_lines_read = 0

    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch['predictors']  # [10,4]
      # Y = T.flatten(batch['species'])  # [10,1] to [10]
      Y = batch['species']  # OK; alreay flattened
      # num_lines_read += bat_size  # early exit
      optimizer.zero_grad()
      oupt = net(X)
      loss_obj = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_obj.item()  # accumulate
      loss_obj.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %4d   loss = %0.4f" % (epoch, epoch_loss))
  print("Done ")

  # 4. evaluate model accuracy
  print("\nComputing model accuracy")
  net.eval()
  # acc = accuracy(net, test_ds)  # item-by-item
  acc = accuracy(net, train_ds)  # item-by-item
  # print("Accuracy on test data = %0.4f" % acc)
  print("Accuracy on train data = %0.4f" % acc)

  # 5. make a prediction
  print("\nPredicting species for [6.1, 3.1, 5.1, 1.1]: ")
  unk = np.array([[6.1, 3.1, 5.1, 1.1]], dtype=np.float32)
  unk = T.tensor(unk, dtype=T.float32).to(device) 

  with T.no_grad():
    logits = net(unk).to(device)  # values do not sum to 1.0
  probs = T.softmax(logits, dim=1).to(device)
  T.set_printoptions(precision=4)
  print(probs)

  # 6. save model (state_dict approach)
  print("\nSaving trained model state")
  fn = ".\\Models\\iris_model.pth"
  T.save(net.state_dict(), fn)

  # saved_model = Net()
  # saved_model.load_state_dict(T.load(fn))
  # use saved_model to make prediction(s)

  print("\nEnd Iris demo")

if __name__ == "__main__":
  main()
This entry was posted in PyTorch. Bookmark the permalink.