Skip to content

Align KTO with DPO: Precompute reference log probs at init#5447

Merged
albertvillanova merged 4 commits into
huggingface:mainfrom
albertvillanova:align-kto-dpo-precompute_ref_log_probs
Apr 3, 2026
Merged

Align KTO with DPO: Precompute reference log probs at init#5447
albertvillanova merged 4 commits into
huggingface:mainfrom
albertvillanova:align-kto-dpo-precompute_ref_log_probs

Conversation

@albertvillanova

@albertvillanova albertvillanova commented Apr 3, 2026

Copy link
Copy Markdown
Member

Align KTO with DPO: Precompute reference log probs at init.

This PR refactors the logic for precomputing reference log probabilities in the KTOTrainer class. The main improvement is the consolidation of duplicated code for precomputing reference log probabilities for both training and evaluation datasets into a single reusable method, which simplifies the codebase and reduces redundancy.

Part of:

Motivation

KTOTrainer precomputed reference log probs lazily inside get_train_dataloader and get_eval_dataloader overrides, guarded by _precomputed_train_ref_log_probs / _precomputed_eval_ref_log_probs flags.

DPOTrainer already uses the correct pattern: precompute in __init__ via a dedicated _precompute_ref_logps helper, so no dataloader overrides are needed at all.

This PR aligns KTOTrainer with that approach:

  • Remove get_train_dataloader and get_eval_dataloader overrides
  • Remove _precomputed_train_ref_log_probs and _precomputed_eval_ref_log_probs flags
  • Add _precompute_reference_log_probs(dataset, batch_size, desc) -> Dataset helper
  • Call it from __init__ (after super().__init__(), once self.accelerator is available) for both train and eval datasets

Changes

Refactoring and code simplification:

  • Removed the duplicated logic from get_train_dataloader and get_eval_dataloader by introducing a unified _precompute_reference_log_probs method that handles reference log probability computation for any given dataset.
  • The new _precompute_reference_log_probs method is now called during initialization if precompute_ref_log_probs is enabled, updating both train_dataset and eval_dataset as needed.

Cleanup:

  • Removed unused flags and state variables (_precomputed_train_ref_log_probs, _precomputed_eval_ref_log_probs) and related comments, as precomputation is now handled directly and only once during initialization.

Note

Medium Risk
Moves reference log-prob precomputation from lazy dataloader overrides to __init__, changing when datasets are mutated and when the extra compute/memory cost is paid. This could affect training startup time and any code relying on the old dataloader override behavior.

Overview
Aligns KTOTrainer with other trainers by precomputing reference log probabilities during initialization when precompute_ref_log_probs is enabled, instead of doing it lazily inside overridden dataloader methods.

This removes the get_train_dataloader/get_eval_dataloader overrides and the _precomputed_* flags, and introduces a single _precompute_reference_log_probs() helper that adds reference_logps (and reference_KL_logps when applicable) to the train dataset and to eval datasets (including dict-of-datasets eval setups).

Written by Cursor Bugbot for commit 34e59c1. This will update automatically on new commits. Configure here.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Comment thread trl/experimental/kto/kto_trainer.py Outdated
if self.eval_dataset is not None:
self.eval_dataset = self._precompute_reference_log_probs(
self.eval_dataset, self.args.per_device_eval_batch_size, "Eval dataset reference log probs"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated precomputation pattern not propagated to BCO trainer

Low Severity

The lazy precomputation pattern via get_train_dataloader/get_eval_dataloader with _precomputed_train_ref_log_probs/_precomputed_eval_ref_log_probs flags was shared between KTOTrainer and BCOTrainer. This PR updates the KTO pattern to eager init-time precomputation but does not propagate the same change to BCOTrainer, which still uses the old lazy approach. Per AGENTS.md, when duplicated logic is modified in one trainer, the same change must be applied to all other trainers that share the pattern.

Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine IMO

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will address this in a separate PR. This one is specifically to align KTO with DPO.

@qgallouedec qgallouedec left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@codex review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1ea2dadb66

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread trl/experimental/kto/kto_trainer.py Outdated
@albertvillanova albertvillanova merged commit 0b05331 into huggingface:main Apr 3, 2026
5 checks passed
@albertvillanova albertvillanova mentioned this pull request Apr 8, 2026
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants