Skip to content

Conversation

@BenjaminBossan
Copy link
Member

Some multi GPU tests had device_map="auto" but some recent changes in accelerate resulted in parameters being moved to a single device. Now set the device map explicitly to avoid that. Add a more rigorous check to ensure that the parameters are really on multiple devices.

Notes:

These tests require GPUs and are thus not part of the normal CI. Also, there is now another error with pytest tests/test_gpu_examples.py::PeftTorchaoGPUTests::test_causal_lm_training_multi_gpu_torchao_1_int8_dynamic_activation_int8_weight:

tests/test_gpu_examples.py:3629: in test_causal_lm_training_multi_gpu_torchao
    model = AutoModelForCausalLM.from_pretrained(
../transformers/src/transformers/models/auto/auto_factory.py:571: in from_pretrained
    return model_class.from_pretrained(
../transformers/src/transformers/modeling_utils.py:279: in _wrapper
    return func(*args, **kwargs)
../transformers/src/transformers/modeling_utils.py:4476: in from_pretrained
    dispatch_model(model, **device_map_kwargs)
../../../anaconda3/envs/peft/lib/python3.12/site-packages/accelerate/big_modeling.py:423: in dispatch_model
    attach_align_device_hook_on_blocks(
../../../anaconda3/envs/peft/lib/python3.12/site-packages/accelerate/hooks.py:635: in attach_align_device_hook_on_blocks
    add_hook_to_module(module, hook)
../../../anaconda3/envs/peft/lib/python3.12/site-packages/accelerate/hooks.py:167: in add_hook_to_module
    module = hook.init_hook(module)
../../../anaconda3/envs/peft/lib/python3.12/site-packages/accelerate/hooks.py:289: in init_hook
    set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

module = Linear(in_features=768, out_features=768, weight=LinearActivationQuantizedTensor(activation=<function _int8_symm_per_t...ck_size=(1, 768), device=cuda:1, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=None, quant_max=None))), tensor_name = 'weight'
device = 0, value = None, dtype = None, fp16_statistics = None
tied_params_map = {132424656371712: {0: Parameter containing:
tensor([[ 0.1152, -0.1436,  0.0554,  ...,  0.2148,  0.0835,  0.0669],
    ...0.1436,  0.0576,  ...,  0.2139,  0.0830,  0.0649]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)}}

    def set_module_tensor_to_device(
        module: nn.Module,
        tensor_name: str,
        device: Union[int, str, torch.device],
        value: Optional[torch.Tensor] = None,
        dtype: Optional[Union[str, torch.dtype]] = None,
        fp16_statistics: Optional[torch.HalfTensor] = None,
        tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
    ):
        """
        A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
        `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).
    
        Args:
            module (`torch.nn.Module`):
                The module in which the tensor we want to move lives.
            tensor_name (`str`):
                The full name of the parameter/buffer.
            device (`int`, `str` or `torch.device`):
                The device on which to set the tensor.
            value (`torch.Tensor`, *optional*):
                The value of the tensor (useful when going from the meta device to any other device).
            dtype (`torch.dtype`, *optional*):
                If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to
                the dtype of the existing parameter in the model.
            fp16_statistics (`torch.HalfTensor`, *optional*):
                The list of fp16 statistics to set on the module, used for 8 bit model serialization.
            tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`):
                A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given
                execution device, this parameter is useful to reuse the first available pointer of a shared weight on the
                device for all others, instead of duplicating memory.
        """
        # Recurse if needed
        if "." in tensor_name:
            splits = tensor_name.split(".")
            for split in splits[:-1]:
                new_module = getattr(module, split)
                if new_module is None:
                    raise ValueError(f"{module} has no attribute {split}.")
                module = new_module
            tensor_name = splits[-1]
    
        if tensor_name not in module._parameters and tensor_name not in module._buffers:
            raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
        is_buffer = tensor_name in module._buffers
        old_value = getattr(module, tensor_name)
    
        # Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight
        # in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer.
        if (
            value is not None
            and tied_params_map is not None
            and value.data_ptr() in tied_params_map
            and device in tied_params_map[value.data_ptr()]
        ):
            module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device]
            return
        elif (
            tied_params_map is not None
            and old_value.data_ptr() in tied_params_map
            and device in tied_params_map[old_value.data_ptr()]
        ):
            module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device]
            return
    
        if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
            raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
    
        param = module._parameters[tensor_name] if tensor_name in module._parameters else None
        param_cls = type(param)
    
        if value is not None:
            # We can expect mismatches when using bnb 4bit since Params4bit will reshape and pack the weights.
            # In other cases, we want to make sure we're not loading checkpoints that do not match the config.
            if old_value.shape != value.shape and param_cls.__name__ != "Params4bit":
                raise ValueError(
                    f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this looks incorrect.'
                )
    
            if dtype is None:
                # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
                value = value.to(old_value.dtype)
            elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
                value = value.to(dtype)
    
        device_quantization = None
        with torch.no_grad():
            # leave it on cpu first before moving them to cuda
            # # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0
            if (
                param is not None
                and param.device.type != "cuda"
                and torch.device(device).type == "cuda"
                and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]
            ):
                device_quantization = device
                device = "cpu"
            # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
            if isinstance(device, int):
                if is_npu_available():
                    device = f"npu:{device}"
                elif is_mlu_available():
                    device = f"mlu:{device}"
                elif is_sdaa_available():
                    device = f"sdaa:{device}"
                elif is_musa_available():
                    device = f"musa:{device}"
                elif is_hpu_available():
                    device = "hpu"
            if "xpu" in str(device) and not is_xpu_available():
                raise ValueError(f'{device} is not available, you should use device="cpu" instead')
            if value is None:
                new_value = old_value.to(device)
                if dtype is not None and device in ["meta", torch.device("meta")]:
                    if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
                        new_value = new_value.to(dtype)
    
                    if not is_buffer:
                        module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
            elif isinstance(value, torch.Tensor):
                new_value = value.to(device)
            else:
                new_value = torch.tensor(value, device=device)
            if device_quantization is not None:
                device = device_quantization
            if is_buffer:
                module._buffers[tensor_name] = new_value
            elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device):
                param_cls = type(module._parameters[tensor_name])
                kwargs = module._parameters[tensor_name].__dict__
                if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]:
                    if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
                        # downcast to fp16 if any - needed for 8bit serialization
                        new_value = new_value.to(torch.float16)
                    # quantize module that are going to stay on the cpu so that we offload quantized weights
                    if device == "cpu" and param_cls.__name__ == "Int8Params":
                        new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu")
                        new_value.CB = new_value.CB.to("cpu")
                        new_value.SCB = new_value.SCB.to("cpu")
                    else:
                        new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
                elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
                    new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device)
                elif param_cls.__name__ in ["AffineQuantizedTensor"]:
                    if importlib.util.find_spec("torchao") is not None and compare_versions("torchao", ">=", "0.7.0"):
                        # TorchAO v0.7.0 made layout_tensor an internal private variable and exposed tensor_impl
                        args = (new_value.tensor_impl,)
                    else:
                        args = (new_value.layout_tensor,)
                    args += (
                        new_value.block_size,
                        new_value.shape,
                        new_value.quant_min,
                        new_value.quant_max,
                        new_value.zero_point_domain,
                    )
                    new_value = torch.nn.Parameter(param_cls(*args), requires_grad=old_value.requires_grad).to(device)
                else:
>                   new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
E                   TypeError: LinearActivationQuantizedTensor.__new__() got an unexpected keyword argument 'requires_grad'

Some multi GPU tests had device_map="auto" but some recent changes in
accelerate resulted in parameters being moved to a single device. Now
set the device map explicitly to avoid that. Add a more rigorous check
to ensure that the parameters are really on multiple devices.
@BenjaminBossan
Copy link
Member Author

@SunMarc Could you please check if this change makes sense and also check if you have an idea about this new error. It feels like moving the device of the layer should not require creating a completely new one.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a nit

It feels like moving the device of the layer should not require creating a completely new one.

The tensor are recreated and this is not done in place, hence this is why we did that. As for why are calling param_cls, again, i'm not so sure. This is potentially related to bnb. I can dig it up with needed

Comment on lines 1652 to 1664
device_map = {
"": 0,
"model.decoder.layers.11": 1,
"model.decoder.layers.11.activation_fn": 1,
"model.decoder.layers.11.fc1": 1,
"model.decoder.layers.11.fc2": 1,
"model.decoder.layers.11.final_layer_norm": 1,
"model.decoder.layers.11.self_attn": 1,
"model.decoder.layers.11.self_attn.k_proj": 1,
"model.decoder.layers.11.self_attn.out_proj": 1,
"model.decoder.layers.11.self_attn.q_proj": 1,
"model.decoder.layers.11.self_attn.v_proj": 1,
"model.decoder.layers.11.self_attn_layer_norm": 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be better to have something like this instead:

device_map = {'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.project_out': 0, 'model.decoder.project_in': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 0, 'model.decoder.layers.7': 0, 'model.decoder.layers.8': 0, 'model.decoder.layers.9': 0, 'model.decoder.layers.10': 0, 'model.decoder.layers.11': 1, 'model.decoder.layers.12': 1, 'model.decoder.layers.13': 1, 'model.decoder.layers.14': 1, 'model.decoder.layers.15': 1, 'model.decoder.layers.16': 1, 'model.decoder.layers.17': 1, 'model.decoder.layers.18': 1, 'model.decoder.layers.19': 1, 'model.decoder.layers.20': 1, 'model.decoder.layers.21': 1, 'model.decoder.layers.22': 1, 'model.decoder.layers.23': 1}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change that, although there are only 12 layers here, so would drop everything past 11. Just for my understanding, why would this be better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh my bad, i looked at facebook/opt-350m model and it has 23 layers. but the actual quantized model is from facebook/opt-125m.
Let's use that then :
device_map = {'model.decoder.embed_tokens': 0, 'lm_head': 0, 'model.decoder.embed_positions': 0, 'model.decoder.project_out': 0, 'model.decoder.project_in': 0, 'model.decoder.layers.0': 0, 'model.decoder.layers.1': 0, 'model.decoder.layers.2': 0, 'model.decoder.layers.3': 0, 'model.decoder.layers.4': 0, 'model.decoder.layers.5': 0, 'model.decoder.layers.6': 1, 'model.decoder.layers.7': 1, 'model.decoder.layers.8': 1, 'model.decoder.layers.9': 1, 'model.decoder.layers.10': 1, 'model.decoder.layers.11': 1}

@BenjaminBossan
Copy link
Member Author

@SunMarc I took your suggested device_map, but also had to add "model.decoder.final_layer_norm": 1. Now all the tests pass, TypeError: LinearActivationQuantizedTensor.__new__() got an unexpected keyword argument 'requires_grad' is gone, it's magic!

@SunMarc
Copy link
Member

SunMarc commented Apr 8, 2025

I took your suggested device_map, but also had to add "model.decoder.final_layer_norm": 1. Now all the tests pass, TypeError: LinearActivationQuantizedTensor.new() got an unexpected keyword argument 'requires_grad' is gone, it's magic!

Wow, nice that it got fixed by magic xD

@BenjaminBossan BenjaminBossan marked this pull request as ready for review April 9, 2025 09:53
@BenjaminBossan BenjaminBossan merged commit 4c82bff into huggingface:main Apr 11, 2025
14 checks passed
@BenjaminBossan BenjaminBossan deleted the fix-multi-gpu-tests-device-map branch April 11, 2025 16:06
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
Some multi GPU tests had device_map="auto" but some recent changes in
accelerate resulted in parameters being moved to a single device. Now
set the device map explicitly to avoid that. Add a more rigorous check
to ensure that the parameters are really on multiple devices.
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
Some multi GPU tests had device_map="auto" but some recent changes in
accelerate resulted in parameters being moved to a single device. Now
set the device map explicitly to avoid that. Add a more rigorous check
to ensure that the parameters are really on multiple devices.
cyyever pushed a commit to cyyever/peft that referenced this pull request Sep 4, 2025
* smol course links and badges

* try without space

* revert space
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants