From a647955fc8fdb5053dae4c0f6b81477f067a9368 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 2 Oct 2023 05:44:29 -0700 Subject: [PATCH] Make packing/padding a training setting --- src/llama_recipes/configs/training.py | 6 +- .../utils.py => data/concatenator.py} | 16 ++-- .../grammar_dataset/grammar_dataset.py | 9 +- src/llama_recipes/datasets/samsum_dataset.py | 3 +- src/llama_recipes/finetuning.py | 28 +++--- src/llama_recipes/utils/config_utils.py | 48 ++++++---- tests/test_finetuning.py | 91 +++++++++++++------ 7 files changed, 123 insertions(+), 78 deletions(-) rename src/llama_recipes/{datasets/utils.py => data/concatenator.py} (96%) diff --git a/src/llama_recipes/configs/training.py b/src/llama_recipes/configs/training.py index 53148773e..354c534eb 100644 --- a/src/llama_recipes/configs/training.py +++ b/src/llama_recipes/configs/training.py @@ -11,6 +11,8 @@ class train_config: low_cpu_fsdp: bool=False run_validation: bool=True batch_size_training: int=4 + batching_strategy: str="packing" #alternative: padding + context_length: int=4096 gradient_accumulation_steps: int=1 num_epochs: int=3 num_workers_dataloader: int=1 @@ -34,7 +36,3 @@ class train_config: dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP save_optimizer: bool=False # will be used if using FSDP use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels - - - - diff --git a/src/llama_recipes/datasets/utils.py b/src/llama_recipes/data/concatenator.py similarity index 96% rename from src/llama_recipes/datasets/utils.py rename to src/llama_recipes/data/concatenator.py index 8ffd5bfd5..611aa6ec5 100644 --- a/src/llama_recipes/datasets/utils.py +++ b/src/llama_recipes/data/concatenator.py @@ -11,7 +11,7 @@ class Concatenator(object): def __init__(self, chunk_size=2048): self.chunk_size=chunk_size self.residual = {"input_ids": [], "attention_mask": []} - + def __call__(self, batch): concatenated_samples = { k: v + list(chain(*batch[k])) for k, v in self.residual.items() @@ -44,26 +44,24 @@ class ConcatDataset(Dataset): def __init__(self, dataset, chunk_size=4096): self.dataset = dataset self.chunk_size = chunk_size - + self.samples = [] - + buffer = { "input_ids": [], "attention_mask": [], "labels": [], } - + for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): buffer = {k: v + sample[k] for k,v in buffer.items()} - + while len(next(iter(buffer.values()))) > self.chunk_size: self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} - + def __getitem__(self, idx): return self.samples[idx] - + def __len__(self): return len(self.samples) - - diff --git a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py b/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py index 2948560c5..ac686137d 100644 --- a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py +++ b/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py @@ -10,8 +10,6 @@ from torch.utils.data import Dataset -from llama_recipes.datasets.utils import ConcatDataset - class grammar(Dataset): def __init__( @@ -48,10 +46,10 @@ def convert_to_features(self, example_batch): input_ = example_batch["input"] target_ = example_batch["target"] - + prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}" sample = self.tokenizer(prompt) - + return sample def __getitem__(self, index): @@ -80,6 +78,5 @@ def get_dataset( tokenizer=tokenizer, csv_name=csv_name, ) - - return dataset + return dataset diff --git a/src/llama_recipes/datasets/samsum_dataset.py b/src/llama_recipes/datasets/samsum_dataset.py index ab8fadbde..6093d0289 100644 --- a/src/llama_recipes/datasets/samsum_dataset.py +++ b/src/llama_recipes/datasets/samsum_dataset.py @@ -5,7 +5,6 @@ import datasets -from llama_recipes.datasets.utils import Concatenator def get_preprocessed_samsum(dataset_config, tokenizer, split): dataset = datasets.load_dataset("samsum", split=split) @@ -24,7 +23,7 @@ def apply_prompt_template(sample): } dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features)) - + dataset = dataset.map( lambda sample: tokenizer(sample["text"]), remove_columns=list(dataset.features), diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 44c74c78e..54600190c 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -17,10 +17,11 @@ LlamaForCausalLM, LlamaTokenizer, LlamaConfig, -) +) from transformers.models.llama.modeling_llama import LlamaDecoderLayer from llama_recipes.configs import fsdp_config, train_config +from llama_recipes.data.concatenator import ConcatDataset from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing from llama_recipes.utils import fsdp_auto_wrap_policy @@ -28,7 +29,7 @@ update_config, generate_peft_config, generate_dataset_config, - get_sampler_kwargs, + get_dataloader_kwargs, ) from llama_recipes.utils.dataset_utils import get_preprocessed_dataset @@ -100,25 +101,19 @@ def main(**kwargs): if train_config.enable_fsdp and train_config.use_fast_kernels: """ For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable - using of Flash Attention or Xformer memory-efficient kernels + using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up fine-tuning. """ try: from optimum.bettertransformer import BetterTransformer - model = BetterTransformer.transform(model) + model = BetterTransformer.transform(model) except ImportError: print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") - + # Load the tokenizer and add special tokens tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name) - tokenizer.add_special_tokens( - { + tokenizer.pad_token_id = tokenizer.eos_token_id - "pad_token": "", - } - ) - model.resize_token_embeddings(model.config.vocab_size + 1) - print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) # Prepare the model for int8 training if quantization is enabled @@ -180,8 +175,11 @@ def main(**kwargs): if not train_config.enable_fsdp or rank == 0: print(f"--> Validation Set Length = {len(dataset_val)}") - train_dl_kwargs = get_sampler_kwargs(train_config, dataset_train, tokenizer, "train") - val_dl_kwargs = get_sampler_kwargs(train_config, dataset_val, tokenizer, "val") + train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train") + val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val") + + if train_config.batching_strategy == "packing": + dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) # Create DataLoaders for the training and validation dataset train_dataloader = torch.utils.data.DataLoader( @@ -193,6 +191,8 @@ def main(**kwargs): eval_dataloader = None if train_config.run_validation: + if train_config.batching_strategy == "packing": + dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length) eval_dataloader = torch.utils.data.DataLoader( dataset_val, num_workers=train_config.num_workers_dataloader, diff --git a/src/llama_recipes/utils/config_utils.py b/src/llama_recipes/utils/config_utils.py index 422240d73..51d74e510 100644 --- a/src/llama_recipes/utils/config_utils.py +++ b/src/llama_recipes/utils/config_utils.py @@ -38,49 +38,63 @@ def update_config(config, **kwargs): print(f"Warning: {config_name} does not accept parameter: {k}") elif isinstance(config, train_config): print(f"Warning: unknown parameter {k}") - - + + def generate_peft_config(train_config, kwargs): configs = (lora_config, llama_adapter_config, prefix_config) peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) names = tuple(c.__name__.rstrip("_config") for c in configs) - + assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}" - + config = configs[names.index(train_config.peft_method)]() - + update_config(config, **kwargs) params = asdict(config) peft_config = peft_configs[names.index(train_config.peft_method)](**params) - + return peft_config def generate_dataset_config(train_config, kwargs): names = tuple(DATASET_PREPROC.keys()) - + assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" - + dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]() - + update_config(dataset_config, **kwargs) - + return dataset_config -def get_sampler_kwargs(train_config, dataset, tokenizer, mode): +def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs = {} batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size - if train_config.enable_fsdp: - kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( + if train_config.batching_strategy == "padding": + if train_config.enable_fsdp: + kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( + dataset, + batch_size=batch_size, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + shuffle=mode=="train", + ) + else: + kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train") + kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer) + elif train_config.batching_strategy == "packing": + if train_config.enable_fsdp: + kwargs["batch_sampler"] = DistributedSampler( dataset, - batch_size=batch_size, rank=dist.get_rank(), num_replicas=dist.get_world_size(), shuffle=mode=="train", ) + kwargs["batch_size"] = batch_size + kwargs["drop_last"] = True + kwargs["collate_fn"] = default_data_collator else: - kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train") - kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer) - + raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") + return kwargs diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py index f980249b0..6651a22a2 100644 --- a/tests/test_finetuning.py +++ b/tests/test_finetuning.py @@ -1,14 +1,17 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import pytest from pytest import approx from unittest.mock import patch from torch.nn import Linear from torch.optim import AdamW from torch.utils.data.dataloader import DataLoader +from torch.utils.data.sampler import BatchSampler from llama_recipes.finetuning import main +from llama_recipes.data.sampler import LengthBasedBatchSampler @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @@ -18,23 +21,23 @@ @patch('llama_recipes.finetuning.StepLR') def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train): kwargs = {"run_validation": False} - + get_dataset.return_value = [[1]] - + main(**kwargs) - + assert train.call_count == 1 - + args, kwargs = train.call_args train_dataloader = args[1] eval_dataloader = args[2] - + assert isinstance(train_dataloader, DataLoader) assert eval_dataloader is None - + assert get_model.return_value.to.call_args.args[0] == "cuda" - - + + @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') @@ -44,20 +47,20 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train): kwargs = {"run_validation": True} get_dataset.return_value = [[1]] - + main(**kwargs) - + assert train.call_count == 1 - + args, kwargs = train.call_args train_dataloader = args[1] eval_dataloader = args[2] assert isinstance(train_dataloader, DataLoader) assert isinstance(eval_dataloader, DataLoader) - + assert get_model.return_value.to.call_args.args[0] == "cuda" - - + + @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') @@ -68,15 +71,15 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, @patch('llama_recipes.finetuning.StepLR') def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train): kwargs = {"use_peft": True} - + get_dataset.return_value = [[1]] - + main(**kwargs) - + assert get_peft_model.return_value.to.call_args.args[0] == "cuda" assert get_peft_model.return_value.print_trainable_parameters.call_count == 1 - - + + @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') @@ -85,20 +88,56 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge @patch('llama_recipes.finetuning.StepLR') def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train): kwargs = {"weight_decay": 0.01} - + get_dataset.return_value = [[1]] - + get_peft_model.return_value = Linear(1,1) get_peft_model.return_value.print_trainable_parameters=lambda:None main(**kwargs) - + assert train.call_count == 1 - + args, kwargs = train.call_args optimizer = args[4] - + print(optimizer.state_dict()) - + assert isinstance(optimizer, AdamW) assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01) - \ No newline at end of file + + +@patch('llama_recipes.finetuning.train') +@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') +@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained') +@patch('llama_recipes.finetuning.get_preprocessed_dataset') +@patch('llama_recipes.finetuning.optim.AdamW') +@patch('llama_recipes.finetuning.StepLR') +def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train): + kwargs = {"batching_strategy": "packing"} + + get_dataset.return_value = [[1]] + + main(**kwargs) + + assert train.call_count == 1 + + args, kwargs = train.call_args + train_dataloader, eval_dataloader = args[1:3] + assert isinstance(train_dataloader.batch_sampler, BatchSampler) + assert isinstance(eval_dataloader.batch_sampler, BatchSampler) + + kwargs["batching_strategy"] = "padding" + train.reset_mock() + main(**kwargs) + + assert train.call_count == 1 + + args, kwargs = train.call_args + train_dataloader, eval_dataloader = args[1:3] + assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler) + assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler) + + kwargs["batching_strategy"] = "none" + + with pytest.raises(ValueError): + main(**kwargs)