Skip to content

SFTTrainer silently breaks datasets that use Dataset.with_transform #6039

@proshian

Description

@proshian

Reproduction

Use case

with_transform is the natural tool for augmentation that must re-randomize on every access rather than bake in once. Example: fine-tuning a function-calling LLM, I want it to read the tool docs in the system prompt instead of memorizing an intention --> function-name mapping. So each access renames the functions randomly but consistently within the example (schema, call, docs). Applying this once with .map() would just make the model memorize the new fixed names — it only works if every access differs. Identifier renaming, paraphrasing, noise injection all share this shape.

Summary

If train_dataset has a transform set via with_transform(...), there is no configuration in which SFTTrainer behaves correctly:

  1. _prepare_dataset corrupts the data. Each internal prep step (add_eostokenize_fn, …) uses Dataset.map, and Dataset.map reads its input rows through the dataset's transform — so the user's augmentation fires once per prep step, each time on top of the previous step's already-transformed output. By the time tokenization runs, input_ids is frozen to this arbitrary stacked realization (in the repro below, the second application even lands after the EOS token).
  2. Default config crashes at train time. Trainer._remove_unused_columns drops the underlying raw column (e.g. text), but the transform survives → KeyError: 'text' on the first batch.
  3. With remove_unused_columns=False, augmentation is silently dead. text keeps re-randomizing on access, so augmentation looks live — but the collator only reads input_ids, which is frozen. Every epoch trains on the same corrupted sample from (1).

Reproduction — failure modes 1 and 3

import tempfile, random
from datasets import Dataset
from trl import SFTTrainer, SFTConfig

random.seed(0)
ds = Dataset.from_dict({"text": ["hello world"] * 2})

def augment(batch):
    # stand-in for real augmentation, e.g. random function renaming for tool-calling SFT
    batch["text"] = [t + f" <TAG{random.randint(0, 999)}>" for t in batch["text"]]
    return batch

ds = ds.with_transform(augment)

with tempfile.TemporaryDirectory() as tmp_dir:
    trainer = SFTTrainer(
        model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
        train_dataset=ds,
        args=SFTConfig(output_dir=tmp_dir, report_to="none", use_cpu=True, bf16=False),
    )

tok = trainer.processing_class
for epoch in range(3):
    item = trainer.train_dataset[0]
    print(f"epoch {epoch}: text={item['text']!r}")
    print(f"         input_ids -> {tok.decode(item['input_ids'])!r}")

Output:

epoch 0: text='hello world <TAG911><|im_end|> <TAG654>'
         input_ids -> 'hello world <TAG911><|im_end|> <TAG41>'
epoch 1: text='hello world <TAG911><|im_end|> <TAG114>'
         input_ids -> 'hello world <TAG911><|im_end|> <TAG41>'
epoch 2: text='hello world <TAG911><|im_end|> <TAG25>'
         input_ids -> 'hello world <TAG911><|im_end|> <TAG41>'

Reading the output: <TAG911> got baked into stored text during prep's add_eos map, then <|im_end|> was appended after it. On each access, the transform only appends a new tag to this stored string — so only the trailing <TAG654>/<TAG114>/<TAG25> changes; everything before it, including <TAG911>, is frozen literal data. input_ids was tokenized from a separate draw (<TAG41>) during the tokenize_fn map and never changes afterward, since the transform doesn't touch input_ids.

Reproduction — failure mode 2 (default config crashes)

Same setup, but call next(iter(trainer.get_train_dataloader())) (or trainer.train()) instead of accessing trainer.train_dataset[0]:

import tempfile, random
from datasets import Dataset
from trl import SFTTrainer, SFTConfig

random.seed(0)
ds = Dataset.from_dict({"text": ["hello world"] * 2})

def augment(batch):
    batch["text"] = [t + f" <TAG{random.randint(0, 999)}>" for t in batch["text"]]
    return batch

ds = ds.with_transform(augment)

with tempfile.TemporaryDirectory() as tmp_dir:
    trainer = SFTTrainer(
        model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
        train_dataset=ds,
        args=SFTConfig(output_dir=tmp_dir, report_to="none", use_cpu=True, bf16=False),
    )
    batch = next(iter(trainer.get_train_dataloader()))

Full traceback:

Traceback (most recent call last):
  File "C:\Users\my-user\Documents\trl-transform-repro\repro_crash.py", line 22, in <module>
    batch = next(iter(trainer.get_train_dataloader()))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\accelerate\data_loader.py", line 585, in __iter__
    current_batch = next(dataloader_iter)
                    ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\torch\utils\data\dataloader.py", line 718, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\torch\utils\data\dataloader.py", line 778, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\arrow_dataset.py", line 3162, in __getitems__
    batch = self.__getitem__(keys)
            ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\arrow_dataset.py", line 3158, in __getitem__
    return self._getitem(key)
           ^^^^^^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\arrow_dataset.py", line 3140, in _getitem
    formatted_output = format_table(
                       ^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\formatting\formatting.py", line 658, in format_table
    return formatter(pa_table, query_type=query_type)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\formatting\formatting.py", line 415, in __call__
    return self.format_batch(pa_table)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\formatting\formatting.py", line 541, in format_batch
    return self.transform(batch)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\my-user\Documents\trl-transform-repro\repro_crash.py", line 10, in augment
    batch["text"] = [t + f" <TAG{random.randint(0, 999)}>" for t in batch["text"]]
                                                                    ~~~~~^^^^^^^^
KeyError: 'text'

A transform that returns only the changed column (return {"text": [...]} — the other common idiom) is worse still: a custom transform's output replaces the row, so the prepared dataset's __getitem__ yields no input_ids at all.

Proposed fix — one of:

  • Minimal: detect a custom transform at the top of _prepare_dataset and raise, pointing to dataset_kwargs={"skip_prepare_dataset": True}. Proceeding is never correct today.
  • Proper: when a custom transform is detected, replace the .map() calls with lazy transform chaining (user_transform → add_eos → tokenize_fn → … composed via with_transform), so fresh augmentation reaches input_ids on every access. This requires raising on packing=True (packing can't be lazy) and forcing remove_unused_columns=False (the Arrow-level columns stay raw).

Happy to help with a PR for either approach if there's interest

System Info

  • Platform: Windows-10-10.0.26100-SP0
  • Python version: 3.11.13
  • TRL version: 1.7.0.dev0
  • PyTorch version: 2.12.0
  • accelerator(s): cpu
  • Transformers version: 5.12.0
  • Accelerate version: 1.14.0
  • Accelerate config: not found
  • Datasets version: 5.0.0
  • HF Hub version: 1.19.0
  • bitsandbytes version: not installed
  • DeepSpeed version: not installed
  • Liger-Kernel version: not installed
  • PEFT version: not installed
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions