PyTorch Post-Training Dynamic Quantization Example

Neural network quantization is a technique that reduces the size of a trained model and can sometimes speed up making predictions using the reduced-size model. The idea is to replace 32-bit weights and biases with 8-bit integers.

I encountered neural network quantization while I was exploring large language models (LLMs) and the LM Studio tool. The LM Studio tool loads LLMs from OpenAI (GPT-3.5), Microsoft (Orca-2), Meta/Facebook (LLaMA-2), and other sources, where the models have been reduced using quantization. See my post at https://jamesmccaffreyblog.com/2023/12/04/a-quick-look-at-the-lm-studio-tool-for-exploring-large-language-models/.

Quantization is a surprisingly complicated topic and there are many, many ways to implement quantization. By far the simplest approach is called post-training dynamic quantization (PTDQ). You create a 32-bit model as usual, then create an 8-bit quantized model from the 32-bit model.

PyTorch has a built-in function to do PTDQ. I coded up a demo. I used one of my standard synthetic datasets. The tab-delimited data looks like:

-1   0.27   0  1  0   0.7610   2
+1   0.19   0  0  1   0.6550   0
. . .

The fields are sex (-1 = male, +1 = female), age (divided by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), income (divided by $100,000) and political leaning (0 = conservative, 1 = moderate, 2 = liberal). The goal is to predict political leaning from sex, age, State, and income.

The key statements in the demo are:

import torch as T

net = Net().to(device)  # 32-bit model
# train the 32-bit model as usual
net_q8 = T.ao.quantization.quantize_dynamic(
  net,              # original trained model
  { T.nn.Linear },  # layer types to dynamically quantize
  dtype=T.qint8)    # quantized 8-bit type
# now use net_q8 as usual

When I defined the 32-bit network, I used class-style activation functions instead of functional functions:

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
    self.hid1_act = T.nn.Tanh()  # class-style
. . .
  def forward(self, x):
    # z = T.tanh(self.hid1(x))  # functional style
    z = self.hid1_act(self.hid1(x))  # class-style
. . .

The idea is that for many types of quantization, the quantitizing code needs to know what activations are being used and the class-stlye definition is easier to analyze. For simple PTDQ I could have used functional style activations.

When I trained the 32-bit model, I applied weight and bias clamping:

. . .
  loss_val.backward()
  optimizer.step()
  with T.no_grad():
    for param in net.parameters():
      param.clamp_(-3.0, 3.0)
. . .

The idea here is to keep all the weights and biases in the same range, which makes the quantization work better in most cases. Clamping is optional but usually improves the quantized model.

For my demo neural network, quantization isn’t really useful because the network isn’t very large. But for huge networks such as large language models, quantization can be very useful.



The old Twilight Zone TV series (1959-1964) had several memorable stories that featured shrinking / quantized people.

Left: In “The Invaders” (1961), an old woman living by herself finds a tiny alien spacecraft with tiny aliens on her roof. In the end, it turns out the aliens are actually U.S. astronauts and the old woman is the alien.

Center: In “The Fear” (1964), a woman and a policeman in an isolated cabin see footprints and other evidence of enormous 100-foot tall aliens. It turns out there are aliens but they’re tiny and just created the illusion of giants.

Right: In “Four O’Clock” (1962), an evil man plans to shrink his enemies (who are all good) through force of will. His plan backfires, and his pet parrot is very hungry.


Demo code.

# people_politics_quantization.py
# predict politics type from sex, age, state, income
# PyTorch 1.12.1-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex  age    state    income   politics
  # -1   0.27   0  1  0   0.7610   2
  # +1   0.19   0  0  1   0.6550   0
 
  def __init__(self, src_file):
    all_xy = np.loadtxt(src_file, usecols=range(0,7),
      delimiter="\t", comments="#", dtype=np.float32)
    tmp_x = all_xy[:,0:6]   # cols [0,6) = [0,5]
    tmp_y = all_xy[:,6]     # 1-D

    self.x_data = T.tensor(tmp_x, 
      dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y,
      dtype=T.int64).to(device)  # 1-D

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return preds, trgts  # as a Tuple

# -----------------------------------------------------------

class Net(T.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.hid1 = T.nn.Linear(6, 10)  # 6-(10-10)-3
    self.hid1_act = T.nn.Tanh()
    self.hid2 = T.nn.Linear(10, 10)
    self.hid2_act = T.nn.Tanh()
    self.oupt = T.nn.Linear(10, 3)
    self.oupt_act = T.nn.LogSoftmax(dim=1)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = self.hid1_act(self.hid1(x))
    z = self.hid2_act(self.hid2(z))
    z = self.oupt_act(self.oupt(z))  # NLLLoss() 
    return z

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  # item-by-item version
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)  # 0 1 or 2, 1D
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin People predict politics quantization ")
  T.manual_seed(1)
  np.random.seed(1)
  np.set_printoptions(precision=4, suppress=True,
    floatmode='fixed')
  T.set_printoptions(precision=4, sci_mode=False)

  
  # 1. create DataLoader objects
  print("\nCreating People Datasets ")

  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)  # 200 rows

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)    # 40 rows

  bat_size = 10
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

# -----------------------------------------------------------

  # 2. create standard 32-bit network
  print("\nCreating 6-(10-10)-3 neural network ")
  net = Net().to(device)

# -----------------------------------------------------------

  # 3. train model
  net.train()  # set mode

  max_epochs = 1000
  ep_log_interval = 200
  lrn_rate = 0.01

  loss_func = T.nn.NLLLoss()  # assumes log_softmax()
  optimizer = T.optim.SGD(net.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = SGD")
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  print("\nStarting training")
  for epoch in range(0, max_epochs):
    # T.manual_seed(epoch+1)  # checkpoint reproducibility
    epoch_loss = 0  # for one full epoch

    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0]  # inputs
      Y = batch[1]  # correct class/label/politics

      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()
      with T.no_grad():
        for param in net.parameters():
          param.clamp_(-3.0, 3.0)

    if epoch % ep_log_interval == 0:
      print("epoch = %5d  |  loss = %10.4f" % \
        (epoch, epoch_loss))

  print("Done ")
  
# -----------------------------------------------------------

  # 4. evaluate model accuracy
  net.eval()
  acc_train = accuracy(net, train_ds)  # item-by-item
  print("\nAccuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy(net, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

  # 5. make a prediction
  print("\nPredicting for M  30  oklahoma  $50,000: ")
  X = np.array([[-1, 0.30,  0,0,1,  0.5000]],
    dtype=np.float32)
  X = T.tensor(X, dtype=T.float32).to(device) 

  with T.no_grad():
    logits = net(X)  # do not sum to 1.0
  probs = T.exp(logits)  # sum to 1.0
  probs = probs.numpy()  # numpy vector prints better
  np.set_printoptions(precision=4, suppress=True,
    floatmode='fixed')
  print(probs)

# -----------------------------------------------------------

  # QUANTIZATION
  print("\nCreating 8-bit PTDQ quantitized model ")
  net_q8 = T.ao.quantization.quantize_dynamic(
    net,  # original trained model
    { T.nn.Linear },  # layers to dynamically quantize
    dtype=T.qint8)
  print("Completed ")

  net_q8.eval()
  acc_train = accuracy(net_q8, train_ds)  # item-by-item
  print("\nAccuracy on training data = %0.4f" % acc_train)
  acc_test = accuracy(net_q8, test_ds) 
  print("Accuracy on test data = %0.4f" % acc_test)

  print("\nPredicting M 30 oklahoma $50,000 quantized ")
  with T.no_grad():
    logits = net_q8(X)
  probs = T.exp(logits)
  probs = probs.numpy()
  print(probs)

  # 6. save model (state_dict approach)
  # print("\nSaving trained model state")
  # fn = ".\\Models\\people_model.pt"
  # T.save(net.state_dict(), fn)

  # saved_model = Net()  # requires class definintion
  # saved_model.load_state_dict(T.load(fn))
  # use saved_model to make prediction(s)

  print("\nEnd People predict politics demo ")

if __name__ == "__main__":
  main()

Training data. Replace space-space with tab character.

# people_train.txt
# sex (M=-1, F=1)  age  state (michigan, 
# nebraska, oklahoma) income
# politics (consrvative, moderate, liberal)
#
1  0.24  1  0  0  0.2950  2
-1  0.39  0  0  1  0.5120  1
1  0.63  0  1  0  0.7580  0
-1  0.36  1  0  0  0.4450  1
1  0.27  0  1  0  0.2860  2
1  0.50  0  1  0  0.5650  1
1  0.50  0  0  1  0.5500  1
-1  0.19  0  0  1  0.3270  0
1  0.22  0  1  0  0.2770  1
-1  0.39  0  0  1  0.4710  2
1  0.34  1  0  0  0.3940  1
-1  0.22  1  0  0  0.3350  0
1  0.35  0  0  1  0.3520  2
-1  0.33  0  1  0  0.4640  1
1  0.45  0  1  0  0.5410  1
1  0.42  0  1  0  0.5070  1
-1  0.33  0  1  0  0.4680  1
1  0.25  0  0  1  0.3000  1
-1  0.31  0  1  0  0.4640  0
1  0.27  1  0  0  0.3250  2
1  0.48  1  0  0  0.5400  1
-1  0.64  0  1  0  0.7130  2
1  0.61  0  1  0  0.7240  0
1  0.54  0  0  1  0.6100  0
1  0.29  1  0  0  0.3630  0
1  0.50  0  0  1  0.5500  1
1  0.55  0  0  1  0.6250  0
1  0.40  1  0  0  0.5240  0
1  0.22  1  0  0  0.2360  2
1  0.68  0  1  0  0.7840  0
-1  0.60  1  0  0  0.7170  2
-1  0.34  0  0  1  0.4650  1
-1  0.25  0  0  1  0.3710  0
-1  0.31  0  1  0  0.4890  1
1  0.43  0  0  1  0.4800  1
1  0.58  0  1  0  0.6540  2
-1  0.55  0  1  0  0.6070  2
-1  0.43  0  1  0  0.5110  1
-1  0.43  0  0  1  0.5320  1
-1  0.21  1  0  0  0.3720  0
1  0.55  0  0  1  0.6460  0
1  0.64  0  1  0  0.7480  0
-1  0.41  1  0  0  0.5880  1
1  0.64  0  0  1  0.7270  0
-1  0.56  0  0  1  0.6660  2
1  0.31  0  0  1  0.3600  1
-1  0.65  0  0  1  0.7010  2
1  0.55  0  0  1  0.6430  0
-1  0.25  1  0  0  0.4030  0
1  0.46  0  0  1  0.5100  1
-1  0.36  1  0  0  0.5350  0
1  0.52  0  1  0  0.5810  1
1  0.61  0  0  1  0.6790  0
1  0.57  0  0  1  0.6570  0
-1  0.46  0  1  0  0.5260  1
-1  0.62  1  0  0  0.6680  2
1  0.55  0  0  1  0.6270  0
-1  0.22  0  0  1  0.2770  1
-1  0.50  1  0  0  0.6290  0
-1  0.32  0  1  0  0.4180  1
-1  0.21  0  0  1  0.3560  0
1  0.44  0  1  0  0.5200  1
1  0.46  0  1  0  0.5170  1
1  0.62  0  1  0  0.6970  0
1  0.57  0  1  0  0.6640  0
-1  0.67  0  0  1  0.7580  2
1  0.29  1  0  0  0.3430  2
1  0.53  1  0  0  0.6010  0
-1  0.44  1  0  0  0.5480  1
1  0.46  0  1  0  0.5230  1
-1  0.20  0  1  0  0.3010  1
-1  0.38  1  0  0  0.5350  1
1  0.50  0  1  0  0.5860  1
1  0.33  0  1  0  0.4250  1
-1  0.33  0  1  0  0.3930  1
1  0.26  0  1  0  0.4040  0
1  0.58  1  0  0  0.7070  0
1  0.43  0  0  1  0.4800  1
-1  0.46  1  0  0  0.6440  0
1  0.60  1  0  0  0.7170  0
-1  0.42  1  0  0  0.4890  1
-1  0.56  0  0  1  0.5640  2
-1  0.62  0  1  0  0.6630  2
-1  0.50  1  0  0  0.6480  1
1  0.47  0  0  1  0.5200  1
-1  0.67  0  1  0  0.8040  2
-1  0.40  0  0  1  0.5040  1
1  0.42  0  1  0  0.4840  1
1  0.64  1  0  0  0.7200  0
-1  0.47  1  0  0  0.5870  2
1  0.45  0  1  0  0.5280  1
-1  0.25  0  0  1  0.4090  0
1  0.38  1  0  0  0.4840  0
1  0.55  0  0  1  0.6000  1
-1  0.44  1  0  0  0.6060  1
1  0.33  1  0  0  0.4100  1
1  0.34  0  0  1  0.3900  1
1  0.27  0  1  0  0.3370  2
1  0.32  0  1  0  0.4070  1
1  0.42  0  0  1  0.4700  1
-1  0.24  0  0  1  0.4030  0
1  0.42  0  1  0  0.5030  1
1  0.25  0  0  1  0.2800  2
1  0.51  0  1  0  0.5800  1
-1  0.55  0  1  0  0.6350  2
1  0.44  1  0  0  0.4780  2
-1  0.18  1  0  0  0.3980  0
-1  0.67  0  1  0  0.7160  2
1  0.45  0  0  1  0.5000  1
1  0.48  1  0  0  0.5580  1
-1  0.25  0  1  0  0.3900  1
-1  0.67  1  0  0  0.7830  1
1  0.37  0  0  1  0.4200  1
-1  0.32  1  0  0  0.4270  1
1  0.48  1  0  0  0.5700  1
-1  0.66  0  0  1  0.7500  2
1  0.61  1  0  0  0.7000  0
-1  0.58  0  0  1  0.6890  1
1  0.19  1  0  0  0.2400  2
1  0.38  0  0  1  0.4300  1
-1  0.27  1  0  0  0.3640  1
1  0.42  1  0  0  0.4800  1
1  0.60  1  0  0  0.7130  0
-1  0.27  0  0  1  0.3480  0
1  0.29  0  1  0  0.3710  0
-1  0.43  1  0  0  0.5670  1
1  0.48  1  0  0  0.5670  1
1  0.27  0  0  1  0.2940  2
-1  0.44  1  0  0  0.5520  0
1  0.23  0  1  0  0.2630  2
-1  0.36  0  1  0  0.5300  2
1  0.64  0  0  1  0.7250  0
1  0.29  0  0  1  0.3000  2
-1  0.33  1  0  0  0.4930  1
-1  0.66  0  1  0  0.7500  2
-1  0.21  0  0  1  0.3430  0
1  0.27  1  0  0  0.3270  2
1  0.29  1  0  0  0.3180  2
-1  0.31  1  0  0  0.4860  1
1  0.36  0  0  1  0.4100  1
1  0.49  0  1  0  0.5570  1
-1  0.28  1  0  0  0.3840  0
-1  0.43  0  0  1  0.5660  1
-1  0.46  0  1  0  0.5880  1
1  0.57  1  0  0  0.6980  0
-1  0.52  0  0  1  0.5940  1
-1  0.31  0  0  1  0.4350  1
-1  0.55  1  0  0  0.6200  2
1  0.50  1  0  0  0.5640  1
1  0.48  0  1  0  0.5590  1
-1  0.22  0  0  1  0.3450  0
1  0.59  0  0  1  0.6670  0
1  0.34  1  0  0  0.4280  2
-1  0.64  1  0  0  0.7720  2
1  0.29  0  0  1  0.3350  2
-1  0.34  0  1  0  0.4320  1
-1  0.61  1  0  0  0.7500  2
1  0.64  0  0  1  0.7110  0
-1  0.29  1  0  0  0.4130  0
1  0.63  0  1  0  0.7060  0
-1  0.29  0  1  0  0.4000  0
-1  0.51  1  0  0  0.6270  1
-1  0.24  0  0  1  0.3770  0
1  0.48  0  1  0  0.5750  1
1  0.18  1  0  0  0.2740  0
1  0.18  1  0  0  0.2030  2
1  0.33  0  1  0  0.3820  2
-1  0.20  0  0  1  0.3480  0
1  0.29  0  0  1  0.3300  2
-1  0.44  0  0  1  0.6300  0
-1  0.65  0  0  1  0.8180  0
-1  0.56  1  0  0  0.6370  2
-1  0.52  0  0  1  0.5840  1
-1  0.29  0  1  0  0.4860  0
-1  0.47  0  1  0  0.5890  1
1  0.68  1  0  0  0.7260  2
1  0.31  0  0  1  0.3600  1
1  0.61  0  1  0  0.6250  2
1  0.19  0  1  0  0.2150  2
1  0.38  0  0  1  0.4300  1
-1  0.26  1  0  0  0.4230  0
1  0.61  0  1  0  0.6740  0
1  0.40  1  0  0  0.4650  1
-1  0.49  1  0  0  0.6520  1
1  0.56  1  0  0  0.6750  0
-1  0.48  0  1  0  0.6600  1
1  0.52  1  0  0  0.5630  2
-1  0.18  1  0  0  0.2980  0
-1  0.56  0  0  1  0.5930  2
-1  0.52  0  1  0  0.6440  1
-1  0.18  0  1  0  0.2860  1
-1  0.58  1  0  0  0.6620  2
-1  0.39  0  1  0  0.5510  1
-1  0.46  1  0  0  0.6290  1
-1  0.40  0  1  0  0.4620  1
-1  0.60  1  0  0  0.7270  2
1  0.36  0  1  0  0.4070  2
1  0.44  1  0  0  0.5230  1
1  0.28  1  0  0  0.3130  2
1  0.54  0  0  1  0.6260  0

Test data. Replace space-space with tab character.

-1  0.51  1  0  0  0.6120  1
-1  0.32  0  1  0  0.4610  1
1  0.55  1  0  0  0.6270  0
1  0.25  0  0  1  0.2620  2
1  0.33  0  0  1  0.3730  2
-1  0.29  0  1  0  0.4620  0
1  0.65  1  0  0  0.7270  0
-1  0.43  0  1  0  0.5140  1
-1  0.54  0  1  0  0.6480  2
1  0.61  0  1  0  0.7270  0
1  0.52  0  1  0  0.6360  0
1  0.30  0  1  0  0.3350  2
1  0.29  1  0  0  0.3140  2
-1  0.47  0  0  1  0.5940  1
1  0.39  0  1  0  0.4780  1
1  0.47  0  0  1  0.5200  1
-1  0.49  1  0  0  0.5860  1
-1  0.63  0  0  1  0.6740  2
-1  0.30  1  0  0  0.3920  0
-1  0.61  0  0  1  0.6960  2
-1  0.47  0  0  1  0.5870  1
1  0.30  0  0  1  0.3450  2
-1  0.51  0  0  1  0.5800  1
-1  0.24  1  0  0  0.3880  1
-1  0.49  1  0  0  0.6450  1
1  0.66  0  0  1  0.7450  0
-1  0.65  1  0  0  0.7690  0
-1  0.46  0  1  0  0.5800  0
-1  0.45  0  0  1  0.5180  1
-1  0.47  1  0  0  0.6360  0
-1  0.29  1  0  0  0.4480  0
-1  0.57  0  0  1  0.6930  2
-1  0.20  1  0  0  0.2870  2
-1  0.35  1  0  0  0.4340  1
-1  0.61  0  0  1  0.6700  2
-1  0.31  0  0  1  0.3730  1
1  0.18  1  0  0  0.2080  2
1  0.26  0  0  1  0.2920  2
-1  0.28  1  0  0  0.3640  2
-1  0.59  0  0  1  0.6940  2
This entry was posted in PyTorch. Bookmark the permalink.