Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions trl/experimental/kto/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def _precompute_ref_logps(self, dataset: Dataset, name: str, batch_size: int) ->
reference_logps = []
reference_KL_logps = []
for padded_batch in tqdm(iterable=data_loader, desc=f"Computing reference log probs for {name} dataset"):
reference_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
reference_logp, reference_KL_logp = self.compute_ref_log_probs(padded_batch)
if self.calculate_KL:
reference_logp, reference_KL_logp = self.accelerator.gather_for_metrics(
(reference_logp, reference_KL_logp)
Expand Down Expand Up @@ -712,42 +712,42 @@ def _precompute_ref_logps(self, dataset: Dataset, name: str, batch_size: int) ->

return dataset

def compute_reference_log_probs(self, padded_batch: dict) -> dict:
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
def compute_ref_log_probs(self, inputs):
"""Computes reference log probabilities for a single padded batch."""

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename not propagated to BCO trainer

Low Severity

The rename from compute_reference_log_probs to compute_ref_log_probs was applied only to the KTO trainer, but the BCO trainer still uses the old name compute_reference_log_probs with the old parameter name padded_batch. The project's AGENTS.md rule states that when modifying duplicated code across trainers, the same change must be applied to all other trainers, and "not propagating a change is a bug."

Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

Reviewed by Cursor Bugbot for commit 8959f80. Configure here.

with torch.no_grad():
if self.ref_model is None:
with self.null_ref_context():
completion_logits = self.model(
padded_batch["completion_input_ids"],
attention_mask=padded_batch["completion_attention_mask"],
inputs["completion_input_ids"],
attention_mask=inputs["completion_attention_mask"],
).logits

if self.calculate_KL:
KL_logits = self.model(
padded_batch["KL_completion_input_ids"],
attention_mask=padded_batch["KL_completion_attention_mask"],
inputs["KL_completion_input_ids"],
attention_mask=inputs["KL_completion_attention_mask"],
).logits
else:
completion_logits = self.ref_model(
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
inputs["completion_input_ids"], attention_mask=inputs["completion_attention_mask"]
).logits

if self.calculate_KL:
KL_logits = self.ref_model(
padded_batch["KL_completion_input_ids"],
attention_mask=padded_batch["KL_completion_attention_mask"],
inputs["KL_completion_input_ids"],
attention_mask=inputs["KL_completion_attention_mask"],
).logits

completion_logps = self.get_batch_logps(
completion_logits,
padded_batch["completion_labels"],
inputs["completion_labels"],
average_log_prob=False,
)

if self.calculate_KL:
KL_logps = self.get_batch_logps(
KL_logits,
padded_batch["KL_completion_labels"],
inputs["KL_completion_labels"],
average_log_prob=False,
)
else:
Expand Down
Loading