PyTorch Neural Network Distillation Using the Teacher-Student Technique

The DeepSeek R1 large language model, announced on January 20 of this year, shocked the deep learning community because it was much less expensive to create and is much less expensive to use, compared to all other current LLMs.

According to early analyses, the three key factors that contribute to the success of R1 are 1.) chain-of-thought reasoning, 2.) reinforcement learning, and 3.) model distillation. This information spurred me to take a fresh look at model distillation.

Model distillation is the process of starting with a large trained model and producing a much smaller model that has nearly the same accuracy but is faster, and therefore less expensive to use. There are several ways to perform model distillation, but one of the most common is called the teacher-student technique.

In a nutshell, the teacher model is large, and is trained using standard training data with known, correct target values. The student model is trained by using the predictions of the teacher model. Put another way, the student model learns to mimic the teacher model.

For my demo, I used one of my standard synthetic datasets. The data looks like:

 1, 0.24, 1, 0, 0, 0.29500, 2
-1, 0.39, 0, 0, 1, 0.51200, 1
 1, 0.63, 0, 1, 0, 0.75800, 0
. . .

Each line represents a person. The fields are sex (male = -1, female = +1), age (divided by 100), State (Michigan, Nebraska, Oklahoma), income (divided by $100,000), and political leaning (conservative = 0, moderate = 1, liberal = 2). The goal is to predict a person’s political leaning from sex, age, State, and income. There are 200 training items and 40 test items.

The first part of the demo output is:

Creating teacher network Datasets
Creating 6-(100-100)-3 teacher network
Training teacher network

bat_size =  10
loss = NLLLoss
optimizer = SGD
max_epochs = 2000
lrn_rate = 0.005
epoch =    0   loss = 21.4943
epoch =  500   loss = 18.8513
epoch = 1000   loss = 12.6165
epoch = 1500   loss = 9.6615
Done

Computing teacher model accuracy
Teacher accuracy on training data = 0.8350
Teacher accuracy on test data = 0.6750

Predict politics for M, 36, Michigan, $44,500
tensor([[0.0855, 0.8515, 0.0629]])
1

The 6-(100-100)-3 teacher neural network has (6 * 100) + (100 * 100) + (100 * 3) + 100 + 100 + 3 = 11,103 weights and biases. The teacher scores 83.50% accuracy on the training data (167 out of 200 correct) and 67.50% accuracy on the test data (27 out of 40 correct). The teacher model predicts that a person who is (M, 36, Michigan, $44,500) has political leaning 1 (moderate) with a pseudo-probability of 0.8515.

The next part of the demo output uses the teacher model to train a small 6-8-3 student neural network:

Creating 6-8-3 student NN
Training student network using teacher

bat_size =  10
loss = MSE
optimizer = SGD
max_epochs = 1200
lrn_rate = 0.010
epoch =    0   loss = 186.2557
epoch =  300   loss = 28.9133
epoch =  600   loss = 15.7231
epoch =  900   loss = 8.4816
Done

Computing student model accuracy
Student accuracy on training data = 0.7850
Student accuracy on test data = 0.6500

The 6-8-3 student neural network has just (6 * 8) + (8 * 3) + 8 + 3 = 83 weights and biases. The student network scores well (78.50% and 65.00% accuracy on train and test data).

The next part of the demo fine-tunes the student network. The training data is fed to the student network and data items that are incorrectly predicted (21.50% of 200 items = 43 items) are saved in a mistakes.txt file, and then that data is used to tune the student network:

Tuning the student network
epoch =    0   loss = 8.7479
epoch =    2   loss = 8.0818
epoch =    4   loss = 7.2917
epoch =    6   loss = 7.4056
Done

Computing tuned student model accuracy
Student accuracy on training data = 0.8500
Student accuracy on test data = 0.6750

Predict politics for M, 36, Michigan, $44,500
tensor([[0.1421, 0.7870, 0.0709]])
1

End Teacher-Student NN demo

The tuned student neural network scores better than the teacher network on the training data (85.00% vs. 83.50% accuracy) and scores the same as the teacher network on the test data (67.50% accuracy).

Good fun.



I’ve always been fascinated by 1930s era electro-mechanical pinball machines. Many of them were gambling games somewhat similar to slot machines.

Left: The “Tycoon” game from Mills Novelty Company was produced in 1936. You insert one or more 5-cent nicels into slots numbered 1 to 7. Then you press a lever that shoots a ball to the top of the playing field. The ball falls through a lane that indicates the payout multiplier value (usually 2 but as high as 30) and then the ball falls into a receiving lane. For example, if you put two nickels in slot #3 and the balls rolls through multiplier 4, and then lands in receiver #3, you win 4 * 2 nickels = 40 cents. There are no flippers like you’d see in later pinball machines.

Right: Advertising brochure for the “Rocket” game from Bally Manufacturing was produced in 1933. Player gets 10 balls for 5 cents. The scoring holes have trap doors that close as soon as a ball enters, and then the corresponding dial near the bottom of the playfield advance. There’s an early circular tilt detection mechanism device in the lower left.


Demo program:

# people_teacher_student.py
# predict politics from sex, age, state, income
# teacher-student technique to distill to smaller network
# PyTorch 2.3.1-CPU Anaconda3-2023.09-0  Python 3.11.5
# Windows 10/11 

import numpy as np
import torch as T
device = T.device('cpu')  # apply to Tensor or Module

# -----------------------------------------------------------

class PeopleDataset(T.utils.data.Dataset):
  # sex age   state   income  politics
  # -1, 0.27, 0,1,0,  0.76100,  2
  # +1, 0.19, 0,0,1,  0.65500,  0
  # sex: -1 = male, +1 = female
  # state: Michigan, Nebraska, Oklahoma
  # politics: conservative, moderate, liberal

  def __init__(self, src_file):
    tmp_x = np.loadtxt(src_file, usecols=range(0,6),
      delimiter=",", dtype=np.float32)
    tmp_y = np.loadtxt(src_file, usecols=6,
      delimiter=",", dtype=np.int64)   # 1d required

    self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device)
    self.y_data = T.tensor(tmp_y, dtype=T.int64).to(device) 

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx]
    trgts = self.y_data[idx] 
    return (preds, trgts)  # as Tuple

# -----------------------------------------------------------

class TeacherNet(T.nn.Module):
  def __init__(self):
    super(TeacherNet, self).__init__()
    self.hid1 = T.nn.Linear(6, 100)  # 6-(100-100)-3
    self.hid2 = T.nn.Linear(100, 100)
    self.oupt = T.nn.Linear(100, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.hid2.weight)
    T.nn.init.zeros_(self.hid2.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = T.tanh(self.hid2(z))
    z = T.log_softmax(self.oupt(z), dim=1)  # NLLLoss() 
    return z

# -----------------------------------------------------------

class StudentNet(T.nn.Module):
  def __init__(self):
    super(StudentNet, self).__init__()
    self.hid1 = T.nn.Linear(6, 8)  # 6-8-3
    self.oupt = T.nn.Linear(8, 3)

    T.nn.init.xavier_uniform_(self.hid1.weight)
    T.nn.init.zeros_(self.hid1.bias)
    T.nn.init.xavier_uniform_(self.oupt.weight)
    T.nn.init.zeros_(self.oupt.bias)

  def forward(self, x):
    z = T.tanh(self.hid1(x))
    z = self.oupt(z)  # no activation for MSELoss() 
    return z

# -----------------------------------------------------------

def accuracy(model, ds):
  # assumes model.eval()
  n_correct = 0; n_wrong = 0
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)     # 0 1 or 2
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx == Y:
      n_correct += 1
    else:
      n_wrong += 1

  acc = (n_correct * 1.0) / (n_correct + n_wrong)
  return acc

# -----------------------------------------------------------

def create_error_data(model, ds, fn):
  f = open(fn, 'w')
  for i in range(len(ds)):
    X = ds[i][0].reshape(1,-1)  # make it a batch
    Y = ds[i][1].reshape(1)     # 0 1 or 2
    with T.no_grad():
      oupt = model(X)  # logits form

    big_idx = T.argmax(oupt)  # 0 or 1 or 2
    if big_idx != Y:  # incorrect
      x = X.numpy().flatten()
      y = Y.item()
      for i in range(len(x)):
        f.write(str(x[i]) + ", ")
      f.write(str(y) + "\n")
  f.close()

# -----------------------------------------------------------

def train(model, teacher, ds, bs, lr, opt, lf, me, le):
  # teacher, dataset, bat_size, lrn_rate, optimizer,
  # loss_function, max_epochs, log interval

  train_ldr = T.utils.data.DataLoader(ds, batch_size=bs,
    shuffle=True)

  if opt == 'SGD':
    optimizer = T.optim.SGD(model.parameters(), lr=lr)
  elif opt == 'Adam':
    optimizer = T.optim.Adam(model.parameters(), lr=lr)

  if lf == 'MSE': loss_func = T.nn.MSELoss()
  elif lf == 'NLLLoss': loss_func = T.nn.NLLLoss()

  # print("\nStart training")
  for epoch in range(0, me):
    epoch_loss = 0  # for one full epoch
    for (batch_idx, batch) in enumerate(train_ldr):
      X = batch[0]
      if teacher == 'None': Y = batch[1]
      else: Y = teacher(X)
      optimizer.zero_grad()
      oupt = model(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % le == 0:
      print("epoch = %4d   loss = %0.4f" \
        % (epoch, epoch_loss))
  # print("Done ")

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin Teacher-Student NN distillation demo ")
  T.manual_seed(0)
  np.random.seed(0)
  
  # 1. create datasets objects
  print("\nCreating teacher network Datasets ")

  train_file = ".\\Data\\people_train.txt"
  train_ds = PeopleDataset(train_file)  # 200 rows

  test_file = ".\\Data\\people_test.txt"
  test_ds = PeopleDataset(test_file)  # 40 rows

  # 2. create network
  print("\nCreating 6-(100-100)-3 teacher network ")
  teacher = TeacherNet().to(device)
  teacher.train()  # set mode

  # 3. train the teacher NN
  bat_size = 10
  max_epochs = 2000
  ep_log_interval = 500
  lrn_rate = 0.005
  opt = 'SGD'
  loss_func = 'NLLLoss'

  print("\nTraining teacher network ")
  print("\nbat_size = %3d " % bat_size)
  print("loss = " + loss_func)
  print("optimizer = " + opt)
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  train(teacher, 'None', train_ds, bat_size, lrn_rate,
    opt, loss_func, max_epochs, ep_log_interval)
  print("Done ")

  # 4. evaluate teacher model accuracy
  print("\nComputing teacher model accuracy")
  teacher.eval()
  acc_train = accuracy(teacher, train_ds)  # item-by-item
  print("Teacher accuracy on training data = %0.4f" \
    % acc_train)
  acc_test = accuracy(teacher, test_ds)  # item-by-item
  print("Teacher accuracy on test data = %0.4f" \
    % acc_test)

  # 4b. use teacher model
  print("\nPredict politics for M, 36, Michigan, $44,500 ")
  x = np.array([[-1, 0.36, 1,0,0, 0.44500]], dtype=np.float32)
  x = T.tensor(x, dtype=T.float32)
  with T.no_grad():
    pred_y_logits = teacher(x)
  print(T.exp(pred_y_logits))
  pred_y = T.argmax(pred_y_logits).item()
  print(pred_y)

  # 5. create and train Student NN
  print("\nCreating 6-8-3 student NN")
  student = StudentNet()
  student.train()  # set mode

  # 6. train the student NN
  bat_size = 10
  max_epochs = 1200
  ep_log_interval = 300
  lrn_rate = 0.01
  opt = 'SGD'
  loss_func = 'MSE'

  print("\nTraining student network using teacher ")
  print("\nbat_size = %3d " % bat_size)
  print("loss = " + loss_func)
  print("optimizer = " + opt)
  print("max_epochs = %3d " % max_epochs)
  print("lrn_rate = %0.3f " % lrn_rate)

  train(student, teacher, train_ds, bat_size, lrn_rate,
    opt, loss_func, max_epochs, ep_log_interval)
  print("Done ")

  # 7. evaluate student model accuracy
  print("\nComputing student model accuracy")
  student.eval()
  acc_train = accuracy(student, train_ds)  # item-by-item
  print("Student accuracy on training data = %0.4f" \
    % acc_train)
  acc_test = accuracy(student, test_ds)  # item-by-item
  print("Student accuracy on test data = %0.4f" \
    % acc_test)

  # 8. fine-tune student model
  print("\nTuning the student network ")
  mistakes_file = ".\\Data\\mistakes.txt"
  create_error_data(student, train_ds, mistakes_file)
  tune_ds = PeopleDataset(mistakes_file)

  bat_size = 10
  max_epochs = 8
  ep_log_interval = 2
  lrn_rate = 0.001
  opt = 'SGD'
  loss_func = 'NLLLoss'
  train(student, 'None', tune_ds, bat_size, lrn_rate,
    opt, loss_func, max_epochs, ep_log_interval)
  print("Done ")

  # 9. evaluate tuned student model accuracy
  print("\nComputing tuned student model accuracy")
  student.eval()
  acc_train = accuracy(student, train_ds)  # item-by-item
  print("Student accuracy on training data = %0.4f" \
    % acc_train)
  acc_test = accuracy(student, test_ds)  # item-by-item
  print("Student accuracy on test data = %0.4f" \
    % acc_test)

  # 9b. use tuned student model
  print("\nPredict politics for M, 36, Michigan, $44,500 ")
  x = np.array([[-1, 0.36, 1,0,0, 0.44500]], dtype=np.float32)
  x = T.tensor(x, dtype=T.float32)
  with T.no_grad():
    pred_y_logits = student(x)
  print(T.nn.functional.softmax(pred_y_logits, dim=1))
  pred_y = T.argmax(pred_y_logits).item()
  print(pred_y)

  print("\nEnd Teacher-Student NN demo")

if __name__ == "__main__":
  main()

Training data:

# people_train.txt
# sex (M = -1, F = 1)  age,  state (Michigan, 
# Nebraska, Oklahoma), income
# politics (conservative, moderate, liberal)
#
 1, 0.24, 1, 0, 0, 0.29500, 2
-1, 0.39, 0, 0, 1, 0.51200, 1
 1, 0.63, 0, 1, 0, 0.75800, 0
-1, 0.36, 1, 0, 0, 0.44500, 1
 1, 0.27, 0, 1, 0, 0.28600, 2
 1, 0.50, 0, 1, 0, 0.56500, 1
 1, 0.50, 0, 0, 1, 0.55000, 1
-1, 0.19, 0, 0, 1, 0.32700, 0
 1, 0.22, 0, 1, 0, 0.27700, 1
-1, 0.39, 0, 0, 1, 0.47100, 2
 1, 0.34, 1, 0, 0, 0.39400, 1
-1, 0.22, 1, 0, 0, 0.33500, 0
 1, 0.35, 0, 0, 1, 0.35200, 2
-1, 0.33, 0, 1, 0, 0.46400, 1
 1, 0.45, 0, 1, 0, 0.54100, 1
 1, 0.42, 0, 1, 0, 0.50700, 1
-1, 0.33, 0, 1, 0, 0.46800, 1
 1, 0.25, 0, 0, 1, 0.30000, 1
-1, 0.31, 0, 1, 0, 0.46400, 0
 1, 0.27, 1, 0, 0, 0.32500, 2
 1, 0.48, 1, 0, 0, 0.54000, 1
-1, 0.64, 0, 1, 0, 0.71300, 2
 1, 0.61, 0, 1, 0, 0.72400, 0
 1, 0.54, 0, 0, 1, 0.61000, 0
 1, 0.29, 1, 0, 0, 0.36300, 0
 1, 0.50, 0, 0, 1, 0.55000, 1
 1, 0.55, 0, 0, 1, 0.62500, 0
 1, 0.40, 1, 0, 0, 0.52400, 0
 1, 0.22, 1, 0, 0, 0.23600, 2
 1, 0.68, 0, 1, 0, 0.78400, 0
-1, 0.60, 1, 0, 0, 0.71700, 2
-1, 0.34, 0, 0, 1, 0.46500, 1
-1, 0.25, 0, 0, 1, 0.37100, 0
-1, 0.31, 0, 1, 0, 0.48900, 1
 1, 0.43, 0, 0, 1, 0.48000, 1
 1, 0.58, 0, 1, 0, 0.65400, 2
-1, 0.55, 0, 1, 0, 0.60700, 2
-1, 0.43, 0, 1, 0, 0.51100, 1
-1, 0.43, 0, 0, 1, 0.53200, 1
-1, 0.21, 1, 0, 0, 0.37200, 0
 1, 0.55, 0, 0, 1, 0.64600, 0
 1, 0.64, 0, 1, 0, 0.74800, 0
-1, 0.41, 1, 0, 0, 0.58800, 1
 1, 0.64, 0, 0, 1, 0.72700, 0
-1, 0.56, 0, 0, 1, 0.66600, 2
 1, 0.31, 0, 0, 1, 0.36000, 1
-1, 0.65, 0, 0, 1, 0.70100, 2
 1, 0.55, 0, 0, 1, 0.64300, 0
-1, 0.25, 1, 0, 0, 0.40300, 0
 1, 0.46, 0, 0, 1, 0.51000, 1
-1, 0.36, 1, 0, 0, 0.53500, 0
 1, 0.52, 0, 1, 0, 0.58100, 1
 1, 0.61, 0, 0, 1, 0.67900, 0
 1, 0.57, 0, 0, 1, 0.65700, 0
-1, 0.46, 0, 1, 0, 0.52600, 1
-1, 0.62, 1, 0, 0, 0.66800, 2
 1, 0.55, 0, 0, 1, 0.62700, 0
-1, 0.22, 0, 0, 1, 0.27700, 1
-1, 0.50, 1, 0, 0, 0.62900, 0
-1, 0.32, 0, 1, 0, 0.41800, 1
-1, 0.21, 0, 0, 1, 0.35600, 0
 1, 0.44, 0, 1, 0, 0.52000, 1
 1, 0.46, 0, 1, 0, 0.51700, 1
 1, 0.62, 0, 1, 0, 0.69700, 0
 1, 0.57, 0, 1, 0, 0.66400, 0
-1, 0.67, 0, 0, 1, 0.75800, 2
 1, 0.29, 1, 0, 0, 0.34300, 2
 1, 0.53, 1, 0, 0, 0.60100, 0
-1, 0.44, 1, 0, 0, 0.54800, 1
 1, 0.46, 0, 1, 0, 0.52300, 1
-1, 0.20, 0, 1, 0, 0.30100, 1
-1, 0.38, 1, 0, 0, 0.53500, 1
 1, 0.50, 0, 1, 0, 0.58600, 1
 1, 0.33, 0, 1, 0, 0.42500, 1
-1, 0.33, 0, 1, 0, 0.39300, 1
 1, 0.26, 0, 1, 0, 0.40400, 0
 1, 0.58, 1, 0, 0, 0.70700, 0
 1, 0.43, 0, 0, 1, 0.48000, 1
-1, 0.46, 1, 0, 0, 0.64400, 0
 1, 0.60, 1, 0, 0, 0.71700, 0
-1, 0.42, 1, 0, 0, 0.48900, 1
-1, 0.56, 0, 0, 1, 0.56400, 2
-1, 0.62, 0, 1, 0, 0.66300, 2
-1, 0.50, 1, 0, 0, 0.64800, 1
 1, 0.47, 0, 0, 1, 0.52000, 1
-1, 0.67, 0, 1, 0, 0.80400, 2
-1, 0.40, 0, 0, 1, 0.50400, 1
 1, 0.42, 0, 1, 0, 0.48400, 1
 1, 0.64, 1, 0, 0, 0.72000, 0
-1, 0.47, 1, 0, 0, 0.58700, 2
 1, 0.45, 0, 1, 0, 0.52800, 1
-1, 0.25, 0, 0, 1, 0.40900, 0
 1, 0.38, 1, 0, 0, 0.48400, 0
 1, 0.55, 0, 0, 1, 0.60000, 1
-1, 0.44, 1, 0, 0, 0.60600, 1
 1, 0.33, 1, 0, 0, 0.41000, 1
 1, 0.34, 0, 0, 1, 0.39000, 1
 1, 0.27, 0, 1, 0, 0.33700, 2
 1, 0.32, 0, 1, 0, 0.40700, 1
 1, 0.42, 0, 0, 1, 0.47000, 1
-1, 0.24, 0, 0, 1, 0.40300, 0
 1, 0.42, 0, 1, 0, 0.50300, 1
 1, 0.25, 0, 0, 1, 0.28000, 2
 1, 0.51, 0, 1, 0, 0.58000, 1
-1, 0.55, 0, 1, 0, 0.63500, 2
 1, 0.44, 1, 0, 0, 0.47800, 2
-1, 0.18, 1, 0, 0, 0.39800, 0
-1, 0.67, 0, 1, 0, 0.71600, 2
 1, 0.45, 0, 0, 1, 0.50000, 1
 1, 0.48, 1, 0, 0, 0.55800, 1
-1, 0.25, 0, 1, 0, 0.39000, 1
-1, 0.67, 1, 0, 0, 0.78300, 1
 1, 0.37, 0, 0, 1, 0.42000, 1
-1, 0.32, 1, 0, 0, 0.42700, 1
 1, 0.48, 1, 0, 0, 0.57000, 1
-1, 0.66, 0, 0, 1, 0.75000, 2
 1, 0.61, 1, 0, 0, 0.70000, 0
-1, 0.58, 0, 0, 1, 0.68900, 1
 1, 0.19, 1, 0, 0, 0.24000, 2
 1, 0.38, 0, 0, 1, 0.43000, 1
-1, 0.27, 1, 0, 0, 0.36400, 1
 1, 0.42, 1, 0, 0, 0.48000, 1
 1, 0.60, 1, 0, 0, 0.71300, 0
-1, 0.27, 0, 0, 1, 0.34800, 0
 1, 0.29, 0, 1, 0, 0.37100, 0
-1, 0.43, 1, 0, 0, 0.56700, 1
 1, 0.48, 1, 0, 0, 0.56700, 1
 1, 0.27, 0, 0, 1, 0.29400, 2
-1, 0.44, 1, 0, 0, 0.55200, 0
 1, 0.23, 0, 1, 0, 0.26300, 2
-1, 0.36, 0, 1, 0, 0.53000, 2
 1, 0.64, 0, 0, 1, 0.72500, 0
 1, 0.29, 0, 0, 1, 0.30000, 2
-1, 0.33, 1, 0, 0, 0.49300, 1
-1, 0.66, 0, 1, 0, 0.75000, 2
-1, 0.21, 0, 0, 1, 0.34300, 0
 1, 0.27, 1, 0, 0, 0.32700, 2
 1, 0.29, 1, 0, 0, 0.31800, 2
-1, 0.31, 1, 0, 0, 0.48600, 1
 1, 0.36, 0, 0, 1, 0.41000, 1
 1, 0.49, 0, 1, 0, 0.55700, 1
-1, 0.28, 1, 0, 0, 0.38400, 0
-1, 0.43, 0, 0, 1, 0.56600, 1
-1, 0.46, 0, 1, 0, 0.58800, 1
 1, 0.57, 1, 0, 0, 0.69800, 0
-1, 0.52, 0, 0, 1, 0.59400, 1
-1, 0.31, 0, 0, 1, 0.43500, 1
-1, 0.55, 1, 0, 0, 0.62000, 2
 1, 0.50, 1, 0, 0, 0.56400, 1
 1, 0.48, 0, 1, 0, 0.55900, 1
-1, 0.22, 0, 0, 1, 0.34500, 0
 1, 0.59, 0, 0, 1, 0.66700, 0
 1, 0.34, 1, 0, 0, 0.42800, 2
-1, 0.64, 1, 0, 0, 0.77200, 2
 1, 0.29, 0, 0, 1, 0.33500, 2
-1, 0.34, 0, 1, 0, 0.43200, 1
-1, 0.61, 1, 0, 0, 0.75000, 2
 1, 0.64, 0, 0, 1, 0.71100, 0
-1, 0.29, 1, 0, 0, 0.41300, 0
 1, 0.63, 0, 1, 0, 0.70600, 0
-1, 0.29, 0, 1, 0, 0.40000, 0
-1, 0.51, 1, 0, 0, 0.62700, 1
-1, 0.24, 0, 0, 1, 0.37700, 0
 1, 0.48, 0, 1, 0, 0.57500, 1
 1, 0.18, 1, 0, 0, 0.27400, 0
 1, 0.18, 1, 0, 0, 0.20300, 2
 1, 0.33, 0, 1, 0, 0.38200, 2
-1, 0.20, 0, 0, 1, 0.34800, 0
 1, 0.29, 0, 0, 1, 0.33000, 2
-1, 0.44, 0, 0, 1, 0.63000, 0
-1, 0.65, 0, 0, 1, 0.81800, 0
-1, 0.56, 1, 0, 0, 0.63700, 2
-1, 0.52, 0, 0, 1, 0.58400, 1
-1, 0.29, 0, 1, 0, 0.48600, 0
-1, 0.47, 0, 1, 0, 0.58900, 1
 1, 0.68, 1, 0, 0, 0.72600, 2
 1, 0.31, 0, 0, 1, 0.36000, 1
 1, 0.61, 0, 1, 0, 0.62500, 2
 1, 0.19, 0, 1, 0, 0.21500, 2
 1, 0.38, 0, 0, 1, 0.43000, 1
-1, 0.26, 1, 0, 0, 0.42300, 0
 1, 0.61, 0, 1, 0, 0.67400, 0
 1, 0.40, 1, 0, 0, 0.46500, 1
-1, 0.49, 1, 0, 0, 0.65200, 1
 1, 0.56, 1, 0, 0, 0.67500, 0
-1, 0.48, 0, 1, 0, 0.66000, 1
 1, 0.52, 1, 0, 0, 0.56300, 2
-1, 0.18, 1, 0, 0, 0.29800, 0
-1, 0.56, 0, 0, 1, 0.59300, 2
-1, 0.52, 0, 1, 0, 0.64400, 1
-1, 0.18, 0, 1, 0, 0.28600, 1
-1, 0.58, 1, 0, 0, 0.66200, 2
-1, 0.39, 0, 1, 0, 0.55100, 1
-1, 0.46, 1, 0, 0, 0.62900, 1
-1, 0.40, 0, 1, 0, 0.46200, 1
-1, 0.60, 1, 0, 0, 0.72700, 2
 1, 0.36, 0, 1, 0, 0.40700, 2
 1, 0.44, 1, 0, 0, 0.52300, 1
 1, 0.28, 1, 0, 0, 0.31300, 2
 1, 0.54, 0, 0, 1, 0.62600, 0

Test data:

# people_test.txt
#
-1, 0.51, 1, 0, 0, 0.61200, 1
-1, 0.32, 0, 1, 0, 0.46100, 1
 1, 0.55, 1, 0, 0, 0.62700, 0
 1, 0.25, 0, 0, 1, 0.26200, 2
 1, 0.33, 0, 0, 1, 0.37300, 2
-1, 0.29, 0, 1, 0, 0.46200, 0
 1, 0.65, 1, 0, 0, 0.72700, 0
-1, 0.43, 0, 1, 0, 0.51400, 1
-1, 0.54, 0, 1, 0, 0.64800, 2
 1, 0.61, 0, 1, 0, 0.72700, 0
 1, 0.52, 0, 1, 0, 0.63600, 0
 1, 0.30, 0, 1, 0, 0.33500, 2
 1, 0.29, 1, 0, 0, 0.31400, 2
-1, 0.47, 0, 0, 1, 0.59400, 1
 1, 0.39, 0, 1, 0, 0.47800, 1
 1, 0.47, 0, 0, 1, 0.52000, 1
-1, 0.49, 1, 0, 0, 0.58600, 1
-1, 0.63, 0, 0, 1, 0.67400, 2
-1, 0.30, 1, 0, 0, 0.39200, 0
-1, 0.61, 0, 0, 1, 0.69600, 2
-1, 0.47, 0, 0, 1, 0.58700, 1
 1, 0.30, 0, 0, 1, 0.34500, 2
-1, 0.51, 0, 0, 1, 0.58000, 1
-1, 0.24, 1, 0, 0, 0.38800, 1
-1, 0.49, 1, 0, 0, 0.64500, 1
 1, 0.66, 0, 0, 1, 0.74500, 0
-1, 0.65, 1, 0, 0, 0.76900, 0
-1, 0.46, 0, 1, 0, 0.58000, 0
-1, 0.45, 0, 0, 1, 0.51800, 1
-1, 0.47, 1, 0, 0, 0.63600, 0
-1, 0.29, 1, 0, 0, 0.44800, 0
-1, 0.57, 0, 0, 1, 0.69300, 2
-1, 0.20, 1, 0, 0, 0.28700, 2
-1, 0.35, 1, 0, 0, 0.43400, 1
-1, 0.61, 0, 0, 1, 0.67000, 2
-1, 0.31, 0, 0, 1, 0.37300, 1
 1, 0.18, 1, 0, 0, 0.20800, 2
 1, 0.26, 0, 0, 1, 0.29200, 2
-1, 0.28, 1, 0, 0, 0.36400, 2
-1, 0.59, 0, 0, 1, 0.69400, 2
This entry was posted in PyTorch. Bookmark the permalink.

3 Responses to PyTorch Neural Network Distillation Using the Teacher-Student Technique

  1. deepdark103's avatar deepdark103 says:

    I’m astounded that the student network would show higher accuracy on the original data than the teacher network does. How does that happen?

    • There’s a couple of reasons why the Student network scores slightly better than the Teacher network. First, the synthetic dataset is very small and so anything can happen. Second, the Student network had access to the original source training data used by the Teacher, and so the fine-tuning process essentially concentrated only on the data items that were predicted incorrectly, increasing the Student accuracy. Third, the 6-(100-100)-3 architecture of the Teacher is too big for the small 200-item dataset and the 6-8-3 architecture of the Student is closer to the NN is used to generate the synthetic data (I think it was 6-10-3 but I don’t remember for sure).

  2. Thorsten Kleppe's avatar Thorsten Kleppe says:

    The world is on fire!

    Never before has my Twitter been flooded with just one term: DeepSeek

    The DeepSeek app beamed its way to number 1 in the US and China. The stock market in tech crashed. Even NVIDIA, thanks to the new approach to perform more efficient calculations. Apple is far away from its own AI as it seems from the outside, but profited easily. The assumption is that the hardware is simply very well suited to the new possibilities.

    Every major competitor has made a statement. It was everywhere in the media. Even POTUS reacted.

    I think the best thing is that we not only get the models, but that they really shared their knowledge.

    Here is something from my bubble:

    this is one of the most comprehensive articles on DeepSeek R1 by
    @i_amanchadha
    , it covers:

    Mixture of Experts (MoE)
    Multihead Latent Attention (MLA)
    Multi-Token Prediction (MTP)
    Group Relative Policy Optimization (GRPO)
    emergent reasoning behaviors
    x.com/Hesamation/status/1884532647054987490

    Lots of hot takes on whether it’s possible that DeepSeek made training 45x more efficient, but
    @doodlestein
    wrote a very clear explanation of how they did it. Once someone breaks it down, it’s not hard to understand. Rough summary:

    • Use 8 bit instead of 32 bit floating point numbers, which gives massive memory savings
    • Compress the key-value indices which eat up much of the VRAM; they get 93% compression ratios
    • Do multi-token prediction instead of single-token prediction which effectively doubles inference speed
    • Mixture of Experts model decomposes a big model into small models that can run on consumer-grade GPUs
      x.com/snowmaker/status/1883628838070149244

    DeepSeek V3 is being studied like crazy.
    Biggest insight is that chinese engineers are writing the GPU routines in a low level assembly language, bypassing Nvidias cuds compilier.
    This has enabled them to do micro optimizations not possible otherwise.
    https://x.com/Perpetualmaniac/status/1884542884134944962

    We reproduced DeepSeek R1-Zero in the CountDown game, and it just works
    Through RL, the 3B base LM develops self-verification and search abilities all on its own
    You can experience the Ahah moment yourself for < $30
    Code: github.com/Jiayi-Pan/TinyZero
    Here’s what we learned đź§µ
    x.com/jiayi_pirate/status/1882839370505621655

    Combining DeepSeek-R1 with Anthropic Claude Sonnet 3.5 sets a new coding record on Aider benchmark.
    x.com/ZainHasan6/status/1883008844726702526

    There was so much overwhelming input and so much more to say. One recipe to keep in mind, if you just take the reasoning part of R1 and put it into Sonnet can predict the best poaaible outputs. But now there are even more models with reasoning and even more possibilities. It remains exciting on this new level.

    Professor Loviscach has tested all the major LLMs with his exams. R1 reached A grade range.
    j3l7h.de/blog/2025-01-29_21_56_DeepSeek%20R1%20und%20meine%20Klausuren

    Perhaps even the best student on his best day should soon no longer do without AI.

    The angle goes up in a crazy way!

Leave a Reply