Official codebase for the paper "GRAIL: Gradient-Reweighted Advantages for Reinforcement Learning with Verifiable Rewards", authored by Tej Deep Pala, Vernon Toh, and Soujanya Poria at the DeCLaRe Lab, Nanyang Technological University.
Reinforcement learning with verifiable rewards (e.g., GRPO) is a standard paradigm to enhance mathematical and logical reasoning in Large Language Models (LLMs). However, standard GRPO broadcasts a uniform sequence-level advantage scalar to all tokens in a rollout. This uniform credit assignment dilutes the gradient signal: flawed reasoning steps, intermediate derivation errors, and filler words receive the same advantage signal as pivotal, logically sound steps.
To solve this, we introduce Gradient-Reweighted Advantage (GRAIL), an intrinsic, token-wise advantage reweighting method. GRAIL uses gradient-activation saliency to assign larger advantage signal to tokens that are locally sensitive to the final answer.
Across five model architectures (Qwen3-4B/8B, OctoThinker-3B/8B-Short, R1-Distill-Llama-8B) evaluated on six mathematical reasoning benchmarks (MATH500, AIME24, AMC23, MinervaMATH, CollegeMATH, OlympiadBench), GRAIL consistently outperforms standard GRPO, yielding an average absolute accuracy boost of 3.60% and Pass@3 improvement of 3.05% without requiring any external process-level supervision (PRMs).
grail/
├── requirements.txt # System dependencies
├── README.md # Project documentation
├── src/
│ ├── train/
│ │ ├── grpo.py # Main trainer launcher
│ │ ├── grpo.sh # Training execution bash script
│ │ ├── grail_trainer.py # Customized GrailTrainer implementing loss logic
│ │ ├── rewards.py # Math outcome-based reward function
│ │ ├── training_configs/ # Directory for YAML configuration files
│ │ │ ├── deepspeed_zero2.yaml
│ │ │ ├── qwen3_grpo.yaml
│ │ │ ├── qwen3_grail.yaml
│ │ │ └── qwen3_oar_g.yaml
│ │ └── utils/ # Arguments, logs, and datasets utilities
│ └── eval/
│ ├── run_eval.py # Main benchmarking execution script
│ ├── run_eval.sh # Serving & evaluation pipeline orchestration
│ ├── start_vllm.sh # Serving policy models via vLLM
│ ├── eval_results_base.py # Combines results across datasets
│ ├── compute_checkpoint_stats.py # Standalone token-level saliency stats
│ ├── token_analysis.sh # Visualizes checkpoint saliency dynamics
│ ├── aggregate_and_plot.py # Plotting script for positional analysis
│ ├── convert_to_excel.py # Generates unified XLSX result reports
│ ├── compile_res.sh # Aggregates evaluation results
│ └── eval_data/ # Benchmark dataset JSONL files (AIME24, MATH, etc.)
Set up a Python environment (Conda recommended) and install dependencies:
conda create -n grail python=3.11 -y
conda activate grail
pip install -r requirements.txt
pip install flash-attn --no-build-isolationTrain model configurations using DeepSpeed ZeRO-2 and the accelerate launcher.
Modify variables (e.g. WANDB_API_KEY, HF_TOKEN) inside grpo.sh and launch:
cd src/train
bash grpo.shBy default, training leverages the config qwen3_grail.yaml which includes the following key parameters:
# GRAIL Parameters
use_grail: true
grail_std: 0.5 # Spread scaling factor (\sigma_w)
grail_mean: 1.0 # Baseline neutral weight (w_mean)
grail_w_min: 0.5 # Minimum weight bounds (w_min)
grail_w_max: 5.0 # Maximum weight bounds (w_max)
grail_leaf_source: "embeddings" # Saliency leaf source
grail_rollout_symmetry: "wrong" # Apply reweighting on: "all", "correct", or "wrong"Evaluations are computed over 6 reasoning suites (aime24, math, college_math, minerva_math, olympiadbench, amc23) using vLLM for high-throughput serving.
Setup your model path inside start_vllm.sh and run:
cd src/eval
bash start_vllm.shRun evaluations by targeting the served port (configured in run_eval.sh):
cd src/eval
bash run_eval.shCompile JSONL evaluation results into a single Excel sheet using compile_res.sh:
bash compile_res.shTo evaluate how token-level saliency weights change over training checkpoints, execute the token analysis suite:
cd src/eval
bash token_analysis.shThis pipeline:
- Runs compute_checkpoint_stats.py to extract token gradients and weights post-hoc.
- Runs aggregate_and_plot.py to generate plots showing the U-shape distribution of weights across normalized reasoning spans.
@misc{pala2026grail,
title={GRAIL: Gradient-Reweighted Advantages for Reinforcement Learning with Verifiable Rewards},
author={Tej Deep Pala and Vernon Toh and Soujanya Poria},
year={2026},
eprint={2606.04889},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2606.04889},
}