Skip to content

Conversation

@pramodith
Copy link
Collaborator

What does this PR do?

The SFTTrainer currently misses doc strings for the compute_loss_func and compute_metrics constructor args. Both of these are passed on to the parent Trainer Class's constructor. This PR adds them, copied over from the Trainer class from the Transformers lib.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

I think compute_loss_func is actually never used in SFTTrainer. I recommend removing it (deprecate and remove in next version)

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
@pramodith
Copy link
Collaborator Author

I think compute_loss_func is actually never used in SFTTrainer. I recommend removing it (deprecate and remove in next version)

I think it is used because it is passed to the transformer.Trainer's constructor which allows the user to pass in a custom loss function.

super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
compute_loss_func=compute_loss_func,

The Trainer calls the passed compute_loss_func in its own compute_loss function.
https://github.com/huggingface/transformers/blob/25b4a0d8aef515ed309c935607473da319d9291c/src/transformers/trainer.py#L4123-L4125

I actually think having this is a nice way to allow users to experiment with their own loss functions, so would refrain from getting rid of it. This PR in progress also serves as an example of how one can use the compute_loss_func for POCs.
#3960

@qgallouedec
Copy link
Member

compute_loss is overwritten in SFT, so this function is never user

@pramodith
Copy link
Collaborator Author

pramodith commented Sep 4, 2025

compute_loss is overwritten in SFT, so this function is never user

Yes but even the overwritten version calls the trainer's implementation of compute_loss which in turn uses compute_loss_func

trl/trl/trainer/sft_trainer.py

Lines 1041 to 1043 in 0c69fd2

(loss, outputs) = super().compute_loss(
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
)

Relevant code in Trainer's compute_loss

if self.compute_loss_func is not None:
                loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
            elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)

@qgallouedec
Copy link
Member

My bad you're right!

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@pramodith pramodith merged commit 19f9b9e into huggingface:main Sep 4, 2025
8 of 10 checks passed
SamY724 pushed a commit to SamY724/trl that referenced this pull request Sep 6, 2025
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
@pramodith pramodith deleted the pramodith/sft_trainer_missing_doc_strings branch September 8, 2025 20:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants