Reproduction
Use case
with_transform is the natural tool for augmentation that must re-randomize on every access rather than bake in once. Example: fine-tuning a function-calling LLM, I want it to read the tool docs in the system prompt instead of memorizing an intention --> function-name mapping. So each access renames the functions randomly but consistently within the example (schema, call, docs). Applying this once with .map() would just make the model memorize the new fixed names — it only works if every access differs. Identifier renaming, paraphrasing, noise injection all share this shape.
Summary
If train_dataset has a transform set via with_transform(...), there is no configuration in which SFTTrainer behaves correctly:
_prepare_dataset corrupts the data. Each internal prep step (add_eos, tokenize_fn, …) uses Dataset.map, and Dataset.map reads its input rows through the dataset's transform — so the user's augmentation fires once per prep step, each time on top of the previous step's already-transformed output. By the time tokenization runs, input_ids is frozen to this arbitrary stacked realization (in the repro below, the second application even lands after the EOS token).
- Default config crashes at train time.
Trainer._remove_unused_columns drops the underlying raw column (e.g. text), but the transform survives → KeyError: 'text' on the first batch.
- With
remove_unused_columns=False, augmentation is silently dead. text keeps re-randomizing on access, so augmentation looks live — but the collator only reads input_ids, which is frozen. Every epoch trains on the same corrupted sample from (1).
Reproduction — failure modes 1 and 3
import tempfile, random
from datasets import Dataset
from trl import SFTTrainer, SFTConfig
random.seed(0)
ds = Dataset.from_dict({"text": ["hello world"] * 2})
def augment(batch):
# stand-in for real augmentation, e.g. random function renaming for tool-calling SFT
batch["text"] = [t + f" <TAG{random.randint(0, 999)}>" for t in batch["text"]]
return batch
ds = ds.with_transform(augment)
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
train_dataset=ds,
args=SFTConfig(output_dir=tmp_dir, report_to="none", use_cpu=True, bf16=False),
)
tok = trainer.processing_class
for epoch in range(3):
item = trainer.train_dataset[0]
print(f"epoch {epoch}: text={item['text']!r}")
print(f" input_ids -> {tok.decode(item['input_ids'])!r}")
Output:
epoch 0: text='hello world <TAG911><|im_end|> <TAG654>'
input_ids -> 'hello world <TAG911><|im_end|> <TAG41>'
epoch 1: text='hello world <TAG911><|im_end|> <TAG114>'
input_ids -> 'hello world <TAG911><|im_end|> <TAG41>'
epoch 2: text='hello world <TAG911><|im_end|> <TAG25>'
input_ids -> 'hello world <TAG911><|im_end|> <TAG41>'
Reading the output: <TAG911> got baked into stored text during prep's add_eos map, then <|im_end|> was appended after it. On each access, the transform only appends a new tag to this stored string — so only the trailing <TAG654>/<TAG114>/<TAG25> changes; everything before it, including <TAG911>, is frozen literal data. input_ids was tokenized from a separate draw (<TAG41>) during the tokenize_fn map and never changes afterward, since the transform doesn't touch input_ids.
Reproduction — failure mode 2 (default config crashes)
Same setup, but call next(iter(trainer.get_train_dataloader())) (or trainer.train()) instead of accessing trainer.train_dataset[0]:
import tempfile, random
from datasets import Dataset
from trl import SFTTrainer, SFTConfig
random.seed(0)
ds = Dataset.from_dict({"text": ["hello world"] * 2})
def augment(batch):
batch["text"] = [t + f" <TAG{random.randint(0, 999)}>" for t in batch["text"]]
return batch
ds = ds.with_transform(augment)
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
train_dataset=ds,
args=SFTConfig(output_dir=tmp_dir, report_to="none", use_cpu=True, bf16=False),
)
batch = next(iter(trainer.get_train_dataloader()))
Full traceback:
Traceback (most recent call last):
File "C:\Users\my-user\Documents\trl-transform-repro\repro_crash.py", line 22, in <module>
batch = next(iter(trainer.get_train_dataloader()))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\accelerate\data_loader.py", line 585, in __iter__
current_batch = next(dataloader_iter)
^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\torch\utils\data\dataloader.py", line 718, in __next__
data = self._next_data()
^^^^^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\torch\utils\data\dataloader.py", line 778, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
data = self.dataset.__getitems__(possibly_batched_index)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\arrow_dataset.py", line 3162, in __getitems__
batch = self.__getitem__(keys)
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\arrow_dataset.py", line 3158, in __getitem__
return self._getitem(key)
^^^^^^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\arrow_dataset.py", line 3140, in _getitem
formatted_output = format_table(
^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\formatting\formatting.py", line 658, in format_table
return formatter(pa_table, query_type=query_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\formatting\formatting.py", line 415, in __call__
return self.format_batch(pa_table)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\.venv\Lib\site-packages\datasets\formatting\formatting.py", line 541, in format_batch
return self.transform(batch)
^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\my-user\Documents\trl-transform-repro\repro_crash.py", line 10, in augment
batch["text"] = [t + f" <TAG{random.randint(0, 999)}>" for t in batch["text"]]
~~~~~^^^^^^^^
KeyError: 'text'
A transform that returns only the changed column (return {"text": [...]} — the other common idiom) is worse still: a custom transform's output replaces the row, so the prepared dataset's __getitem__ yields no input_ids at all.
Proposed fix — one of:
- Minimal: detect a custom transform at the top of
_prepare_dataset and raise, pointing to dataset_kwargs={"skip_prepare_dataset": True}. Proceeding is never correct today.
- Proper: when a custom transform is detected, replace the
.map() calls with lazy transform chaining (user_transform → add_eos → tokenize_fn → … composed via with_transform), so fresh augmentation reaches input_ids on every access. This requires raising on packing=True (packing can't be lazy) and forcing remove_unused_columns=False (the Arrow-level columns stay raw).
Happy to help with a PR for either approach if there's interest
System Info
- Platform: Windows-10-10.0.26100-SP0
- Python version: 3.11.13
- TRL version: 1.7.0.dev0
- PyTorch version: 2.12.0
- accelerator(s): cpu
- Transformers version: 5.12.0
- Accelerate version: 1.14.0
- Accelerate config: not found
- Datasets version: 5.0.0
- HF Hub version: 1.19.0
- bitsandbytes version: not installed
- DeepSpeed version: not installed
- Liger-Kernel version: not installed
- PEFT version: not installed
- vLLM version: not installed
Checklist
Reproduction
Use case
with_transformis the natural tool for augmentation that must re-randomize on every access rather than bake in once. Example: fine-tuning a function-calling LLM, I want it to read the tool docs in the system prompt instead of memorizing an intention --> function-name mapping. So each access renames the functions randomly but consistently within the example (schema, call, docs). Applying this once with .map() would just make the model memorize the new fixed names — it only works if every access differs. Identifier renaming, paraphrasing, noise injection all share this shape.Summary
If
train_datasethas a transform set viawith_transform(...), there is no configuration in whichSFTTrainerbehaves correctly:_prepare_datasetcorrupts the data. Each internal prep step (add_eos,tokenize_fn, …) usesDataset.map, andDataset.mapreads its input rows through the dataset's transform — so the user's augmentation fires once per prep step, each time on top of the previous step's already-transformed output. By the time tokenization runs,input_idsis frozen to this arbitrary stacked realization (in the repro below, the second application even lands after the EOS token).Trainer._remove_unused_columnsdrops the underlying raw column (e.g.text), but the transform survives →KeyError: 'text'on the first batch.remove_unused_columns=False, augmentation is silently dead.textkeeps re-randomizing on access, so augmentation looks live — but the collator only readsinput_ids, which is frozen. Every epoch trains on the same corrupted sample from (1).Reproduction — failure modes 1 and 3
Output:
Reading the output:
<TAG911>got baked into storedtextduring prep'sadd_eosmap, then<|im_end|>was appended after it. On each access, the transform only appends a new tag to this stored string — so only the trailing<TAG654>/<TAG114>/<TAG25>changes; everything before it, including<TAG911>, is frozen literal data.input_idswas tokenized from a separate draw (<TAG41>) during thetokenize_fnmap and never changes afterward, since the transform doesn't touchinput_ids.Reproduction — failure mode 2 (default config crashes)
Same setup, but call
next(iter(trainer.get_train_dataloader()))(ortrainer.train()) instead of accessingtrainer.train_dataset[0]:Full traceback:
A transform that returns only the changed column (
return {"text": [...]}— the other common idiom) is worse still: a custom transform's output replaces the row, so the prepared dataset's__getitem__yields noinput_idsat all.Proposed fix — one of:
_prepare_datasetand raise, pointing todataset_kwargs={"skip_prepare_dataset": True}. Proceeding is never correct today..map()calls with lazy transform chaining (user_transform → add_eos → tokenize_fn → …composed viawith_transform), so fresh augmentation reachesinput_idson every access. This requires raising onpacking=True(packing can't be lazy) and forcingremove_unused_columns=False(the Arrow-level columns stay raw).Happy to help with a PR for either approach if there's interest
System Info
Checklist