Skip to content

thu-nics/mini-opd

Repository files navigation

MiniOPD: Simple On-Policy Distillation

Minimal framework for on-policy distillation of language models. The student generates via an SGLang server, trains on KL divergence against a frozen teacher, and syncs weights back.

Student generates (SGLang) → Task reward scores → Train on teacher KL → Sync weights → repeat

Branches

Branch For What's in it
main Humans Full repo: core framework + all alternative losses, tasks, and recipes
core Agents & humans Core only: the minimal loop with one loss (opd_loss), one task (GSM8K), one script. Start here to build your own.

The core branch is auto-synced from main on every push (via GitHub Actions). It strips alt/, non-core scripts, and shim files — everything still runs standalone.

Setup

uv pip install -e ".[all]"
git config core.hooksPath .githooks

Quick Start

# Prepare dataset
python scripts/core/prepare_gsm8k.py --output local/data/gsm8k

# Launch SGLang server (separate terminal)
CUDA_VISIBLE_DEVICES=0 python -m sglang.launch_server \
    --model Qwen/Qwen3-0.6B --port 30000

# Train
python scripts/core/train.py scripts/core/recipe.yaml

Design

Core / alt separation. miniopd/core/ is the self-contained framework — training loop, rollout engine, dataset, one loss, one task. miniopd/alt/ provides alternative choices for each configurable slot. Delete alt/ and everything still works.

Callback interfaces. Every configurable piece is a Protocol-typed callback:

Callback Signature Core default
StudentForward (batch) -> (logits, labels)
TeacherForward (batch) -> logits
LossFn (student_logits, teacher_logits, labels, **kw) -> (loss, metrics) opd_loss
ForwardFn (batch) -> (OPDOutput, n_tokens) opd_forward_step
ProcessRewardsFn (rewards) -> processed post_process_rewards
TaskInterface .reward() + .eval_fn() Gsm8kTask

Swap any callback to change behavior. The training loop has no branching on modes.

Macro-batching. RolloutLoader buffers grad_accum micro-batches, fires one rollout for the full batch. Advantage normalization runs on all samples before splitting. Token-level gradient scaling ensures micro-batch size doesn't affect results.

Async overlap. buffer_size > 0 runs rollout in a background thread. 1-step weight staleness, corrected by importance-sampled ratio + behavioral weight cap.

Project Structure

miniopd/
├── core/                       # Self-contained framework
│   ├── training.py             # Protocols, train_loop, opd_forward_step
│   ├── rollout.py              # RolloutEngine, RolloutLoader, make_rollout
│   ├── dataset.py              # PackedSFTDataset, collate_padded
│   ├── loss.py                 # opd_loss
│   ├── tasks.py                # TaskInterface + Gsm8kTask
│   └── utils.py                # Math scoring helpers
├── alt/                        # Alternative choices
│   ├── loss.py                 # topk_kl, teacher_topk_rkl, grpo_kdrl
│   ├── training.py             # grpo_kdrl_forward_step
│   ├── rollout.py              # compute_advantages
│   └── tasks/                  # aime, deepmath, countdown

scripts/
├── core/                       # Standalone: prep + train + recipe
│   ├── prepare_gsm8k.py
│   ├── train.py
│   └── recipe.yaml
├── train/                      # One script per loss mode
│   ├── train_opd.py
│   ├── train_topk_opd.py
│   └── train_grpo_kdrl.py
├── recipe/                     # Matching YAML configs
└── dataset/                    # Data preparation

Data-Parallel Training

DP is a launch-time choice — no recipe changes needed. Use torchrun to spawn N ranks; each loads its own student + teacher and gradients are all-reduced automatically.

# 2 DP ranks, student + teacher on same GPU each
CUDA_VISIBLE_DEVICES=6,7 torchrun --nproc_per_node=2 \
    scripts/train/train_topk_opd.py scripts/recipe/topk_opd_deepmath.yaml

Only rank 0 runs wandb, eval, weight sync, and checkpointing. Useful when rollout is much faster than training.

Loss Modes

Script Loss Description
core/train.py opd_loss REINFORCE with teacher-student log-ratio as advantage
train_topk_opd.py teacher_topk_rkl_loss Reverse KL on teacher's top-K, renormalized (arXiv:2603.25562)
train_grpo_kdrl.py grpo_kdrl_loss GRPO clipped surrogate + reverse KL penalty

License

Apache 2.0

About

A minimal open-source on-policy-distillation implementation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors