Skip to content

Conversation

@zqiu24
Copy link
Contributor

@zqiu24 zqiu24 commented Jun 8, 2025

Hi there,

I want to make some changes to OFT. The major problem of OFT is it is relatively slow and not memory efficient for large models, the commit does not change the training logic but makes several improvements to make it faster.

Core Improvements:

  • Enhanced OFT layer implementation with the option of using Cayley-Neumann parametrization
  • Input-centric implementation to reduce memory and compute
  • Adding support to different quantization: bnb, awe, aqlm, hqq, eetq, gptq etc.

Hopefully, you can let me add these improvements to the OFT training code.

The below is some test on when performing fine-tuning with Qwen2.5-7B.

oftv2

Best,

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 11, 2025

@BenjaminBossan Hi, can you please make a review? It would be great. Best,

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this big update to OFT. I didn't have time for an in-depth review yet, but I did a high level review and have some comments, please check.

My biggest concern right now is that the PR is backwards incompatible for the following reasons:

  1. a new module, OFTRotationModule, is being used
  2. a new option, use_cayley_neumann is added and enabled by default
  3. some parameters are renamed

We should instead strife to make the change backwards compatible. This would mean:

  1. Not sure if possible, but initialize OFTRotationModule in a way that the module works the same as previously. Remap state_dict parameters if necessary.
  2. Default to use_cayley_neumann=False.
  3. Don't rename the parameters or ensure that state_dicts are mapped accordingly.

I think it's fine not to concern ourselves with forwards compatibility. That means it's not necessary for OFT checkpoints created after this PR is merged to be compatible with older PEFT versions.

As for the quantized layers, also thanks a lot for adding all of them. I haven't checked the details yet. Maybe it would be good to move those to a separate, self-contained PR, WDYT?

Apart from this, we'll also have to add some tests for the new parameters and quantized layers, but let's first deal with the points mentioned above.

PS: If the prospect of making these changes backwards compatible is too much, we can also consider creating a new PEFT method like "OFTv2" and leave the current OFT as is.

lora_bias=lora_bias,
)

breakpoint()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this.

@@ -0,0 +1,101 @@
# Copyright 2024-present the HuggingFace Inc. team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2024-present the HuggingFace Inc. team.
# Copyright 2025-present the HuggingFace Inc. team.

Here and below.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 15, 2025

Thanks a lot for this big update to OFT. I didn't have time for an in-depth review yet, but I did a high level review and have some comments, please check.

My biggest concern right now is that the PR is backwards incompatible for the following reasons:

  1. a new module, OFTRotationModule, is being used
  2. a new option, use_cayley_neumann is added and enabled by default
  3. some parameters are renamed

We should instead strife to make the change backwards compatible. This would mean:

  1. Not sure if possible, but initialize OFTRotationModule in a way that the module works the same as previously. Remap state_dict parameters if necessary.
  2. Default to use_cayley_neumann=False.
  3. Don't rename the parameters or ensure that state_dicts are mapped accordingly.

I think it's fine not to concern ourselves with forwards compatibility. That means it's not necessary for OFT checkpoints created after this PR is merged to be compatible with older PEFT versions.

As for the quantized layers, also thanks a lot for adding all of them. I haven't checked the details yet. Maybe it would be good to move those to a separate, self-contained PR, WDYT?

Apart from this, we'll also have to add some tests for the new parameters and quantized layers, but let's first deal with the points mentioned above.

PS: If the prospect of making these changes backwards compatible is too much, we can also consider creating a new PEFT method like "OFTv2" and leave the current OFT as is.

Dear Benjamin,

Thank you so much for this quick reply and the suggestions. I am actually the author of both OFT and this newer version of OFT (the paper will be released on arxiv once the PR gets accepted).

I totally understand the concerns with the backward compatibility, here are some explanations for the changes:

I changed to add the module OFTRotationModule and the name to oft_R to be more consistent with the naming of LoRA (lora_A and lora_B). I locally tested and it seems this is also necessary if we do peft with qlora/qoft (fine-tune quantized layers) + fsdp, otherwise it gets some data type problem with the distributed setting. It also avoids the definition of OFT training operations for each differently quantized layers seperately.

Personally I hope and think that it would be great if we could leave it as one OFTConfig (instead of adding another OFTConfigV2) to avoid confusions for the user, because essentially the algorithm does not change but is simply much faster and also supports the fine-tuning quantized layers. Is it possible to simply adding a check and detect if the checkpoint is an old one, the user should instead pip install an older version?

Best,

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for offering the perspective, I understand your argument. After some further exploration of the code, it looks like there is no easy way to make the transition backwards compatible. I think the least we should do, however, is to check if a user tries to load an old OFT checkpoint and raise an error with a helpful error message. For this, we could for instance add some code after this line:

peft_model_state_dict = {renamed_dora_weights(k): v for k, v in peft_model_state_dict.items()}

The check could be something like this:

elif config.peft_type == PeftType.OFT:
    # new OFT adapter is backwards incompatible, see #2575
    if any(".oft_s." in key for key in peft_model_state_dict):
        raise ValueError("Trying to load old OFT checkpoint, which is no longer supported. Please install PEFT <= v0.15.2 to load it or train a new OFT adapter.")

Next week, I should hopefully find the opportunity to do some testing with the new OFT method to check if I can replicate the improvements that you mentioned.

Before we can merge, we also need to update some of the tests. For instance, this test no longer works with the new OFT parameters:

peft/tests/test_common_gpu.py

Lines 1882 to 1898 in a27406c

def test_oft_add_new_adapter_does_not_change_device(self, mlp):
# same as first test, but using OFT
config = OFTConfig(target_modules=["lin0"])
model = get_peft_model(mlp, config)
model = model.to(self.device)
model.lin0.oft_r.cpu()
# check that the adapter is indeed on CPU and the base model on GPU
assert model.lin0.oft_r.default.device.type == "cpu"
assert model.lin0.base_layer.weight.device.type == self.device
model.add_adapter("other", config)
# check that after adding a new adapter, the old adapter is still on CPU
assert model.lin0.oft_r.default.device.type == "cpu"
# the rest should be on GPU
assert model.lin0.base_layer.weight.device.type == self.device
assert model.lin0.oft_r.other.device.type == self.device

Furthermore, since quantized layers are now supported, let's add GPU tests for them. As an example, here is a LoRA test for GPTQ. You can essentially copy this test and switch the method from LoRA to OFT.

self.eps = {}
self.block_share = {}
# For Embedding layer
self.oft_embedding_R = nn.ParameterDict({})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is not used, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, updated

Comment on lines 327 to 330
elif isinstance(base_layer, nn.MultiheadAttention):
if not base_layer._qkv_same_embed_dim:
raise ValueError(f"Only same dim for query/key/value is supported as of now for {self.__class__}.")
in_features, out_features = base_layer.embed_dim, 3 * base_layer.embed_dim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MHA is not supported, so this can be removed, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, updated.

warnings.warn("Unscaling operation for OFT not supported! Keeping scale to 1.")

def update_layer(self, adapter_name, r, oft_block_size, module_dropout, coft, eps, block_share, init_weights):
def _check_forward_args(self, x, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is used in LoRA to check that we can perform mixed adapter predictions (i.e. in the same batch, having some samples with adapter X and other samples with adapter Y. See:

def _mixed_batch_forward(

As this is not supported by OFT, we can remove this method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, updated

Comment on lines 656 to 657
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, since OFT does not support _mixed_batch_forward, these lines can be safely removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, updated.

new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
else:
breakpoint()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, updated

if init_weights is False:
nn.init.normal_(self.oft_r[adapter_name], mean=0.0, std=0.1)
nn.init.normal_(self.oft_s[adapter_name], mean=1.0, std=0.1)
nn.init.normal_(self.oft_R[adapter_name], mean=0.0, std=0.1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails for me because self.oft_R[adapter_name] is a OFTRotationModule, did you mean:

Suggested change
nn.init.normal_(self.oft_R[adapter_name], mean=0.0, std=0.1)
nn.init.normal_(self.oft_R[adapter_name].weight, mean=0.0, std=0.1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 17, 2025

@BenjaminBossan Thank you so much for the detailed reply and the suggestions. I just pushed a version with the required changes and added tests.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, there isn't much missing for this PR.

Apart from the smaller comments I left, ideally we should also have one test per quantization method which was added in this PR. For this, please check test_gpu_examples.py, where you can see the test classes for each quantization method. Could you please add OFT examples?

I would also like to do some of my own testing, but this will have to wait until next week. I hope this is fine.

@@ -0,0 +1,145 @@
# Copyright 2024-present the HuggingFace Inc. team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update all the years to 2025.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, updated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now the option to use OFT with use_cayley_neumann=True and use_cayley_neumann=False. Let's ensure that we have at least one test case for each one of those.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan Where do you suggest to add this test? Best,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be sufficient to extend the existing OFT test cases here:

########
# OFT #
########
("Vanilla MLP 1 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": "lin0"}),
("Vanilla MLP 2 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"]}),
("Vanilla MLP 5 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "modules_to_save": ["lin1"]}),
(
"Vanilla MLP 6 OFT",
"MLP",
OFTConfig,
{
"r": 2,
"target_modules": ["lin0"],
"module_dropout": 0.1,
},
),
("Vanilla MLP 7 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "coft": True}),
("Vanilla MLP 8 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "block_share": True}),
("Vanilla MLP 9 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "coft": True, "block_share": True}),
("Conv2d 1 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"]}),
("Conv2d 3 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"], "coft": True}),
("Conv2d 4 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"], "block_share": True}),
("Conv2d 5 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"], "coft": True, "block_share": True}),

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 23, 2025

@BenjaminBossan Hi Benjamin, just updated the test, would be great if you could provide some feedback this week? We would really like to publish the paper to Arxiv :)

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Jun 23, 2025

Hi Benjamin, just updated the test, would be great if you could provide some feedback this week? We would really like to publish the paper to Arxiv :)

@zqiu24 Okay, let's try our best to get this merged this week. I tried to run the new tests with:

pytest tests/test_custom_models.py -k "oft and not boft" -v

but I get an error:

tests/test_custom_models.py - ValueError: You can only specify either r (2) or oft_block_size (32), but not both simultaneously, because r x oft_block_size == in_features.

Could you please check if the tests pass for you and if they don't, fix the remaining issues?

I'm also currently running some experiments and get an issue when the base model dtype != float32. I think I know the error, but I'll get back to you once I'm sure I pinned it down.

PS: I think I found the cause, see my comment below.

x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
bias = self._cast_input_dtype(self.get_base_layer().bias, scaled_rotated_weight.dtype)
result = F.linear(input=x, weight=scaled_rotated_weight, bias=bias)
result = self.base_layer(x, *args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this line errors for me when the base layer is bfloat16. The reason is that x has been upcast to float32 due to x = self._cast_input_dtype(x, oft_R.weight.dtype). I could fix the problem by changing this line to:

Suggested change
result = self.base_layer(x, *args, **kwargs)
result = self.base_layer(x.to(previous_dtype), *args, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 23, 2025

@BenjaminBossan Thank you so much:) I will do my best to make the changes as fast as I get the review. Yes, the above issue is because we changed the default now.

for some reasons I have the following error when I perform test locally, it breaks:

INTERNALERROR> File "/home/wliu/anaconda3/envs/peft_pr/lib/python3.12/site-packages/coverage/sqlitedb.py", line 138, in _execute
INTERNALERROR> raise DataError(f"Couldn't use data file {self.filename!r}: {msg}") from exc
INTERNALERROR> coverage.exceptions.DataError: Couldn't use data file '/lustre/fast/fast/wliu/zqiu/peft_pr/peft/.coverage.i102.2091071.XYUWDkvx.c': disk I/O error

So I cannot see the error logs,

@BenjaminBossan
Copy link
Member

MacOS errors are unrelated

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 24, 2025

MacOS errors are unrelated

@BenjaminBossan So if I determined the best the of default hyperparameters, it is ready to be merged? Best,

@BenjaminBossan
Copy link
Member

So if I determined the best the of default hyperparameters, it is ready to be merged? Best,

I guess we can merge as is, but normally we would not want to change the default parameters later. Thus, ideally, we have them set before merging. If you have enough time to test better defaults, please go ahead.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 24, 2025

Yes, thanks. I am just curious about the performance gap. I will test it locally and give me the go from my side. Best,

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 25, 2025

@BenjaminBossan would you mind to test the current version again and see if the performance is better?

Some good news, I tested the new OFT implementation on our MetaMathQA benchmark and there are significant memory improvements, going down from 28GB to 22GB. However, the test accuracy also decreased from 48.5% to 44.5% and train loss was worse, 0.596 vs 0.636. I assume that it has to do with the new default values. Would you recommend changing the hyper-parameters to get better results?

for some reasons I have the following error when I perform test locally, it breaks:

Hmm, I haven't seen that error. Could you try re-installing pytest and pytest-cov? If that doesn't help, maybe just uninstall pytest-cov and locally remove this config line. We're not interested in test coverage right now, so this would be fine.

self._active_adapter = adapter_name

self.update_layer(adapter_name, r, oft_block_size, module_dropout, coft, eps, block_share, init_weights)
breakpoint()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forget to delete this.

@BenjaminBossan
Copy link
Member

would you mind to test the current version again and see if the performance is better?

I ran it again, with the updated code and otherwise same settings. The final results are very similar (showing most relevant metrics):

Before:

    "test accuracy": 0.4450341167551175,
    "train loss": 0.6357718685865402,
    "cuda_memory_reserved_avg": 11804291891,
    "cuda_memory_max": 22189965312,
    "cuda_memory_reserved_99th": 17720976343,
    "train_time": 1967.2921348049422,

After:

    "test accuracy": 0.4473085670962851,
    "train loss": 0.6358395928144455,
    "cuda_memory_reserved_avg": 11783439070,
    "cuda_memory_max": 22198353920,
    "cuda_memory_reserved_99th": 17720934400,
    "train_time": 2134.552857758019,

If you want, you should be quickly able to run the benchmark yourself. You need this file: method_comparison/MetaMathQA/experiments/oft/llama-3.2-3B-oft_block_size-32/adapter_config.json

{
  "auto_mapping": null,
  "base_model_name_or_path": null,
  "bias": "none",
  "block_share": false,
  "coft": false,
  "eps": 6e-05,
  "exclude_modules": null,
  "fan_in_fan_out": false,
  "inference_mode": false,
  "init_weights": true,
  "layers_pattern": null,
  "layers_to_transform": null,
  "module_dropout": 0.0,
  "modules_to_save": null,
  "num_cayley_neumann_terms": 5,
  "oft_block_size": 32,
  "peft_type": "OFT",
  "r": 0,
  "revision": null,
  "target_modules": null,
  "task_type": null,
  "use_cayley_neumann": false
}

and then run python run.py -v method_comparison/MetaMathQA/experiments/oft/llama-3.2-3B-oft_block_size-32/

(you might need some extra requirements: pip install -r method_comparison/MetaMathQA/requirements.txt)

At the end of the day, we can still merge and figure out if there are better settings later.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 26, 2025

@BenjaminBossan Thank you for the help and running of OFT tests, I really appreciate it:)
I tested locally, I think the performance gap may come from the oft_s, which currently does not exist anymore, I tried locally there is no way I can think of to use that, without significantly slower the training and increasing the memory usage. But for other tasks, OFT yields competitive performance. It can also be that for this task, we might need a bigger oft_block_size or another learning rate to get better performance. For now, I think maybe we just keep as it is for PR? Best,

@BenjaminBossan
Copy link
Member

Okay, then let's deal with potentially better experimental settings later. LMK if the PR is ready for a final review.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 26, 2025

@BenjaminBossan Please have a final review. Best,

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this update. I found only a few minor issues, they should be quick to fix, otherwise this LGTM. I only skimmed the different quantization methods like AQLM, AWQ etc. as the code is mostly the same.

I think it would be good to add tests for each supported quantization method in test_gpu_examples.py (single GPU tests are enough) to ensure that they work and keep on working, but this can be added in a later PR.

from trl import SFTTrainer
from peft import OFTConfig

if use_quantization:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This snippet uses a mix of 2 and 4 spaces for indentation. Let's use 4 spaces consistently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@@ -0,0 +1,388 @@
# Copyright 2023-present the HuggingFace Inc. team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below, the date is sometimes wrong.

Suggested change
# Copyright 2023-present the HuggingFace Inc. team.
# Copyright 2025-present the HuggingFace Inc. team.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

@pytest.mark.single_gpu_tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.single_gpu_tests
@pytest.mark.multi_gpu_tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@zqiu24
Copy link
Contributor Author

zqiu24 commented Jun 26, 2025

@BenjaminBossan Hi, is there any other changes required from my side? Best,

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the improvement to, and extension of, OFT. The changes LGTM.

@BenjaminBossan BenjaminBossan merged commit d936478 into huggingface:main Jun 26, 2025
10 of 14 checks passed
efraimdahl pushed a commit to efraimdahl/peft that referenced this pull request Jul 12, 2025
Make OFT faster and more memory efficient. This new version of OFT is
not backwards compatible with older checkpoints and vice versa. To load
older checkpoints, downgrade PEFT to 0.15.2 or lower.
cyyever pushed a commit to cyyever/peft that referenced this pull request Sep 4, 2025
* fix token_level_kl

* fix non_score_reward and rlhf_reward

* add rloo test

* update test

* fix docs

* fix doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants