Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
cc1e994
update for oftv2
zqiu24 Jun 8, 2025
927d7ad
update oftv2
zqiu24 Jun 8, 2025
c93084d
oftv2 make quality
zqiu24 Jun 8, 2025
45048ae
oftv2 make quality
zqiu24 Jun 8, 2025
3084c47
oftv2 make style
zqiu24 Jun 8, 2025
76d00bc
make quality oftv2
zqiu24 Jun 8, 2025
7e1a51f
passing make style oftv2
zqiu24 Jun 8, 2025
24ebbf4
update oftv2
zqiu24 Jun 8, 2025
2335f9b
update Cayley-Neumann
zqiu24 Jun 10, 2025
298ad82
update oftv2 for pr
zqiu24 Jun 15, 2025
4c7aa93
update oftv2 for pr
zqiu24 Jun 15, 2025
24ffe28
update oftv2
zqiu24 Jun 17, 2025
3851cf0
run make style
zqiu24 Jun 17, 2025
1fdc742
update oft tests
zqiu24 Jun 17, 2025
a941228
add oft gptq tests
zqiu24 Jun 17, 2025
8036666
run make style
zqiu24 Jun 17, 2025
f4779ed
update doc
zqiu24 Jun 19, 2025
467c266
update oft doc
zqiu24 Jun 19, 2025
1667420
update use_cayley_neumann test
zqiu24 Jun 23, 2025
abb17e8
update oft to pass test
zqiu24 Jun 23, 2025
7f3b61f
pass test custom models test
zqiu24 Jun 23, 2025
2bd36a9
update make style
zqiu24 Jun 23, 2025
a23f6ce
update oftv2
zqiu24 Jun 23, 2025
15453cd
update make style
zqiu24 Jun 24, 2025
145ca0d
update with make style
zqiu24 Jun 24, 2025
6dfcb25
resolve failed test
zqiu24 Jun 24, 2025
c3fd629
update bnb oft
zqiu24 Jun 24, 2025
c235be6
update oftv2 bnb
zqiu24 Jun 24, 2025
d9e4b90
update oft config
zqiu24 Jun 25, 2025
ba287d5
fix the argument passing issues for quantized baselayer
zqiu24 Jun 25, 2025
f00afa3
update doc
zqiu24 Jun 25, 2025
7f80f74
update doc
zqiu24 Jun 25, 2025
1859c61
add oft doc example
zqiu24 Jun 25, 2025
7f2792e
update peft
zqiu24 Jun 25, 2025
4b50e0c
update oft config
zqiu24 Jun 26, 2025
14b1745
update oft config
zqiu24 Jun 26, 2025
45a0728
update oft config
zqiu24 Jun 26, 2025
c158f06
update oftv2
zqiu24 Jun 26, 2025
72fe7f6
prepare for oftv2 pr
zqiu24 Jun 26, 2025
8f1d177
update for oftv2 pr
zqiu24 Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 63 additions & 5 deletions docs/source/conceptual_guides/oft.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ rendered properly in your Markdown viewer.

# Orthogonal Finetuning (OFT and BOFT)

This conceptual guide gives a brief overview of [OFT](https://huggingface.co/papers/2306.07280) and [BOFT](https://huggingface.co/papers/2311.06243), a parameter-efficient fine-tuning technique that utilizes orthogonal matrix to multiplicatively transform the pretrained weight matrices.
This conceptual guide gives a brief overview of [OFT](https://huggingface.co/papers/2306.07280), [OFTv2](https://www.arxiv.org/abs/2506.19847) and [BOFT](https://huggingface.co/papers/2311.06243), a parameter-efficient fine-tuning technique that utilizes orthogonal matrix to multiplicatively transform the pretrained weight matrices.

To achieve efficient fine-tuning, OFT represents the weight updates with an orthogonal transformation. The orthogonal transformation is parameterized by an orthogonal matrix multiplied to the pretrained weight matrix. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesnt receive any further adjustments. To produce the final results, both the original and the adapted weights are multiplied togethor.
To achieve efficient fine-tuning, OFT represents the weight updates with an orthogonal transformation. The orthogonal transformation is parameterized by an orthogonal matrix multiplied to the pretrained weight matrix. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn't receive any further adjustments. To produce the final results, both the original and the adapted weights are multiplied togethor.

Orthogonal Butterfly (BOFT) generalizes OFT with Butterfly factorization and further improves its parameter efficiency and finetuning flexibility. In short, OFT can be viewed as a special case of BOFT. Different from LoRA that uses additive low-rank weight updates, BOFT uses multiplicative orthogonal weight updates. The comparison is shown below.

Expand Down Expand Up @@ -58,13 +58,25 @@ As with other methods supported by PEFT, to fine-tune a model using OFT or BOFT,
4. Train the `PeftModel` as you normally would train the base model.


### OFT-specific parameters

`OFTConfig` allows you to control how OFT is applied to the base model through the following parameters:

- `r`: OFT rank, number of OFT blocks per injected layer. **Bigger** `r` results in more sparse update matrices with **fewer** trainable paramters. **Note**: You can only specify either `r` or `oft_block_size`, but not both simultaneously, because `r` × `oft_block_size` = layer dimension. For simplicity, we let the user speficy either `r` or `oft_block_size` and infer the other one. Default set to `r = 0`, the user is advised to set the `oft_block_size` instead for better clarity.
- `oft_block_size`: OFT block size across different layers. **Bigger** `oft_block_size` results in more dense update matrices with **more** trainable parameters. **Note**: Please choose `oft_block_size` to be divisible by layer's input dimension (`in_features`), e.g., 4, 8, 16. You can only specify either `r` or `oft_block_size`, but not both simultaneously, because `r` × `oft_block_size` = layer dimension. For simplicity, we let the user speficy either `r` or `oft_block_size` and infer the other one. Default set to `oft_block_size = 32`.
- `use_cayley_neumann`: Specifies whether to use the Cayley-Neumann parameterization (efficient but approximate) or the vanilla Cayley parameterization (exact but computationally expensive because of matrix inverse). We recommend to set it to `True` for better efficiency, but performance may be slightly worse because of the approximation error. Please test both settings (`True` and `False`) depending on your needs. Default is `False`.
- `module_dropout`: The multiplicative dropout probability, by setting OFT blocks to identity during training, similar to the dropout layer in LoRA.
- `bias`: specify if the `bias` parameters should be trained. Can be `"none"`, `"all"` or `"oft_only"`.
- `target_modules`: The modules (for example, attention blocks) to inject the OFT matrices.
- `modules_to_save`: List of modules apart from OFT matrices to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task.

### BOFT-specific parameters

`BOFTConfig` allows you to control how OFT/BOFT is applied to the base model through the following parameters:
`BOFTConfig` allows you to control how BOFT is applied to the base model through the following parameters:

- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. Smaller block size results in sparser update matrices with fewer trainable parameters. **Note**, please choose `boft_block_size` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. **Bigger** `boft_block_size` results in more dense update matrices with **more** trainable parameters. **Note**, please choose `boft_block_size` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension.
- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. Fewer blocks result in sparser update matrices with fewer trainable parameters. **Note**, please choose `boft_block_num` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. **Bigger** `boft_block_num` result in sparser update matrices with **fewer** trainable parameters. **Note**, please choose `boft_block_num` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension.
- `boft_n_butterfly_factor`: the number of butterfly factors. **Note**, for `boft_n_butterfly_factor=1`, BOFT is the same as vanilla OFT, for `boft_n_butterfly_factor=2`, the effective block size of OFT becomes twice as big and the number of blocks become half.
- `bias`: specify if the `bias` parameters should be trained. Can be `"none"`, `"all"` or `"boft_only"`.
Expand All @@ -74,6 +86,52 @@ specify either `boft_block_size` or `boft_block_num`, but not both simultaneousl



## OFT Example Usage

For using OFT for quantized finetuning with [TRL](https://github.com/huggingface/trl) for `SFT`, `PPO`, or `DPO` fine-tuning, follow the following outline:

```py
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTTrainer
from peft import OFTConfig

if use_quantization:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This snippet uses a mix of 2 and 4 spaces for indentation. Let's use 4 spaces consistently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
"model_name",
quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained("model_name")

# Configure OFT
peft_config = OFTConfig(
oft_block_size=32,
use_cayley_neumann=True,
target_modules="all-linear",
bias="none",
task_type="CAUSAL_LM"
)

trainer = SFTTrainer(
model=model,
train_dataset=ds['train'],
peft_config=peft_config,
tokenizer=tokenizer,
args=training_arguments,
data_collator=collator,
)

trainer.train()
```


## BOFT Example Usage

For an example of the BOFT method application to various downstream tasks, please refer to the following guides:
Expand Down
30 changes: 29 additions & 1 deletion src/peft/tuners/oft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available
from peft.utils import register_peft_method

from .config import OFTConfig
from .gptq import GPTQOFTLinear
from .layer import Conv2d, Linear, OFTLayer
from .model import OFTModel


__all__ = ["Conv2d", "Linear", "OFTConfig", "OFTLayer", "OFTModel"]
__all__ = [
"Conv2d",
"GPTQOFTLinear",
"Linear",
"OFTConfig",
"OFTLayer",
"OFTModel",
]

register_peft_method(name="oft", config_cls=OFTConfig, model_cls=OFTModel)


def __getattr__(name):
if (name == "Linear8bitLt") and is_bnb_available():
from .bnb import Linear8bitLt

return Linear8bitLt

if (name == "Linear4bit") and is_bnb_4bit_available():
from .bnb import Linear4bit

return Linear4bit

if (name == "EetqOFTLinear") and is_eetq_available():
from .eetq import EetqOFTLinear

return EetqOFTLinear

raise AttributeError(f"module {__name__} has no attribute {name}")
105 changes: 105 additions & 0 deletions src/peft/tuners/oft/aqlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import torch

from peft.import_utils import is_aqlm_available
from peft.tuners.oft.layer import OFTLayer
from peft.tuners.tuners_utils import BaseTunerLayer


if is_aqlm_available():
from aqlm import QuantizedLinear


class AqlmOFTLinear(torch.nn.Module, OFTLayer):
def __init__(
self,
base_layer,
adapter_name: str,
r: int = 0,
oft_block_size: int = 32,
module_dropout: float = 0.0,
init_weights: bool = True,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
**kwargs,
):
super().__init__()
OFTLayer.__init__(self, base_layer)

self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
init_weights=init_weights,
coft=coft,
eps=eps,
block_share=block_share,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)

def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported
if self.disable_adapters:
return self.base_layer(x)

for active_adapter in self.active_adapters:
if active_adapter not in self.oft_R.keys():
continue
oft_R = self.oft_R[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = x.dtype
x = self._cast_input_dtype(x, oft_R.weight.dtype)

x = oft_R(x)

result = self.base_layer(x)
if requires_conversion:
result = result.to(expected_dtype)
return result

def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep


def dispatch_aqlm(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None

if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target

if is_aqlm_available() and isinstance(target_base_layer, QuantizedLinear):
new_module = AqlmOFTLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.codes

return new_module
119 changes: 119 additions & 0 deletions src/peft/tuners/oft/awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata as importlib_metadata
from typing import Any, Optional

import packaging.version
import torch

from peft.import_utils import is_auto_awq_available
from peft.tuners.oft.layer import OFTLayer
from peft.tuners.tuners_utils import BaseTunerLayer


class AwqOFTLinear(torch.nn.Module, OFTLayer):
def __init__(
self,
base_layer,
adapter_name,
r: int = 0,
oft_block_size: int = 32,
module_dropout: float = 0.0,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
init_weights: bool = True,
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
**kwargs,
):
super().__init__()
OFTLayer.__init__(self, base_layer)

# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer

self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
coft=coft,
eps=eps,
block_share=block_share,
init_weights=init_weights,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)

def forward(self, x: torch.Tensor):
if self.disable_adapters:
result = self.quant_linear_module(x)
return result

for active_adapter in self.active_adapters:
if active_adapter not in self.oft_R.keys():
continue
oft_R = self.oft_R[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = x.dtype
x = self._cast_input_dtype(x, oft_R.weight.dtype)

x = oft_R(x)
if requires_conversion:
x = x.to(expected_dtype)

result = self.quant_linear_module(x)
return result

def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep


def dispatch_awq(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None

if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target

if is_auto_awq_available():
from awq.modules.linear import WQLinear_GEMM

if isinstance(target_base_layer, WQLinear_GEMM):
# Raise the error only at the dispatch level
AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0")
version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq"))

if AUTOAWQ_MINIMUM_VERSION > version_autoawq:
raise ImportError(
f"Found an incompatible version of auto-awq. Found version {version_autoawq}, "
f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT."
)

new_module = AwqOFTLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight

return new_module
Loading
Loading