-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Adding Lora implementation for nn.Conv1d #2333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding a Conv1d implementation for LoRA. In general, this looks good, I have a few small comments, please check. Please also run make style to satisfy the linter.
Before merging, however, let's ensure that the code works correctly by adding some tests. We already have a "test factory" for the different LoRA layer types, so this is a matter of adding an entry for Conv1d. To do this, look at this code:
peft/tests/test_custom_models.py
Lines 877 to 931 in aa3f41f
| class ModelMha(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.mha = nn.MultiheadAttention(10, 2) | |
| self.lin0 = nn.Linear(10, 2) | |
| self.sm = nn.LogSoftmax(dim=-1) | |
| def forward(self, X): | |
| X = X.float() | |
| X, _ = self.mha(X, X, X) | |
| X = self.lin0(X) | |
| X = self.sm(X) | |
| return X | |
| class MockTransformerWrapper: | |
| """Mock class to behave like a transformers model. | |
| This is needed because the tests initialize the model by calling transformers_class.from_pretrained. | |
| """ | |
| @classmethod | |
| def from_pretrained(cls, model_id, torch_dtype=None): | |
| # set the seed so that from_pretrained always returns the same model | |
| torch.manual_seed(0) | |
| if torch_dtype is None: | |
| torch_dtype = torch.float32 | |
| if model_id == "MLP": | |
| return MLP().to(torch_dtype) | |
| if model_id == "EmbConv1D": | |
| return ModelEmbConv1D().to(torch_dtype) | |
| if model_id == "Conv2d": | |
| return ModelConv2D().to(torch_dtype) | |
| if model_id == "Conv3d": | |
| return ModelConv3D().to(torch_dtype) | |
| if model_id == "MLP_LayerNorm": | |
| return MLP_LayerNorm().to(torch_dtype) | |
| if model_id == "MLP2": | |
| return MLP2().to(torch_dtype) | |
| if model_id == "Conv2d2": | |
| return ModelConv2D2().to(torch_dtype) | |
| if model_id == "MHA": | |
| return ModelMha().to(torch_dtype) | |
| raise ValueError(f"model_id {model_id} not implemented") |
What we need is to add a model similar to ModelMha but using Conv1d instead. The shape of the input should be 10. The from_pretrained method should get an update to dispatch to said model.
After this, it's only a matter of adding a row to the test cases, following this format:
peft/tests/test_custom_models.py
Lines 106 to 113 in aa3f41f
| ("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}), | |
| ("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}), | |
| ("Conv2d 1 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "use_dora": True}), | |
| ("Conv2d 2 LoRA with DoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"], "use_dora": True}), | |
| ("Conv3d 1 LoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d"]}), | |
| ("Conv3d 2 LoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d", "lin0"]}), | |
| ("Conv3d 1 LoRA with DoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d"], "use_dora": True}), | |
| ("Conv3d 2 LoRA with DoRA", "Conv3d", LoraConfig, {"target_modules": ["conv3d", "lin0"], "use_dora": True}), |
I hope this makes sense. LMK if you have questions.
|
|
I think the most likely explanation is that you were using a different ruff version from what is used on CI. This would explain why CI still fails. Could you please ensure that the same version is used: ruff-0.6.9? |
|
@BenjaminBossan Yep, 0.6.9 works much better |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the updates, the PR LGTM. I tested it Hubert to have a more realistic test and it worked too (with the exception being the groups argument, but that is yet to be added by all conv layers).
Before merging, however, I just noticed one small change still needed, namely this error message which lists all supported layer types for LoRA:
peft/src/peft/tuners/lora/model.py
Lines 347 to 351 in bbb1128
| raise ValueError( | |
| f"Target module {target} is not supported. Currently, only the following modules are supported: " | |
| "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, " | |
| "`transformers.pytorch_utils.Conv1D`, `torch.nn.MultiheadAttention.`." | |
| ) |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding LoRA support for Conv1d, LGTM.
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Resolves #2241
My comment shows that the shapes match in Enformer model: #2241 (comment)
Unsure how to test further it other than to run it in some training loop