Why Distillation Works

  1. It basically reweights the dataset to focus on tasks that teacher model is confident on, which reduces weight on data points that are too hard/impossible, which improves optimization.
  2. Large model is more sample-efficient, meaning that it’s easier for them to find shortcuts in the circuits. Online learning could help small model to generalize.

Forward KL and Reverse KL

KL divergence of two distributions $P$ and $Q$ is defined as: $$ D_{KL}(P | Q) = \mathbb{E}_{x \sim P}\left[\log \frac{P(X)}{Q(X)}\right] $$

In general knowledge distillation, $P$ is the output of the teacher model and does not have trainable parameters, while $Q$ is the output of the student model and contains parameters that can be optimized. Because the KL Divergence is not symmetric, thus forward KL is $$ \arg\min_\theta D_{\mathrm{KL}}(P | Q_\theta) $$

Reverse KL is $$ \arg\min_\theta D_{\mathrm{KL}}(Q_\theta | P) $$

Offline Distillation

The core idea of knowledge distillation is to use the output of a teacher model (typically soft labels, i.e., probability distributions) to guide the training of a student model. Unlike traditional supervised learning, knowledge distillation uses not only the real (hard) labels (namely the original labels in student model training data) but also the soft labels generated by the teacher model to convey more information. Thus simple distillation training loss take the following form $$ L = \alpha L_{\text{soft}} + (1 - \alpha) L_{\text{hard}} $$ Where the two losses are: $$ L_{\text{hard}} = -\sum_i y_i \log q_i \\ L_{\text{soft}} = -T^2 \sum_i p_i^{(T)} \log q_i^{(T)} $$ $p_i$ is the soft label (probability distribution) generated by the teacher model, and $q_i$ is the probability distribution output by the student model. Here we only keep the cross entropy term from KL divergence. Note that the soft label loss is multiplied by $T^2$ to balance the effect of the temperature factor on the gradients.

Online Distillation

In online distillation, both the teacher model and the student model are updated simultaneously, and the whole knowledge distillation framework is end-to-end trainable.

On-Policy Distillation

When we fine-tune a student model (e.g., a smaller LLM) using synthetic data generated by a teacher model (e.g., GPT-4), we’re typically doing supervised learning. The training is done with teacher’s distribution. While at inference time, small model can only see its self output - another distribution, thus a mismatch for training and inference.

  • Generate outputs from the student using the same prompts used in the initial SFT step.

  • Take these outputs (even if they’re low-quality) and feed them back into both the teacher and student models. The teacher will compute its own distribution over tokens (logits). The student does the same.

  • Train the student to match the teacher’s token-by-token probability distribution using KL divergence loss.

The nice part of the on-policy distillation is that it’s compatible with existing RLHF framework. We can take the following steps to distill knowledge from reference model to policy model

  1. Take existing RLHF framework
  2. Turn off the reward maximization term.
  3. Switch the reference SFT policy with a teacher policy (*change partitioning). You have implemented on-policy GKD with reverse KL. Switch reverse KL with other token-level f-divergences (e.g., JSD, Jeffreys) for best results.

Distillation with Speculative Decoding

SKD introduces an interleaved sampling mechanism between the student and teacher models:

  1. Token Proposal: The student generates a sequence of tokens.
  2. Teacher Evaluation: The teacher assesses each token. If a token is poorly ranked based on the teacher’s distribution, it is replaced with the teacher’s version.
  3. Adaptive Training: This process creates high-quality training data that aligns with the student’s inference-time behavior, facilitating better learning.

Distillation Implementation Using TorchRPC

import multiprocessing as mp
import torch
import torch.distributed
from torch.distributed import rpc

class Model:
    def __call__(self, tensor):
        return tensor + 1

def call_model(tensor):
    return model(tensor)

model = None


def teacher(rank,world_size):
    torch.distributed.init_process_group(rank=0,world_size=1,backend='nccl',init_method=f'tcp://127.0.0.1:{29500+rank}')
    options = rpc.TensorPipeRpcBackendOptions(init_method='tcp://127.0.0.1:30000')
    global model
    model = Model()
    rpc.init_rpc('teacher', rank=rank, world_size=world_size, rpc_backend_options=options)
    rpc.shutdown()

def student(rank,world_size):
    torch.distributed.init_process_group(rank=0,world_size=1,backend='nccl',init_method=f'tcp://127.0.0.1:{29500+rank}')
    options = rpc.TensorPipeRpcBackendOptions(init_method='tcp://127.0.0.1:30000')
    rpc.init_rpc('student', rank=rank, world_size=world_size, rpc_backend_options=options)

    input_ids = torch.randn(4)
    teacher_probs = rpc.rpc_async('teacher', call_model, args=(input_ids,))
    student_probs = input_ids
    loss = teacher_probs.wait() - student_probs
    print(loss)
    rpc.shutdown()


def main(rank, world_size):
    teacher_offset = world_size // 2

    if rank < teacher_offset:
        student(rank, world_size)
    else:
        teacher(rank, world_size)

if __name__ == '__main__':
    world_size = 2
    ps = [mp.Process(None,main,args=(rank,world_size)) for rank in range(world_size)]

    for p in ps:
        p.start()
    for p in ps:
        p.join()

# rpc.TensorPipeRpcBackendOptions(init_method='tcp://127.0.0.1:30000', device_maps={'teacher':{0:1}})

References

  1. Knowledge distillation: A good teacher is patient and consistent
  2. On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes
  3. Born Again Neural Networks
  4. https://zhuanlan.zhihu.com/p/10091011992
  5. https://github.com/predibase/llm_distillation_playbook
  6. A Little Help Goes a Long Way: Efficient LLM Training by Leveraging Small LMs
  7. Distilling the Knowledge in a Neural Network
  8. Does Knowledge Distillation Really Work?
  9. DistillSpec: Improving Speculative Decoding via Knowledge Distillation
  10. Large scale distributed neural network training through online distillation
  11. Distillation Scaling Laws
  12. Gemma 2: Improving Open Language Models at a Practical Size
  13. Diversity-Rewarded CFG Distillation
  14. Speculative Knowledge Distillation: Bridging the Teacher-Student Gap Through Interleaved Sampling
  15. https://dibyaghosh.com/blog/probability/kldivergence.html
  16. MiniLLM: Knowledge Distillation of Large Language Models