Skip to content

Align KTO with DPO: Align compute_ref_log_probs#5852

Merged
albertvillanova merged 4 commits into
mainfrom
align-kto-dpo-compute_reference_log_probs
May 26, 2026
Merged

Align KTO with DPO: Align compute_ref_log_probs#5852
albertvillanova merged 4 commits into
mainfrom
align-kto-dpo-compute_reference_log_probs

Conversation

@albertvillanova

@albertvillanova albertvillanova commented May 26, 2026

Copy link
Copy Markdown
Member

Align KTO with DPO: Align compute_ref_log_probs.

Part of:

Changes

Refactoring and Naming Consistency:

  • Renamed the method compute_reference_log_probs to compute_ref_log_probs and updated all internal calls accordingly for clarity and consistency.

Code Simplification and Parameter Handling:

  • Changed the method parameter from padded_batch to inputs in compute_ref_log_probs, and updated all usages from padded_batch[...] to inputs[...] to streamline the function interface.

Note

Low Risk
Rename-only refactor in experimental KTO with no training or loss logic changes.

Overview
Renames compute_reference_log_probs to compute_ref_log_probs in the experimental KTO trainer and updates the precompute path to call the new name, matching the DPO trainer API.

The method now takes inputs instead of padded_batch; all tensor lookups use inputs[...]. Behavior is unchanged—reference completion (and optional KL) forward passes and get_batch_logps are the same.

Reviewed by Cursor Bugbot for commit 8959f80. Bugbot is set up for automated code reviews on this repo. Configure here.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 8959f80. Configure here.

def compute_reference_log_probs(self, padded_batch: dict) -> dict:
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
def compute_ref_log_probs(self, inputs):
"""Computes reference log probabilities for a single padded batch."""

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename not propagated to BCO trainer

Low Severity

The rename from compute_reference_log_probs to compute_ref_log_probs was applied only to the KTO trainer, but the BCO trainer still uses the old name compute_reference_log_probs with the old parameter name padded_batch. The project's AGENTS.md rule states that when modifying duplicated code across trainers, the same change must be applied to all other trainers, and "not propagating a change is a bug."

Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

Reviewed by Cursor Bugbot for commit 8959f80. Configure here.

@albertvillanova albertvillanova merged commit a25c07e into main May 26, 2026
6 checks passed
@albertvillanova albertvillanova deleted the align-kto-dpo-compute_reference_log_probs branch May 26, 2026 20:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants