Skip to content

Conversation

@jiqing-feng
Copy link
Contributor

Hi @BenjaminBossan . This PR fixed gpu tests. To reproduce it: pytest tests/test_common_gpu.py::PeftGPUCommonTests::test_8bit_dora_inference

The lora 8bit model q_proj:
image

The Dora 8bit model q_proj:
image

You can see that Dora model have an extra linear: lora.dora.DoraLinearLayer.
The output cannot be the same unless the weights of lora.dora.DoraLinearLayer are all zero.

@BenjaminBossan
Copy link
Member

Hey, I don't think these changes are correct. As the comment suggests, the purpose of this test is:

check for same result with and without DoRA when initializing with init_lora_weights=False

By setting init_lora_weights=False, we ensure that the LoRA weights are initialized such that LoRA is not a no-op (with init_lora_weights=True, the default, the lora_B weight is all zeros). Setting init_lora_weights="eva" doesn't make sense in this context. Also, the tests pass both on CI and locally, so it's unclear to me what you're trying to fix.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Contributor Author

That's weird. I got totally different result on A100, will check it again.

@jiqing-feng
Copy link
Contributor Author

Hi @BenjaminBossan . Could you please run the following codes and paste your outputs?

import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, set_seed
from peft import LoraConfig, get_peft_model

set_seed(0)

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    quantization_config=BitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.float32,
).eval()

torch.manual_seed(0)
config_lora = LoraConfig(r=8, init_lora_weights=False, use_dora=False)
model = get_peft_model(model, config_lora).eval()

random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
logits_lora = model(random_input).logits

dora_model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    quantization_config=BitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.float32,
)
torch.manual_seed(0)
config_dora = LoraConfig(r=8, init_lora_weights=False, use_dora=True)
dora_model = get_peft_model(dora_model, config_dora).eval()

logits_dora = dora_model(random_input).logits

# import pdb; pdb.set_trace()

print(logits_lora)
print(logits_dora)

I got the following outputs:
image

with
bitsandbytes 0.45.4
transformers 4.51.0.dev0
peft 0.15.2.dev0
torch 2.8.0.dev20250325+cu128

on A100 single card.

@BenjaminBossan
Copy link
Member

I get:

tensor([[[-1.3665, -1.3774,  2.7763,  ..., -1.3983, -1.3762, -1.4330],
         [-2.7419, -2.7511,  0.5346,  ..., -2.8071, -2.7081, -2.8664],
         [-3.5760, -3.5855,  0.5904,  ..., -3.6521, -3.5360, -3.6858],
         [-3.8018, -3.8119,  0.6282,  ..., -3.8884, -3.7535, -3.9122],
         [-3.9198, -3.9294,  0.6355,  ..., -4.0114, -3.8708, -4.0364],
         [-3.8760, -3.8866,  0.6693,  ..., -3.9737, -3.8277, -3.9972]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)
tensor([[[-1.3665, -1.3774,  2.7763,  ..., -1.3983, -1.3762, -1.4330],
         [-2.7419, -2.7511,  0.5346,  ..., -2.8071, -2.7081, -2.8664],
         [-3.5760, -3.5855,  0.5904,  ..., -3.6521, -3.5360, -3.6858],
         [-3.8018, -3.8119,  0.6282,  ..., -3.8884, -3.7535, -3.9122],
         [-3.9198, -3.9294,  0.6355,  ..., -4.0114, -3.8708, -4.0364],
         [-3.8760, -3.8866,  0.6693,  ..., -3.9737, -3.8277, -3.9972]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

On a 4090, both with latest main of PEFT and v0.15.1. I tried transformers main and v4.50.3, bnb 0.45.3, torch 2.6.0. assert torch.allclose(logits_lora, logits_dora) passes for me too.

@jiqing-feng
Copy link
Contributor Author

That's weird. I got same results on 4090 but different results on A100. And I don't know why these 2 results can be the same because the computation is different here

@jiqing-feng
Copy link
Contributor Author

I see, it's because of the numerical loss here:
image

There is some numerical loss in A100 so the results are slightly different. I don't know why no numerical loss in 4090....

@jiqing-feng
Copy link
Contributor Author

Hi @BenjaminBossan . Do you have any idea to deal with this case?

@BenjaminBossan
Copy link
Member

I don't know why these 2 results can be the same because the computation is different here

The DoRA part itself is a no-op without any updates to the DoRA params. The LoRA part is not a no-op since init_lora_weights=False but the results should still be the same because of the seed.

There is some numerical loss in A100 so the results are slightly different. I don't know why no numerical loss in 4090....

I'm also not sure, did you check that the dtype is the same between the A100 and 4090? Also, are the torch versions identical?

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Apr 2, 2025

Both 4090 and A100 is float32 dtype. The torch version is 2.8.0.dev20250401+cu128 for both.

Exactly the same script as I proposed here

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Apr 2, 2025

Could you try out how much you need to increase tolerance for the script to pass on A100?

@jiqing-feng
Copy link
Contributor Author

Could you try out how much you need to increase tolerance for the script to pass on A100?

atol=0.2, rtol=0.1 could pass the assert.

I also selected the max and mean value of abs(logits_lora - logits_dora)
max: 0.2207
mean: 0.0235

@BenjaminBossan
Copy link
Member

Hmm, this is quite high, I don't think it makes sense to set these tolerances, as that would make the test almost meaningless. Instead, could we check the hardware and skip on A100? Maybe it's the architecture (Ampere vs Ada)?

@jiqing-feng
Copy link
Contributor Author

Hi @BenjaminBossan . As I only test A100 and 4090, don't know what kind of GPU can pass the tests. But I know XPU cannot pass so I disabled XPU only. It is enough for me. Please review the new changes. Thanks!

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this.

@BenjaminBossan BenjaminBossan merged commit 82a2a0b into huggingface:main Apr 3, 2025
14 checks passed
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
Avoid issue with numerical instability.
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
Avoid issue with numerical instability.
@jiqing-feng jiqing-feng deleted the test branch October 9, 2025 01:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants