Skip to content

hsannn/swai

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SWAI

Official code for Steering Language Models Before They Speak: Logit-Level Interventions.

SWAI (Statistical Writing-style Aligned Inference) is a training-free, lookup-based logit-steering method for autoregressive language models. Default backbone: Llama-3.1-8B.

SWAI overview


Overview

SWAI steers an LM toward (or away from) a target attribute — difficulty level, politeness, toxicity, etc. — by injecting a precomputed token-level score bias into the logits at every decoding step, via a HuggingFace LogitsProcessor.

The pipeline is three steps:

  1. Build score table — collect per-token attribute scores from labeled corpora.
  2. Steered generation — bias the LM's next-token distribution with the score table.
  3. Evaluation — judge attribute control with GPT-4o (individual or triplet mode).

No gradient updates. No fine-tuning. No extra forward passes. Just a vocabulary-sized bias vector added to logits.


Why SWAI

  • 🚫 Training-free. Works on any open LM without parameter updates.
  • Zero inference overhead. A precomputed bias vector — no extra forward/backward passes.
  • 🔌 Plug-and-play. Drops into HuggingFace generate() as a standard LogitsProcessor.
  • 🔄 Bidirectional control. Reuses the same score table to amplify or suppress an attribute.
  • 🎯 Interpretable. Per-token statistics — easy to inspect which tokens drive the steering.

Setup

pip install torch transformers openai tqdm numpy
huggingface-cli login    # required for gated models like Llama-3.1

GPU with sufficient VRAM is required (Llama-3.1-8B loads in fp16 by default).


Pipeline

1. Build the score table

python build_scores.py \
    --dataset ose \
    --tokenizer meta-llama/Llama-3.1-8B \
    --output_dir lookup/llama3.1-8b/ose

Produces per-class score files (e.g. scores_E.json, scores_I.json, scores_A.json for OSE).

2. Steered generation

python logit_steering.py \
    --dataset ose \
    --experiment rewriting \
    --target E \
    --score_dir lookup/llama3.1-8b/ose \
    --data_path dataset/ose.json \
    --output_dir outputs/ose \
    --max_new_tokens 800 \
    --num_samples 100

Key arguments:

Argument Description
--dataset ose, wikipol, real_tox
--experiment rewriting or open_ended
--target OSE: E/I/A · WikiPol: I/P/N · RealTox: T/NT
--score_dir Directory with score files from Step 1
--num_samples Samples to process (-1 = all)
--seed Random seed for reproducibility

3. Evaluation (GPT-4o judge)

export OPENAI_API_KEY="..."

python judge_triplet.py \
    --input outputs/ose/ose_rewriting_E.json \
    --output eval_results/ose_rewriting_E_judged.json \
    --prompt judge_system_prompts/ose_sys_prompt.txt \
    --use-triplet \
    --max-triplets 500

Two judging modes:

  • Triplet (--use-triplet, default for OSE) — judges E/I/A texts of the same subject together to determine relative difficulty.
  • Individual (--no-triplet, default for WikiPol/RealTox) — judges each text independently.

Repository Layout

swai/
├── build_scores.py     # Step 1: corpus → per-token score table
├── logit_steering.py   # Step 2: steered AR generation (LogitsProcessor)
├── judge_triplet.py    # Step 3: GPT-4o evaluation (individual / triplet)
└── utils.py            # Score loading, AttributeSteeringProcessor, seeding

Supported Datasets

Dataset Task Targets
OSE (OneStopEnglish) Difficulty-controlled rewriting / continuation E, I, A
WikiPol Politeness-controlled paraphrasing I, P, N
RealToxicityPrompts Detoxified paraphrasing T, NT

Notes

  • Default backbone: meta-llama/Llama-3.1-8B (modify in logit_steering.py to swap).
  • Steering hyperparameters (K, rho, delta, lambda_max, score_clip) are set in code — see logit_steering.py.
  • The evaluation step incurs OpenAI API usage costs.

License

See LICENSE.

About

Official code for controllable generation via logit steering

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages