Skip to content

Align KTO with DPO: Add disable_gradient_checkpointing to ref model forward passes#5900

Merged
albertvillanova merged 3 commits into
mainfrom
align-kto-dpo-gradient_checkpointing
Jun 1, 2026
Merged

Align KTO with DPO: Add disable_gradient_checkpointing to ref model forward passes#5900
albertvillanova merged 3 commits into
mainfrom
align-kto-dpo-gradient_checkpointing

Conversation

@albertvillanova

@albertvillanova albertvillanova commented Jun 1, 2026

Copy link
Copy Markdown
Member

Align KTO with DPO: Add disable_gradient_checkpointing to ref model forward passes.

Part of:

This PR improves the handling of gradient checkpointing during reference log probability computations in the KTOTrainer by ensuring gradient checkpointing is disabled when running inference on the reference model. This prevents unnecessary memory usage and potential side effects during evaluation.

Solution

This PR adds the disable_gradient_checkpointing context manager to all reference model forward passes in KTOTrainer, matching the pattern already used in DPOTrainer.

Changes

Gradient checkpointing management:

  • Imported disable_gradient_checkpointing from models.utils and included it alongside other utility imports in kto_trainer.py.
  • Wrapped all reference model inference blocks (compute_ref_log_probs, _compute_loss_liger, and _compute_loss) with the disable_gradient_checkpointing context manager to ensure checkpointing is turned off during these computations.
    • For _compute_loss_liger, the torch.no_grad was missing entirely: the ref hidden states were computed with gradients enabled, which is a bug.

Note

Low Risk
Training-path correctness and memory behavior for ref inference only; no auth, data, or API surface changes.

Overview
KTOTrainer now matches DPOTrainer for reference-policy inference: all reference forward paths run inside disable_gradient_checkpointing (using training gradient_checkpointing_kwargs) so checkpointing stays off during no_grad ref log-prob work—less memory, fewer PyTorch warnings, and behavior consistent with other preference trainers.

That wrapper is applied in compute_ref_log_probs, the on-the-fly ref branch in _compute_loss, and the Liger path in _compute_loss_liger. On the Liger path, ref decoder forward is also wrapped in torch.no_grad() (it was missing before), so ref hidden states are no longer computed with gradients enabled.

Reviewed by Cursor Bugbot for commit d082d3c. Bugbot is set up for automated code reviews on this repo. Configure here.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@albertvillanova albertvillanova merged commit 882e731 into main Jun 1, 2026
6 checks passed
@albertvillanova albertvillanova deleted the align-kto-dpo-gradient_checkpointing branch June 1, 2026 12:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants