Skip to content

Commit

Permalink
prevent from merge/unmerge LoRA weights with quantized weights (#2399)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Jun 7, 2023
1 parent 2956074 commit c858395
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions onmt/modules/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(self, *args, **kwargs):
)
else:
super(QLoraLinear_cls, self).__init__(*args, bias=bias)
self.quant_type = quant_type
LoRALayer.__init__(self, r, lora_alpha, lora_dropout, merge_weights)
# Actual trainable parameters
if r > 0:
Expand All @@ -171,23 +172,27 @@ def reset_parameters(self):
nn.init.zeros_(self.lora_B)

def train(self, mode: bool = True):
super().train(mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (
self.lora_B @ self.lora_A
) * self.scaling
self.merged = False
if self.quant_type is None:
super().train(mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (
self.lora_B @ self.lora_A
) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (
self.lora_B @ self.lora_A
) * self.scaling
self.merged = True
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (
self.lora_B @ self.lora_A
) * self.scaling
self.merged = True
# cannot merge/unmerge quantized weigts with unquantized lora_X
pass

def forward(self, x: torch.Tensor):
result = self.maybe_ckpt(super().forward, x)
Expand Down

0 comments on commit c858395

Please sign in to comment.