Align KTO with DPO: Precompute reference log probs at init#5447
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| if self.eval_dataset is not None: | ||
| self.eval_dataset = self._precompute_reference_log_probs( | ||
| self.eval_dataset, self.args.per_device_eval_batch_size, "Eval dataset reference log probs" | ||
| ) |
There was a problem hiding this comment.
Duplicated precomputation pattern not propagated to BCO trainer
Low Severity
The lazy precomputation pattern via get_train_dataloader/get_eval_dataloader with _precomputed_train_ref_log_probs/_precomputed_eval_ref_log_probs flags was shared between KTOTrainer and BCOTrainer. This PR updates the KTO pattern to eager init-time precomputation but does not propagate the same change to BCOTrainer, which still uses the old lazy approach. Per AGENTS.md, when duplicated logic is modified in one trainer, the same change must be applied to all other trainers that share the pattern.
Triggered by project rule: ../.ai/AGENTS.md
There was a problem hiding this comment.
I think I will address this in a separate PR. This one is specifically to align KTO with DPO.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1ea2dadb66
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
Align KTO with DPO: Precompute reference log probs at init.
This PR refactors the logic for precomputing reference log probabilities in the
KTOTrainerclass. The main improvement is the consolidation of duplicated code for precomputing reference log probabilities for both training and evaluation datasets into a single reusable method, which simplifies the codebase and reduces redundancy.Part of:
Motivation
KTOTrainer precomputed reference log probs lazily inside get_train_dataloader and get_eval_dataloader overrides, guarded by _precomputed_train_ref_log_probs / _precomputed_eval_ref_log_probs flags.
DPOTrainer already uses the correct pattern: precompute in
__init__via a dedicated _precompute_ref_logps helper, so no dataloader overrides are needed at all.This PR aligns KTOTrainer with that approach:
__init__(aftersuper().__init__(), once self.accelerator is available) for both train and eval datasetsChanges
Refactoring and code simplification:
get_train_dataloaderandget_eval_dataloaderby introducing a unified_precompute_reference_log_probsmethod that handles reference log probability computation for any given dataset._precompute_reference_log_probsmethod is now called during initialization ifprecompute_ref_log_probsis enabled, updating bothtrain_datasetandeval_datasetas needed.Cleanup:
_precomputed_train_ref_log_probs,_precomputed_eval_ref_log_probs) and related comments, as precomputation is now handled directly and only once during initialization.Note
Medium Risk
Moves reference log-prob precomputation from lazy dataloader overrides to
__init__, changing when datasets are mutated and when the extra compute/memory cost is paid. This could affect training startup time and any code relying on the old dataloader override behavior.Overview
Aligns
KTOTrainerwith other trainers by precomputing reference log probabilities during initialization whenprecompute_ref_log_probsis enabled, instead of doing it lazily inside overridden dataloader methods.This removes the
get_train_dataloader/get_eval_dataloaderoverrides and the_precomputed_*flags, and introduces a single_precompute_reference_log_probs()helper that addsreference_logps(andreference_KL_logpswhen applicable) to the train dataset and to eval datasets (including dict-of-datasets eval setups).Written by Cursor Bugbot for commit 34e59c1. This will update automatically on new commits. Configure here.