Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 89 additions & 4 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,16 +1649,37 @@ def test_causal_lm_training_multi_gpu(self):
Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set
correctly.
"""
device_map = {
"model.decoder.embed_tokens": 0,
"lm_head": 0,
"model.decoder.embed_positions": 0,
"model.decoder.project_out": 0,
"model.decoder.project_in": 0,
"model.decoder.layers.0": 0,
"model.decoder.layers.1": 0,
"model.decoder.layers.2": 0,
"model.decoder.layers.3": 0,
"model.decoder.layers.4": 0,
"model.decoder.layers.5": 0,
"model.decoder.layers.6": 1,
"model.decoder.layers.7": 1,
"model.decoder.layers.8": 1,
"model.decoder.layers.9": 1,
"model.decoder.layers.10": 1,
"model.decoder.layers.11": 1,
"model.decoder.final_layer_norm": 1,
}

with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
device_map=device_map,
quantization_config=self.quantization_config,
)

assert set(model.hf_device_map.values()) == set(range(device_count))
assert {p.device.index for p in model.parameters()} == set(range(device_count))

model = prepare_model_for_kbit_training(model)

Expand Down Expand Up @@ -3182,14 +3203,35 @@ def test_causal_lm_training_multi_gpu(self):
Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set
correctly.
"""
device_map = {
"model.decoder.embed_tokens": 0,
"lm_head": 0,
"model.decoder.embed_positions": 0,
"model.decoder.project_out": 0,
"model.decoder.project_in": 0,
"model.decoder.layers.0": 0,
"model.decoder.layers.1": 0,
"model.decoder.layers.2": 0,
"model.decoder.layers.3": 0,
"model.decoder.layers.4": 0,
"model.decoder.layers.5": 0,
"model.decoder.layers.6": 1,
"model.decoder.layers.7": 1,
"model.decoder.layers.8": 1,
"model.decoder.layers.9": 1,
"model.decoder.layers.10": 1,
"model.decoder.layers.11": 1,
"model.decoder.final_layer_norm": 1,
}

with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map="auto",
device_map=device_map,
)

assert set(model.hf_device_map.values()) == set(range(device_count))
assert {p.device.index for p in model.parameters()} == set(range(device_count))

model = prepare_model_for_kbit_training(model)

Expand Down Expand Up @@ -3579,16 +3621,38 @@ def test_causal_lm_training_single_gpu_torchao_int4_raises(self):
def test_causal_lm_training_multi_gpu_torchao(self, quant_type):
from transformers import TorchAoConfig

device_map = {
"model.decoder.embed_tokens": 0,
"lm_head": 0,
"model.decoder.embed_positions": 0,
"model.decoder.project_out": 0,
"model.decoder.project_in": 0,
"model.decoder.layers.0": 0,
"model.decoder.layers.1": 0,
"model.decoder.layers.2": 0,
"model.decoder.layers.3": 0,
"model.decoder.layers.4": 0,
"model.decoder.layers.5": 0,
"model.decoder.layers.6": 1,
"model.decoder.layers.7": 1,
"model.decoder.layers.8": 1,
"model.decoder.layers.9": 1,
"model.decoder.layers.10": 1,
"model.decoder.layers.11": 1,
"model.decoder.final_layer_norm": 1,
}

with tempfile.TemporaryDirectory() as tmp_dir:
quantization_config = TorchAoConfig(quant_type=quant_type)
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map="auto",
device_map=device_map,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)

assert set(model.hf_device_map.values()) == set(range(device_count))
assert {p.device.index for p in model.parameters()} == set(range(device_count))

model = prepare_model_for_kbit_training(model)
model.model_parallel = True
Expand Down Expand Up @@ -3640,15 +3704,36 @@ def test_causal_lm_training_multi_gpu_torchao_int4_raises(self):
# TODO: Once proper torchao support for int4 is added, remove this test and add int4 to supported_quant_types
from transformers import TorchAoConfig

device_map = {
"model.decoder.embed_tokens": 0,
"lm_head": 0,
"model.decoder.embed_positions": 0,
"model.decoder.project_out": 0,
"model.decoder.project_in": 0,
"model.decoder.layers.0": 0,
"model.decoder.layers.1": 0,
"model.decoder.layers.2": 0,
"model.decoder.layers.3": 0,
"model.decoder.layers.4": 0,
"model.decoder.layers.5": 0,
"model.decoder.layers.6": 1,
"model.decoder.layers.7": 1,
"model.decoder.layers.8": 1,
"model.decoder.layers.9": 1,
"model.decoder.layers.10": 1,
"model.decoder.layers.11": 1,
"model.decoder.final_layer_norm": 1,
}
quantization_config = TorchAoConfig(quant_type="int4_weight_only")
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map="auto",
device_map=device_map,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)

assert set(model.hf_device_map.values()) == set(range(device_count))
assert {p.device.index for p in model.parameters()} == set(range(device_count))

model = prepare_model_for_kbit_training(model)
model.model_parallel = True
Expand Down
Loading