Official PyTorch implementation for the paper AutoJudge: Judge Decoding Without Manual Annotation
Our approach introduces an algorithm for automatically identifying important token mismatches in model generations. We extract hidden states for these tokens, train a lightweight classifier to detect them, and employ it during inference.
To reproduce our results, follow these steps:
- Run the dataset mining script
- Calculate hidden states
- Train the classifier
- Run evaluations
🤗 Mined datasets available: preprocessed GSM8K & LiveCodeBench artifacts for the most compute-intensive stages (Dataset Mining & Hidden States Calculating) are available at Hugging Face mightyneighbor/Autojudge, so you can skip the first two steps for these setups!
Install packages from requirements.txt:
pip install -r requirements.txtHere we provide a small snippet of how to run dataset mining for GSM8K and LiveCodeBench, for the detailed instructions including multiple-gpu run please refer to the find_important_tokens_gsm8k.sh and find_important_tokens_lcb.sh scripts.
export TORCH_DTYPE=auto
export GSM8K_TRAIN=data/train_small.jsonl # replace by data/train.jsonl for full run
export RANDOM_SEED=42
export MAX_NEW_TOKENS=2048
export OUTPUT_FOLDER=output
export OUTPUT_FILE=important_tokens
export DUMP_FREQ=64
mkdir $OUTPUT_FOLDER
# one-gpu run
CUDA_VISIBLE_DEVICES=0 python3 src/find_important_tokens.py \
--draft_model $MODEL0 \
--target_model $MODEL1 \
--torch_dtype $TORCH_DTYPE \
--gsm8k_train_path $GSM8K_TRAIN \
--random_seed $RANDOM_SEED \
--max_new_tokens $MAX_NEW_TOKENS \
--output_folder $OUTPUT_FOLDER \
--output_file $OUTPUT_FILE \
--dump_freq $DUMP_FREQ
rm output/done*export TORCH_DTYPE=auto
export GSM8K_TRAIN=data/train_small.jsonl # replace by data/train.jsonl for full run
export RANDOM_SEED=42
export MAX_NEW_TOKENS=2048
export OUTPUT_FOLDER=output
export OUTPUT_FILE=important_tokens_lcb
export DUMP_FREQ=64
export NUM_PROCESS_EVALUATE=64
export N_TASKS=2 # will use 2 tasks for short demo, set 880 for full lcb release_v5 dataset
export TOTAL_GPUS=1
mkdir $OUTPUT_FOLDER
# one-gpu run
CUDA_VISIBLE_DEVICES=0 python3 src/find_important_tokens_lcb.py \
--draft_model $MODEL0 \
--target_model $MODEL1 \
--torch_dtype $TORCH_DTYPE \
--random_seed $RANDOM_SEED \
--max_new_tokens $MAX_NEW_TOKENS \
--output_folder $OUTPUT_FOLDER \
--output_file $OUTPUT_FILE \
--dump_freq $DUMP_FREQ \
--n_tasks $N_TASKS \
--num_process_evaluate $NUM_PROCESS_EVALUATE \
--total_gpus $TOTAL_GPUS🧮 Calculating hidden states ⚙️
For the full script including multiple-gpus run please refer to the calc_hiddens.sh script.
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export TORCH_DTYPE=auto
export BATCH_SIZE=8
export DATA_FILE=output/important_tokens.pt
export OUTPUT_PATH=output/important_tokens_with_hiddens
export SAVE_FREQ=128
export N_PROCESSES=1
# single gpu run
CUDA_VISIBLE_DEVICES=0 python src/calc_hiddens.py \
--draft_model $MODEL0 \
--target_model $MODEL1 \
--torch_dtype $TORCH_DTYPE \
--batch_size $BATCH_SIZE \
--data_file $DATA_FILE \
--output_path $OUTPUT_PATH \
--save_freq $SAVE_FREQ \
--n_processes $N_PROCESSES \
--process_id 0 export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
python src/train_head_gsm8k_nirvana.py \
--random_seed 52 \
--train_size 0.9 \
--data_path output/important_tokens_with_hiddens.pt \
--checkpoint_path output/trained_head.pkl \
--target_model $MODEL1 \
--draft_model $MODEL0 \
--setup DD-DT \
--train_on_all_data
# --convert_to_vllm - add this option to get head that can be used with our vllm patchTo convert trained head to our vllm format you can run the following python script:
import pandas as pd
import os
import pickle
checkpoint = pd.read_pickle('trained_head.pkl')
head = checkpoint['model']
scaler = checkpoint['scaler']
target_hidden_size = 4096 # 4096 for Llama-3.1-8B-Instruct and 8192 for Llama-3.1-70B-Instruct
head_dict = dict(
mean=scaler.mean_[-target_hidden_size:],
scale=scaler.scale_[-target_hidden_size:],
weights=head.coef_[0][-target_hidden_size:],
bias=head.intercept_[-target_hidden_size:],
thr=0.25 #
)
vllm_checkpoint_path = 'vllm_compatible_head.pkl'
with open(vllm_checkpoint_path, 'wb') as f:
dump_dict = head_dict
pickle.dump(dump_dict, f)More scripts to be uploaded later.
Here we provide evaluation example for GSM8K, similar scripts were used to obtain main results on LiveCodeBench. To run it, please refer to eval/run_lcb_folds.py and eval/run_lcb_topk.py. There for each threshold(ours) and K(for baseline) values we also vary FOLD_ID since we use out-of-fold technique.
export START_IDX=0
export END_IDX=10 # set 1319 for full eval
export THR_ID=0 # vary this from 0 to 25, thresholds for inference are selected automatically in train scripts
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export NUM_SHOTS=0 # set 8 for 8 shot setup
export MAX_NEW_TOKENS=1024
# Running eval on 2 gpus
CUDA_VISIBLE_DEVICES=0,1 python3 eval/gpu_parallel.py --gpus_per_script 1 --start $START_IDX --end $END_IDX --use_queue --script eval/run_inference_task.py --extra_args "--save_folder output/eval_$THR_ID --gsm8k_test_path data/gsm8k_test.json --torch_dtype auto --window_size 64 --head_path output/trained_head.pkl --setup DD-DT --max_new_tokens $MAX_NEW_TOKENS --num_shots $NUM_SHOTS --head_threshold_idx $THR_ID --draft_model $MODEL0 --target_model $MODEL1"export START_IDX=0
export END_IDX=10 # set 1319 for full eval
export K=2048 # to be varied, we considered the following values [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 128256]
export MODEL0="meta-llama/Llama-3.2-1B-Instruct"
export MODEL1="meta-llama/Llama-3.1-8B-Instruct"
export NUM_SHOTS=0 # set 8 for 8 shot setup
export MAX_NEW_TOKENS=1024
# Running eval on 2 gpus
CUDA_VISIBLE_DEVICES=0,1 python3 eval/gpu_parallel.py --gpus_per_script 1 --start $START_IDX --end $END_IDX --use_queue --script eval/run_inference_topk_baseline_task.py --extra_args "--save_folder output/eval_baseline_$K --gsm8k_test_path data/gsm8k_test.json --torch_dtype auto --window_size 64 --head_path output/trained_head.pkl --setup DD-DT --max_new_tokens $MAX_NEW_TOKENS --num_shots $NUM_SHOTS --K $K --draft_model $MODEL0 --target_model $MODEL1"To make a final report based on the evaluation outputs you can use the following snippet:
import pandas as pd
import numpy as np
import os
def make_pareto_curve_df(data, group_by_col='thr'):
df = pd.DataFrame(data)
mean_accept = pd.DataFrame(
df.groupby(group_by_col).apply(lambda x: np.concatenate(x['raw_accepts'].tolist()).mean()),
columns=['mean_accept']
).reset_index()
gsm_acc = pd.DataFrame(
df.groupby(group_by_col).apply(lambda x: np.mean(x['tp'])),
columns=['gsm8k_acc']
).reset_index()
pareto_curve_df = pd.merge(left=mean_accept, right=gsm_acc, on=group_by_col).sort_values(by=[group_by_col])
return pareto_curve_df
AJ_DIRS = ['output/eval_0', 'output/eval_1'] # output/eval_2, ... output/eval_25
aj_data = []
for DIR in AJ_DIRS:
files = os.listdir(DIR)
aj_data.extend([pd.read_pickle(os.path.join(DIR, f)) for f in files])
autojudge_df = make_pareto_curve_df(aj_data)
print(autojudge_df)
BASELINE_DIRS = ['output/eval_baseline_0', 'output/eval_baseline_1'] # output/eval_baseline_2, ... output/eval_baseline_17
baseline_data = []
for DIR in BASELINE_DIRS:
files = os.listdir(DIR)
baseline_data.extend([pd.read_pickle(os.path.join(DIR, f)) for f in files])
baseline_df = make_pareto_curve_df(baseline_data, group_by_col='k')
print(baseline_df)Clone vllm repository and checkout commit a83a0f92b56b71855dc38e8e3d9809619e58bcd1.
Copy out patch file to the vllm repo and apply it: git apply vllm_patch.patch.
Install vllm with VLLM_USE_PRECOMPILED pip install -e path/to/vllm/folder.
Run commands
python vllm_gsm8k.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-8B-Instruct\
--draft_model meta-llama/Llama-3.2-1B-Instruct --judge_path vllm_heads/head_${SHOTS}shot_8b.pkl --judge_threshold $THRESHOLD --shots $SHOTSpython vllm_gsm8k.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-70B-Instruct\
--draft_model meta-llama/Llama-3.1-8B-Instruct --judge_path vllm_heads/head_${SHOTS}shot_70b.pkl --judge_threshold $THRESHOLD --shots $SHOTSto run evaluations on GSM8K dataset.
SHOTS can be either 0 or 8.
For example, to reproduce results for 0-shot 70B/8B model run
for threshold in 0.03719609313336198 0.07084856680433153 0.09208237305325259 0.13549077699786996 0.2209569576527778; do
python vllm_gsm8k.py -o results\
--target_model meta-llama/Llama-3.1-70B-Instruct\
--draft_model meta-llama/Llama-3.1-8B-Instruct\
--judge_path vllm_heads/head_0shot_70b.pkl --judge_threshold $threshold --shots 0
doneRun commands
LCB_MAX_MEMORY_BYTES=34359738368 python vllm_lcb.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-8B-Instruct\
--draft_model meta-llama/Llama-3.2-1B-Instruct --judge_path head_lcb_8b.pkl --judge_threshold $THRESHOLDLCB_MAX_MEMORY_BYTES=34359738368 python vllm_lcb.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-70B-Instruct\
--draft_model meta-llama/Llama-3.1-8B-Instruct --judge_path head_lcb_70b.pkl --judge_threshold $THRESHOLDIf you receive an error while performing multi-gpu 70B/8B evaluation, you can mitigate it by running one fold at a time:
LCB_MAX_MEMORY_BYTES=34359738368 python vllm_lcb.py -o $RESULTS_FILE --target_model meta-llama/Llama-3.1-70B-Instruct\
--draft_model meta-llama/Llama-3.1-8B-Instruct --judge_path head_lcb_70b.pkl --judge_threshold $THRESHOLD --fold $foldFor example, to reproduce 70B/8B evaluation on LCB dataset, run
for threshold in 0.0 0.05 0.075 0.1 0.125; do
for fold in 0 1 2 3 4; do
LCB_MAX_MEMORY_BYTES=34359738368 python vllm_lcb.py -o results --target_model meta-llama/Llama-3.1-70B-Instruct\
--draft_model meta-llama/Llama-3.1-8B-Instruct\
--judge_path head_lcb_70b.pkl --judge_threshold $threshold --fold $fold
done
doneMore scripts to be uploaded later.
If you found this work useful, please consider citing:
@misc{garipov2025autojudgejudgedecodingmanual,
title={AutoJudge: Judge Decoding Without Manual Annotation},
author={Roman Garipov and Fedor Velikonivtsev and Ruslan Svirschevski and Vage Egiazarian and Max Ryabinin},
year={2025},
eprint={2504.20039},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2504.20039},
}