Skip to content

Commit

Permalink
Make packing/padding a training setting
Browse files Browse the repository at this point in the history
  • Loading branch information
mreso committed Oct 2, 2023
1 parent cc8cc0d commit a647955
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 78 deletions.
6 changes: 2 additions & 4 deletions src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class train_config:
low_cpu_fsdp: bool=False
run_validation: bool=True
batch_size_training: int=4
batching_strategy: str="packing" #alternative: padding
context_length: int=4096
gradient_accumulation_steps: int=1
num_epochs: int=3
num_workers_dataloader: int=1
Expand All @@ -34,7 +36,3 @@ class train_config:
dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
save_optimizer: bool=False # will be used if using FSDP
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels




Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Concatenator(object):
def __init__(self, chunk_size=2048):
self.chunk_size=chunk_size
self.residual = {"input_ids": [], "attention_mask": []}

def __call__(self, batch):
concatenated_samples = {
k: v + list(chain(*batch[k])) for k, v in self.residual.items()
Expand Down Expand Up @@ -44,26 +44,24 @@ class ConcatDataset(Dataset):
def __init__(self, dataset, chunk_size=4096):
self.dataset = dataset
self.chunk_size = chunk_size

self.samples = []

buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}

for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
buffer = {k: v + sample[k] for k,v in buffer.items()}

while len(next(iter(buffer.values()))) > self.chunk_size:
self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}

def __getitem__(self, idx):
return self.samples[idx]

def __len__(self):
return len(self.samples)


9 changes: 3 additions & 6 deletions src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

from torch.utils.data import Dataset

from llama_recipes.datasets.utils import ConcatDataset


class grammar(Dataset):
def __init__(
Expand Down Expand Up @@ -48,10 +46,10 @@ def convert_to_features(self, example_batch):

input_ = example_batch["input"]
target_ = example_batch["target"]

prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
sample = self.tokenizer(prompt)

return sample

def __getitem__(self, index):
Expand Down Expand Up @@ -80,6 +78,5 @@ def get_dataset(
tokenizer=tokenizer,
csv_name=csv_name,
)

return dataset

return dataset
3 changes: 1 addition & 2 deletions src/llama_recipes/datasets/samsum_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import datasets

from llama_recipes.datasets.utils import Concatenator

def get_preprocessed_samsum(dataset_config, tokenizer, split):
dataset = datasets.load_dataset("samsum", split=split)
Expand All @@ -24,7 +23,7 @@ def apply_prompt_template(sample):
}

dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))

dataset = dataset.map(
lambda sample: tokenizer(sample["text"]),
remove_columns=list(dataset.features),
Expand Down
28 changes: 14 additions & 14 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
LlamaForCausalLM,
LlamaTokenizer,
LlamaConfig,
)
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

from llama_recipes.configs import fsdp_config, train_config
from llama_recipes.data.concatenator import ConcatDataset
from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing

from llama_recipes.utils import fsdp_auto_wrap_policy
from llama_recipes.utils.config_utils import (
update_config,
generate_peft_config,
generate_dataset_config,
get_sampler_kwargs,
get_dataloader_kwargs,
)
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset

Expand Down Expand Up @@ -100,25 +101,19 @@ def main(**kwargs):
if train_config.enable_fsdp and train_config.use_fast_kernels:
"""
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
using of Flash Attention or Xformer memory-efficient kernels
using of Flash Attention or Xformer memory-efficient kernels
based on the hardware being used. This would speed up fine-tuning.
"""
try:
from optimum.bettertransformer import BetterTransformer
model = BetterTransformer.transform(model)
model = BetterTransformer.transform(model)
except ImportError:
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

# Load the tokenizer and add special tokens
tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
tokenizer.add_special_tokens(
{
tokenizer.pad_token_id = tokenizer.eos_token_id

"pad_token": "<PAD>",
}
)
model.resize_token_embeddings(model.config.vocab_size + 1)

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

# Prepare the model for int8 training if quantization is enabled
Expand Down Expand Up @@ -180,8 +175,11 @@ def main(**kwargs):
if not train_config.enable_fsdp or rank == 0:
print(f"--> Validation Set Length = {len(dataset_val)}")

train_dl_kwargs = get_sampler_kwargs(train_config, dataset_train, tokenizer, "train")
val_dl_kwargs = get_sampler_kwargs(train_config, dataset_val, tokenizer, "val")
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")

if train_config.batching_strategy == "packing":
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)

# Create DataLoaders for the training and validation dataset
train_dataloader = torch.utils.data.DataLoader(
Expand All @@ -193,6 +191,8 @@ def main(**kwargs):

eval_dataloader = None
if train_config.run_validation:
if train_config.batching_strategy == "packing":
dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
eval_dataloader = torch.utils.data.DataLoader(
dataset_val,
num_workers=train_config.num_workers_dataloader,
Expand Down
48 changes: 31 additions & 17 deletions src/llama_recipes/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,49 +38,63 @@ def update_config(config, **kwargs):
print(f"Warning: {config_name} does not accept parameter: {k}")
elif isinstance(config, train_config):
print(f"Warning: unknown parameter {k}")


def generate_peft_config(train_config, kwargs):
configs = (lora_config, llama_adapter_config, prefix_config)
peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
names = tuple(c.__name__.rstrip("_config") for c in configs)

assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"

config = configs[names.index(train_config.peft_method)]()

update_config(config, **kwargs)
params = asdict(config)
peft_config = peft_configs[names.index(train_config.peft_method)](**params)

return peft_config


def generate_dataset_config(train_config, kwargs):
names = tuple(DATASET_PREPROC.keys())

assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"

dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()

update_config(dataset_config, **kwargs)

return dataset_config


def get_sampler_kwargs(train_config, dataset, tokenizer, mode):
def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
if train_config.enable_fsdp:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
if train_config.batching_strategy == "padding":
if train_config.enable_fsdp:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode=="train",
)
else:
kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
elif train_config.batching_strategy == "packing":
if train_config.enable_fsdp:
kwargs["batch_sampler"] = DistributedSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode=="train",
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
else:
kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)

raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")

return kwargs
Loading

0 comments on commit a647955

Please sign in to comment.