Skip to content

whittle-org/whittle-paper

Repository files navigation

Whittle

This repo supports:

  1. Running a search procedure (evolutionary or Optuna-based) to extract subnet configurations per parameter-count bucket from a base model (Pythia/Qwen).
  2. Pretraining those subnets and/or distilling from a teacher model.

Install

pip install -e .

The package depends on the external whittle, litgpt, and lightning libraries — install whichever versions your environment requires alongside.

All commands below expect PYTHONPATH=.:search so the local search/ package and the shared search_spaces/ / modules/ folders resolve.

Repository layout

  • pretrain.py, distill.py — top-level entry points.
  • search/evo/evo_search_coarse.py — evolutionary search over the coarse space.
  • search/lib_optuna/optuna_search_coarse.py — Optuna-based search over the coarse space. Uses an ask/tell loop: every counted trial is guaranteed unique and in-bin. Duplicates and out-of-bin suggestions are told back as PRUNED (the sampler still learns from them but they don't consume --n-trials). A SQLite study (optuna.db) is written to the output directory.
  • search/search_spaces/pythia.py, search/search_spaces/qwen.py — sampler classes shared by both search backends. SearchSpaceQwenCoarse extends SearchSpacePythiaCoarse to add a query-groups dimension and enforce GQA legality.
  • search/modules/ — dataloader and evaluation helpers.
  • configs/pretrain/, configs/distillation/ — YAML configs per experiment.
  • launch/ — SLURM submission wrappers (edit PARTITION_PLACEHOLDER and the conda env name before use).

Config file name format

The yaml config files should be named using the following format: exp_{inhert/randominit}_{basemodelname}_{evolutionary/optuna}_search_coarse_bin_{0,1,2}.yaml E.g.:

exp_randominit_Pythia-6.9b_evolutionary_search_coarse_bin_2.yaml
exp_inherit_Pythia-12b_optuna_search_coarse_bin_0.yaml

Running search

Evolutionary (coarse)

PYTHONPATH=.:search python -m torch.distributed.launch --nproc_per_node=1 --use_env \
  search/evo/evo_search_coarse.py \
    --model_id EleutherAI/pythia-6.9b --seq_len 512 --batch-size 2 \
    --bin_to_search 0 --max-epochs 100 \
    --lower-bounds 350 900 2000 --upper-bounds 410 1500 3500

Optuna (coarse)

PYTHONPATH=.:search python -m torch.distributed.launch --nproc_per_node=1 --use_env \
  search/lib_optuna/optuna_search_coarse.py \
    --model_id EleutherAI/pythia-6.9b --seq_len 512 --batch-size 2 \
    --bin_to_search 0 --n-trials 500 --sampler tpe \
    --lower-bounds 350 900 2000 --upper-bounds 410 1500 3500

Useful flags shared by all search scripts:

  • --lower-bounds / --upper-bounds — per-bin parameter-count limits in millions, indexed by --bin_to_search. Both take nargs="+" (space- separated floats) and must have equal length. Defaults match the Pythia presets [350 900 2000] / [410 1500 3500]. For Qwen3 use [1638 3685 5000] / [2457 4504 7000]. The launcher scripts under launch/search/ auto-select these from MODEL_ID.
  • --bin_to_search {0,1,2} — which parameter-count bucket to search.
  • --no-bf16 — keep the supernet in float32 (bf16 is on by default).
  • --wandb — enable Weights & Biases logging (every evaluated trial is logged as trial/*, and the running best as incumbent/*). Configure with --wandb-project, --wandb-entity, --wandb-run-name, and --wandb-mode {offline,online,disabled} (default offline — logs to <output_dir>/wandb/; sync later via wandb sync <output_dir>/wandb/offline-run-*). Requires wandb installed. In DDP runs only rank 0 logs.

Optuna-specific flags:

  • --n-trials — trials to run this invocation (default 500).
  • --timeout — wall-clock budget in seconds (default 6 hours; pass 0 to disable).
  • --sampler {tpe,random} — Optuna sampler (default tpe).
  • --study-name, --storage — override the default SQLite-backed study.

Both scripts infer the search space (pythia or qwen) from --model_id and error out if neither substring is present.

Both scripts write the running incumbent (bin, trial number, ppl, parameter count, and sub-network dims) at every checkpoint interval, and a final snapshot once the search finishes. Two sidecar files are produced per snapshot:

  • <output_dir>/<stem>_arch_config.yaml — the search-space record.
  • <output_dir>/<stem>_litgpt_config.yaml — a full litgpt.Config dict loadable via Config.from_file to instantiate the same architecture.

<stem> is incumbent for in-progress checkpoints and final_best for the post-search snapshot.

Search output directory

If --output_dir is not supplied, both scripts derive a unique path per model, per bin, per hyperparameter configuration, per run:

  • Evolutionary: out/evolutionary_search_coarse_<pythia|qwen>/<model_id>/bin_<N>/epochs_<E>_pop_<P>_sel_<S>_mut_<M>_cross_<C>_mprob_<mp>_sprob_<sp>_<timestamp>/
  • Optuna: out/optuna_search_coarse_<pythia|qwen>/<model_id>/bin_<N>/n_trials_<T>_timeout_<sec>_sampler_<tpe|random>_<timestamp>/

Pass --output_dir explicitly to override. For Optuna, a fresh timestamped dir also means a fresh SQLite study — to resume a prior study, pass either --output_dir <old_dir> or --storage sqlite:///<old_dir>/optuna.db (optionally together with --study-name).

Pretrain and distill

# From scratch
PYTHONPATH=.:search python pretrain.py \
  --config configs/pretrain/exp_from_scratch_pythia_410m.yaml \
  --data LitData --data.data_path ./nemotron_data \
  --data.split_names '["train","val"]' --resume auto

# Pretrain a searched subnet (init_from points at an extracted .pth)
PYTHONPATH=.:search python pretrain.py \
  --config configs/pretrain/exp_inherit_Pythia-6.9b_evolutionary_search_coarse_bin_0.yaml \
  --data LitData --data.data_path data/Nemotron/GPT-NeoX/ --resume auto

# Distill from a teacher checkpoint
PYTHONPATH=.:search python distill.py \
  --config configs/distillation/exp_from_supernet_6.9b_subnet_config_evolutionary_search_coarse_2_100_epochs.yaml \
  --data LitData --data.data_path data/Nemotron/GPT-NeoX/ \
  --teacher_checkpoint_dir checkpoints/EleutherAI/pythia-6.9b/ --resume auto

Submitting subnet pretraining to SLURM

launch/pretrain/pretrain_subnet.py wraps the call to sbatch — it validates arguments (including checking that --data-path exists and that the config's init_from is consistent with --weight-init), resolves the config path from the supplied flags, prints the resolved config (including init_from) and output directory in a table, prompts for confirmation, and then exports the env vars (CONFIG_FILE, OUT_DIR, DATA_PATH, MAX_TOKENS) that launch/pretrain/job.sh expects.

python launch/pretrain/pretrain_subnet.py \
  --weight-init inherit \
  --base-model EleutherAI/Pythia-6.9b \
  --search evolutionary \
  --space coarse \
  --bucket 0 \
  --max-tokens 10000000000 \
  --partition mldlc2_gpu-h200

Arguments:

  • --weight-init {randominit,inherit}randominit trains from scratch (the config's init_from is expected to be scratch; a warning is printed otherwise). inherit loads the extracted subnet weights from the supernet; the config's init_from must point to an existing .pth file.
  • --base-model — of the form <org>/<model> where <model> starts with Qwen3 or Pythia (e.g. EleutherAI/Pythia-6.9b, Qwen/Qwen3-8B). The <model> part is used in the config filename and output path.
  • --search {optuna,evolutionary} — which search produced the subnet config.
  • --space {coarse} — the search space.
  • --bucket {0,1,2} — parameter-count bin (0=small, 1=mid, 2=large).
  • --max-tokens — training token budget (int); also included in the output dir as <n>B.
  • --partition — SLURM partition.
  • --experiment-name (optional, default pretrain) — prefix for the output directory.
  • --data-path (optional, default ./nemotron_data) — training data dir, exported as DATA_PATH. Must exist.

The script resolves the config file as configs/pretrain/exp_{weight_init}_{model}_{search}_search_{space}_bin_{bucket}_tokens_{max_tokens}B.yaml and the output directory as out/{experiment_name}_{max_tokens}B/{model}/weight_init_{weight_init}/subnet_{search}_search_{space}_bin_{bucket} (where {model} is the part of --base-model after the /).

The equivalent bash entry point is launch/pretrain/pretrain_subnet.sh (positional arguments in the same order as the Python flags).

SLURM

Searching for subnets

# Launch evolutionary search
MODEL_ID=EleutherAI/pythia-6.9b sbatch launch/search/launch_evo_search_coarse.sh -p <partition_name>
# Launch TPE search with Optuna
MODEL_ID=Qwen/Qwen3-32B sbatch launch/search/launch_optuna_search_coarse.sh -p <partition_name>
# Launch random search with Optuna
MODEL_ID=EleutherAI/pythia-6.9b sbatch launch/search/launch_random_search_coarse.sh -p <partition_name>

The launch scripts require MODEL_ID to contain either /pythia or /Qwen3 (defaults to EleutherAI/pythia-6.9b) and use the substring to auto-select the per-bin parameter-count bounds: Pythia uses [350 900 2000] / [410 1500 3500], Qwen3 uses [1638 3685 5000] / [2457 4504 7000]. Anything else exits with an error.

Pretraining discovered subnets

The config file that is generated based on the command line arguments must be present in ./configs/pretrain/. See the config name format

python launch/pretrain/pretrain_subnet.py --weight-init inherit --base-model EleutherAI/Pythia-6.9b \
  --search evolutionary --space coarse --bucket 0 \
  --max-tokens 10000000000 --partition mldlc2_gpu-h200

The search launch scripts are SLURM arrays over bins 0-2.

About

The repository for all the experiments in the Whittle paper

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors