Skip to content

dengyl20/SR2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

2 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

SR2: Selection, Reflection and Self-Refinement

SR2 Framework Overview

arXiv PyTorch Python 3.12 CUDA 12.8

๐Ÿ“ฐ News / Updates

  • [2026.01] ๐ŸŽ‰ Congratulations! Our work has been accepted to ICLR 2026 as a poster.

๐Ÿ“˜ Overview

This repository provides the official implementation of the paper:

Selection, Reflection and Self-Refinement: Revisit Reasoning Tasks via a Causal Lens
https://arxiv.org/abs/2510.08222

We introduce a framework called SRยฒ that incorporates estimated latent variables as feedback into a selection mechanism, enabling the learning of dense dependencies among latent representations. The framework is composed of three key modules below.

  1. Reflective Representation Learning: Learns latent variables that capture structured reasoning signals.
  2. Dependency Self-Refinement: Iteratively refines latent dependencies using the selection feedback.
  3. Periodic Intermediate Alignment: Aligns intermediate representations with causal structure to stabilize training.

Experimentally, SRยฒ delivers substantial gains in reasoning accuracy. For example, on Sudoku and Maze tasks, SRยฒ achieves over 10% improvement in performance while using 8ร— fewer parameters compared with recent strong baselines.

Note

This project is built on top of the HRM repository (sapientinc/HRM). We reuse its attention layer designs, optimizer, embeddings, and most of the hyperparameters, while replacing the model architecture and training procedure with the SRยฒ framework.

๐Ÿงฉ Reasoning Task Illustration

Sudoku Reasoning and Selection Mechanism

Illustration of reasoning tasks and the selection mechanism using Sudoku as an example.

  • (a) A sample $9 \times 9$ Sudoku puzzle with a subset of given clues; the goal is to fill the remaining cells so that each row, column, and $3 \times 3$ subgrid contains the digits $1$โ€“$9$ exactly once.
  • (b) A single unfilled cell $Y_{ij}$ with its row (purple), column (blue), and $3 \times 3$ block (orange) highlighted. The digits within these groups impose constraints that determine the admissible values for $Y_{ij}$.
  • (c) Selection mechanism: a candidate value $Y$ is valid if and only if the validity criteria are satisfied:
    $S^i_{Row} = S^j_{Col} = S^b_{Block} = 1$

๐Ÿ› ๏ธ Environment Setup

Our main experiments were conducted on:

  • Sudoku / Maze: 8 ร— AMD MI210 (ROCm 6.2)
  • ARC-1 / ARC-2: 8 ร— NVIDIA H200 NVL (CUDA 12.8)

Important

For reproducing ARC-1 and ARC-2 experiments, we strongly recommend:

  • CUDA โ‰ฅ 12.8
  • NVIDIA Hopper-class GPUs
    to take full advantage of FlashAttention-3 for efficient training.

Below we assume that CUDA 12.8 is properly installed and configured.

๐Ÿ”น 1. Clone the repository

git clone https://github.com/dengyl20/SR2.git
cd SR2

๐Ÿ”น 2. Create a new virtual environment

conda create -n sr2 python==3.12
conda activate sr2

๐Ÿ”น 3. Install PyTorch (recommended: 2.5.0 or later)

pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 \
  --index-url https://download.pytorch.org/whl/cu128

Note

Adjust the PyTorch version and index URL if your CUDA setup differs. The above command assumes CUDA 12.8 wheels are available from the official PyTorch index.

๐Ÿ”น 4. Install additional dependencies

pip install -r requirements.txt

๐Ÿ”น 5. (Optional) Install FlashAttention-3 and adam-atan2

These components are optional but strongly recommended for faster and more stable training.

# adam-atan2
pip install packaging ninja wheel setuptools setuptools-scm
pip install --no-cache-dir --no-build-isolation adam-atan2

# Flash Attention 3
git clone git@github.com:Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install

Tip

If you are using ROCm or a non-Hopper GPU, you may need to skip FlashAttention-3 or use a backend compatible with your hardware.

๐Ÿš€ Quick Start

After the environment is configured, you can quickly reproduce the Sudoku and Maze results from the paper.

Note

For ARC-1 and ARC-2, please refer to arc-agi/README.md for dedicated instructions and configuration details.

๐Ÿ”น 1. Build datasets

We directly reuse the data processing scripts and original data from the HRM repository.

# Initialize submodules 
git submodule update --init --recursive

# Sudoku-Extreme (1000 examples with augmentation)
python dataset/build_sudoku_dataset.py \
  --output-dir data/sudoku-extreme-1k-aug-1000 \
  --subsample-size 1000 \
  --num-aug 1000

# Maze (1000 examples)
python dataset/build_maze_dataset.py

Important

Ensure all submodules have been initialized successfully. Missing raw data or scripts will cause dataset preparation to fail.

๐Ÿ”น 2. Configure Weights & Biases (W&B)

We use Weights & Biases (wandb.ai) to log training curves, evaluation metrics, and experiment configurations.

Edit config/cfg_pretrain.yaml and replace the placeholder:

wandb_key: <your_wandb_key>

with your actual W&B API key.

Tip

You can obtain your W&B key from your user settings page on https://wandb.ai. Make sure you are logged in when launching experiments.

๐Ÿ”น 3. Run experiments

# Train on Sudoku-Extreme
bash pretrain_sudoku.sh

# Train on Maze-Hard
bash pretrain_maze.sh

These scripts will:

  • Load the the model architecture.
  • Prepare the corresponding dataset splits.
  • Launch training and periodic evaluation.
  • Log all metrics (e.g., all.exact_accuracy) to W&B.

Note

Hyperparameters largely follow HRM settings. Only changes relevant to the SRยฒ framework (architecture and training dynamics) are introduced.

๐Ÿ“Š Evaluation

  1. Training curves and evaluation metrics

    • All training metrics and evaluation metrics (all.exact_accuracy), can be found in the corresponding W&B runs.
    • Use W&B dashboards to inspect convergence, stability, and performance across tasks.
  2. Reporting metrics

    • For SRยฒ and all baselines reported in the paper, we select the peak value of the evaluation curve (best all.exact_accuracy) rather than the last evaluation point.
    • This follows the evaluation protocol used in the paper to fairly compare model capacities.
  3. Reproducibility

    • We provide W&B reports and model checkpoints as reference.
    • Under comparable environment settings (hardware, CUDA/ROCm, and software versions), repeated runs should exhibit variation within approximately 1% absolute accuracy.

Caution

Differences in GPU architecture, CUDA/ROCm versions, or third-party library implementations (e.g., FlashAttention versions) may introduce minor deviations in the final metrics.

๐Ÿ™ Acknowledgements

We gratefully acknowledge:

Our project framework is built on top of these two excellent codebases.

๐Ÿ“š Citation

If you find this repository useful in your research, please consider citing:

@inproceedings{deng2026selection,
  title     = {Selection, Reflection and Self-Refinement: Revisit Reasoning Tasks via a Causal Lens},
  author    = {Yunlong Deng and Boyang Sun and Yan Li and Zeyu Tang and Lingjing Kong and Kun Zhang and Guangyi Chen},
  booktitle = {The Fourteenth International Conference on Learning Representations},
  year      = {2026},
  url       = {https://openreview.net/forum?id=0X5moS8KSm}
}

About

Code for the paper "Selection, Reflection and Self-Refinement: Revisit Reasoning Tasks via a Causal Lens"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors