-
Notifications
You must be signed in to change notification settings - Fork 2.1k
OFT: several improvements to make OFT faster and more memory efficient #2575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
cc1e994
update for oftv2
zqiu24 927d7ad
update oftv2
zqiu24 c93084d
oftv2 make quality
zqiu24 45048ae
oftv2 make quality
zqiu24 3084c47
oftv2 make style
zqiu24 76d00bc
make quality oftv2
zqiu24 7e1a51f
passing make style oftv2
zqiu24 24ebbf4
update oftv2
zqiu24 2335f9b
update Cayley-Neumann
zqiu24 298ad82
update oftv2 for pr
zqiu24 4c7aa93
update oftv2 for pr
zqiu24 24ffe28
update oftv2
zqiu24 3851cf0
run make style
zqiu24 1fdc742
update oft tests
zqiu24 a941228
add oft gptq tests
zqiu24 8036666
run make style
zqiu24 f4779ed
update doc
zqiu24 467c266
update oft doc
zqiu24 1667420
update use_cayley_neumann test
zqiu24 abb17e8
update oft to pass test
zqiu24 7f3b61f
pass test custom models test
zqiu24 2bd36a9
update make style
zqiu24 a23f6ce
update oftv2
zqiu24 15453cd
update make style
zqiu24 145ca0d
update with make style
zqiu24 6dfcb25
resolve failed test
zqiu24 c3fd629
update bnb oft
zqiu24 c235be6
update oftv2 bnb
zqiu24 d9e4b90
update oft config
zqiu24 ba287d5
fix the argument passing issues for quantized baselayer
zqiu24 f00afa3
update doc
zqiu24 7f80f74
update doc
zqiu24 1859c61
add oft doc example
zqiu24 7f2792e
update peft
zqiu24 4b50e0c
update oft config
zqiu24 14b1745
update oft config
zqiu24 45a0728
update oft config
zqiu24 c158f06
update oftv2
zqiu24 72fe7f6
prepare for oftv2 pr
zqiu24 8f1d177
update for oftv2 pr
zqiu24 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| # Copyright 2025-present the HuggingFace Inc. team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Any, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from peft.import_utils import is_aqlm_available | ||
| from peft.tuners.oft.layer import OFTLayer | ||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
|
||
|
|
||
| if is_aqlm_available(): | ||
| from aqlm import QuantizedLinear | ||
|
|
||
|
|
||
| class AqlmOFTLinear(torch.nn.Module, OFTLayer): | ||
| def __init__( | ||
| self, | ||
| base_layer, | ||
| adapter_name: str, | ||
| r: int = 0, | ||
| oft_block_size: int = 32, | ||
| module_dropout: float = 0.0, | ||
| init_weights: bool = True, | ||
| coft: bool = False, | ||
| eps: float = 6e-5, | ||
| block_share: bool = False, | ||
| fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) | ||
| use_cayley_neumann: bool = False, | ||
| num_cayley_neumann_terms: int = 5, | ||
| **kwargs, | ||
| ): | ||
| super().__init__() | ||
| OFTLayer.__init__(self, base_layer) | ||
|
|
||
| self._active_adapter = adapter_name | ||
| self.update_layer( | ||
| adapter_name, | ||
| r, | ||
| oft_block_size=oft_block_size, | ||
| module_dropout=module_dropout, | ||
| init_weights=init_weights, | ||
| coft=coft, | ||
| eps=eps, | ||
| block_share=block_share, | ||
| use_cayley_neumann=use_cayley_neumann, | ||
| num_cayley_neumann_terms=num_cayley_neumann_terms, | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| # note: logic differs from default Linear because merging is not supported | ||
| if self.disable_adapters: | ||
| return self.base_layer(x) | ||
|
|
||
| for active_adapter in self.active_adapters: | ||
| if active_adapter not in self.oft_R.keys(): | ||
| continue | ||
| oft_R = self.oft_R[active_adapter] | ||
|
|
||
| requires_conversion = not torch.is_autocast_enabled() | ||
| if requires_conversion: | ||
| expected_dtype = x.dtype | ||
| x = self._cast_input_dtype(x, oft_R.weight.dtype) | ||
|
|
||
| x = oft_R(x) | ||
|
|
||
| result = self.base_layer(x) | ||
| if requires_conversion: | ||
| result = result.to(expected_dtype) | ||
| return result | ||
|
|
||
| def __repr__(self) -> str: | ||
| rep = super().__repr__() | ||
| return "oft." + rep | ||
|
|
||
|
|
||
| def dispatch_aqlm( | ||
| target: torch.nn.Module, | ||
| adapter_name: str, | ||
| **kwargs: Any, | ||
| ) -> Optional[torch.nn.Module]: | ||
| new_module = None | ||
|
|
||
| if isinstance(target, BaseTunerLayer): | ||
| target_base_layer = target.get_base_layer() | ||
| else: | ||
| target_base_layer = target | ||
|
|
||
| if is_aqlm_available() and isinstance(target_base_layer, QuantizedLinear): | ||
| new_module = AqlmOFTLinear(target, adapter_name, **kwargs) | ||
| target.qweight = target_base_layer.codes | ||
|
|
||
| return new_module |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # Copyright 2025-present the HuggingFace Inc. team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import importlib.metadata as importlib_metadata | ||
| from typing import Any, Optional | ||
|
|
||
| import packaging.version | ||
| import torch | ||
|
|
||
| from peft.import_utils import is_auto_awq_available | ||
| from peft.tuners.oft.layer import OFTLayer | ||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
|
||
|
|
||
| class AwqOFTLinear(torch.nn.Module, OFTLayer): | ||
| def __init__( | ||
| self, | ||
| base_layer, | ||
| adapter_name, | ||
| r: int = 0, | ||
| oft_block_size: int = 32, | ||
| module_dropout: float = 0.0, | ||
| coft: bool = False, | ||
| eps: float = 6e-5, | ||
| block_share: bool = False, | ||
| fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) | ||
| init_weights: bool = True, | ||
| use_cayley_neumann: bool = False, | ||
| num_cayley_neumann_terms: int = 5, | ||
| **kwargs, | ||
| ): | ||
| super().__init__() | ||
| OFTLayer.__init__(self, base_layer) | ||
|
|
||
| # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter | ||
| # for backwards compatibility | ||
| self.quant_linear_module = base_layer | ||
|
|
||
| self._active_adapter = adapter_name | ||
| self.update_layer( | ||
| adapter_name, | ||
| r, | ||
| oft_block_size=oft_block_size, | ||
| module_dropout=module_dropout, | ||
| coft=coft, | ||
| eps=eps, | ||
| block_share=block_share, | ||
| init_weights=init_weights, | ||
| use_cayley_neumann=use_cayley_neumann, | ||
| num_cayley_neumann_terms=num_cayley_neumann_terms, | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| if self.disable_adapters: | ||
| result = self.quant_linear_module(x) | ||
| return result | ||
|
|
||
| for active_adapter in self.active_adapters: | ||
| if active_adapter not in self.oft_R.keys(): | ||
| continue | ||
| oft_R = self.oft_R[active_adapter] | ||
|
|
||
| requires_conversion = not torch.is_autocast_enabled() | ||
| if requires_conversion: | ||
| expected_dtype = x.dtype | ||
| x = self._cast_input_dtype(x, oft_R.weight.dtype) | ||
|
|
||
| x = oft_R(x) | ||
| if requires_conversion: | ||
| x = x.to(expected_dtype) | ||
|
|
||
| result = self.quant_linear_module(x) | ||
| return result | ||
|
|
||
| def __repr__(self) -> str: | ||
| rep = super().__repr__() | ||
| return "oft." + rep | ||
|
|
||
|
|
||
| def dispatch_awq( | ||
| target: torch.nn.Module, | ||
| adapter_name: str, | ||
| **kwargs: Any, | ||
| ) -> Optional[torch.nn.Module]: | ||
| new_module = None | ||
|
|
||
| if isinstance(target, BaseTunerLayer): | ||
| target_base_layer = target.get_base_layer() | ||
| else: | ||
| target_base_layer = target | ||
|
|
||
| if is_auto_awq_available(): | ||
| from awq.modules.linear import WQLinear_GEMM | ||
|
|
||
| if isinstance(target_base_layer, WQLinear_GEMM): | ||
| # Raise the error only at the dispatch level | ||
| AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0") | ||
| version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) | ||
|
|
||
| if AUTOAWQ_MINIMUM_VERSION > version_autoawq: | ||
| raise ImportError( | ||
| f"Found an incompatible version of auto-awq. Found version {version_autoawq}, " | ||
| f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT." | ||
| ) | ||
|
|
||
| new_module = AwqOFTLinear(target, adapter_name, **kwargs) | ||
| target.qweight = target_base_layer.qweight | ||
|
|
||
| return new_module |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This snippet uses a mix of 2 and 4 spaces for indentation. Let's use 4 spaces consistently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.