Official code for Steering Language Models Before They Speak: Logit-Level Interventions.
SWAI (Statistical Writing-style Aligned Inference) is a training-free, lookup-based logit-steering method for autoregressive language models. Default backbone: Llama-3.1-8B.
SWAI steers an LM toward (or away from) a target attribute — difficulty level, politeness, toxicity, etc. — by injecting a precomputed token-level score bias into the logits at every decoding step, via a HuggingFace LogitsProcessor.
The pipeline is three steps:
- Build score table — collect per-token attribute scores from labeled corpora.
- Steered generation — bias the LM's next-token distribution with the score table.
- Evaluation — judge attribute control with GPT-4o (individual or triplet mode).
No gradient updates. No fine-tuning. No extra forward passes. Just a vocabulary-sized bias vector added to logits.
- 🚫 Training-free. Works on any open LM without parameter updates.
- ⚡ Zero inference overhead. A precomputed bias vector — no extra forward/backward passes.
- 🔌 Plug-and-play. Drops into HuggingFace
generate()as a standardLogitsProcessor. - 🔄 Bidirectional control. Reuses the same score table to amplify or suppress an attribute.
- 🎯 Interpretable. Per-token statistics — easy to inspect which tokens drive the steering.
pip install torch transformers openai tqdm numpy
huggingface-cli login # required for gated models like Llama-3.1GPU with sufficient VRAM is required (Llama-3.1-8B loads in fp16 by default).
python build_scores.py \
--dataset ose \
--tokenizer meta-llama/Llama-3.1-8B \
--output_dir lookup/llama3.1-8b/oseProduces per-class score files (e.g. scores_E.json, scores_I.json, scores_A.json for OSE).
python logit_steering.py \
--dataset ose \
--experiment rewriting \
--target E \
--score_dir lookup/llama3.1-8b/ose \
--data_path dataset/ose.json \
--output_dir outputs/ose \
--max_new_tokens 800 \
--num_samples 100Key arguments:
| Argument | Description |
|---|---|
--dataset |
ose, wikipol, real_tox |
--experiment |
rewriting or open_ended |
--target |
OSE: E/I/A · WikiPol: I/P/N · RealTox: T/NT |
--score_dir |
Directory with score files from Step 1 |
--num_samples |
Samples to process (-1 = all) |
--seed |
Random seed for reproducibility |
export OPENAI_API_KEY="..."
python judge_triplet.py \
--input outputs/ose/ose_rewriting_E.json \
--output eval_results/ose_rewriting_E_judged.json \
--prompt judge_system_prompts/ose_sys_prompt.txt \
--use-triplet \
--max-triplets 500Two judging modes:
- Triplet (
--use-triplet, default for OSE) — judges E/I/A texts of the same subject together to determine relative difficulty. - Individual (
--no-triplet, default for WikiPol/RealTox) — judges each text independently.
swai/
├── build_scores.py # Step 1: corpus → per-token score table
├── logit_steering.py # Step 2: steered AR generation (LogitsProcessor)
├── judge_triplet.py # Step 3: GPT-4o evaluation (individual / triplet)
└── utils.py # Score loading, AttributeSteeringProcessor, seeding
| Dataset | Task | Targets |
|---|---|---|
| OSE (OneStopEnglish) | Difficulty-controlled rewriting / continuation | E, I, A |
| WikiPol | Politeness-controlled paraphrasing | I, P, N |
| RealToxicityPrompts | Detoxified paraphrasing | T, NT |
- Default backbone:
meta-llama/Llama-3.1-8B(modify inlogit_steering.pyto swap). - Steering hyperparameters (
K,rho,delta,lambda_max,score_clip) are set in code — see logit_steering.py. - The evaluation step incurs OpenAI API usage costs.
See LICENSE.