Skip to content

xufangzhi/Genius

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

29 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Genius
Genius: A Generalizable and Purely Unsupervised Self-Training Framework For Advanced Reasoning

[📜 Paper][🤗 HF Models][🐱 GitHub]

Repo for "Genius: A Generalizable and Purely Unsupervised Self-Training Framework For Advanced Reasoning"

🔥 News

  • [2025/05/15] 🔥🔥🔥 Genius is accepted by ACL 2025 (Main Conference) !
  • [2025/05/01] 🔥🔥🔥 The model checkpoints are released in huggingface ! Please check it !

🔍 Core Implementation

The core implementation of Advantage Calibrated Optimization (ACO) loss function is presented below:

import torch
import torch.nn.functional as F
from typing import Tuple

def aco_loss(
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
    chosen_weights: torch.FloatTensor,
    rejected_weights: torch.FloatTensor,
    average_weights: torch.FloatTensor,
    beta: float,
    alpha: float = 1.0,
    reference_free: bool = False
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """
    Adaptive Comparison Optimization (ACO) Loss.

    Args:
        policy_chosen_logps (FloatTensor): Log-probs from policy model for chosen responses. Shape: (batch,)
        policy_rejected_logps (FloatTensor): Log-probs from policy model for rejected responses. Shape: (batch,)
        reference_chosen_logps (FloatTensor): Log-probs from reference model for chosen responses. Shape: (batch,)
        reference_rejected_logps (FloatTensor): Log-probs from reference model for rejected responses. Shape: (batch,)
        chosen_weights (FloatTensor): Preference weights for chosen responses.
        rejected_weights (FloatTensor): Preference weights for rejected responses.
        beta (float): Temperature parameter to scale loss sharpness.
        alpha (float): Relaxation scaling factor in adaptive weighting.
        reference_free (bool): If True, ignores reference model by assuming uniform logits.

    Returns:
        Tuple of (losses, chosen_rewards, rejected_rewards), each of shape (batch,)
    """

    if reference_free:
        reference_chosen_logps = torch.zeros_like(policy_chosen_logps)
        reference_rejected_logps = torch.zeros_like(policy_rejected_logps)

    # Log-ratio between policy and reference
    chosen_logratios = policy_chosen_logps - reference_chosen_logps
    rejected_logratios = policy_rejected_logps - reference_rejected_logps

    # Adaptive weighting based on preference gap
    adjustment_factor = torch.max(
        torch.tensor(1.0, device=chosen_logratios.device),
        torch.exp(-(rejected_weights - chosen_weights) / alpha)
    )
    rejected_logratios_weighted = adjustment_factor * rejected_logratios

    logits = chosen_logratios - rejected_logratios_weighted

    # Final ACO loss
    losses = -F.logsigmoid(beta * logits)

    chosen_rewards = beta * chosen_logratios.detach()
    rejected_rewards = beta * rejected_logratios.detach()

    return losses, chosen_rewards, rejected_rewards

🚀 Quick Start

To implement the foresight re-sampling, you can first use the following command to conduct the stepwise self-exploration:

python foresight-sampling/foresight_sampling_with_vllm.py \
  --model_path <path_to_your_model> \
  --data_path <path_to_your_queries> \
  --output_path <path_to_your_output_file> \
  --step_beam_size 2 \
  --num_rollout 4 \
  --num_foresight 4

After exploration, you can exploit and construct the preference data through resampling technique:

python foresight-sampling/construct_foresight_preference_data.py \
  --input_path <path_to_the_exploration_outputs> \
  --output_path <path_to_output_data_file>

To train the policy model with Advantage Calibrated Optimization (ACO) loss function, please refer to the following command:

# execute
bash ./aco-training/scripts/aco_train_with_accelerate.sh

📒 Note

We are open-sourcing the model weights and the code. The training scripts are based on open-instruct and the evaluation is supported and accelerated by vLLM engine.

Citation

If you find it helpful, please kindly cite our paper as well as the inference-time decoding algorithm $\phi$-Decoding:

@article{xu2025genius,
  title={Genius: A Generalizable and Purely Unsupervised Self-Training Framework For Advanced Reasoning},
  author={Xu, Fangzhi and Yan, Hang and Ma, Chang and Zhao, Haiteng and Sun, Qiushi and Cheng, Kanzhi and He, Junxian and Liu, Jun and Wu, Zhiyong},
  journal={arXiv preprint arXiv:2504.08672},
  year={2025}
}
@article{xu2025phi,
  title={$\phi$-Decoding: Adaptive Foresight Sampling for Balanced Inference-Time Exploration and Exploitation},
  author={Xu, Fangzhi and Yan, Hang and Ma, Chang and Zhao, Haiteng and Liu, Jun and Lin, Qika and Wu, Zhiyong},
  journal={arXiv preprint arXiv:2503.13288},
  year={2025}
}

About

[ACL 2025] A Generalizable and Purely Unsupervised Self-Training Framework

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •