-
Notifications
You must be signed in to change notification settings - Fork 499
/
merge_peft_adapter.py
121 lines (109 loc) · 4.2 KB
/
merge_peft_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description:
Usage:
python merge_peft_adapter.py \
--model_type llama \
--base_model path/to/llama/model \
--tokenizer_path path/to/llama/tokenizer \
--lora_model path/to/lora/model \
--output_dir path/to/output/dir
"""
import argparse
import torch
from peft import PeftModel, PeftConfig
from transformers import (
AutoModel,
AutoTokenizer,
BloomForCausalLM,
BloomTokenizerFast,
AutoModelForCausalLM,
LlamaForCausalLM,
AutoModelForSequenceClassification,
)
MODEL_CLASSES = {
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoModel, AutoTokenizer),
"llama": (LlamaForCausalLM, AutoTokenizer),
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
"auto": (AutoModelForCausalLM, AutoTokenizer),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', default=None, type=str, required=True)
parser.add_argument('--base_model', default=None, required=True, type=str,
help="Base model name or path")
parser.add_argument('--tokenizer_path', default=None, type=str,
help="Please specify tokenization path.")
parser.add_argument('--lora_model', default=None, required=True, type=str,
help="Please specify LoRA model to be merged.")
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
parser.add_argument('--output_dir', default='./merged', type=str)
parser.add_argument('--hf_hub_model_id', default='', type=str)
parser.add_argument('--hf_hub_token', default=None, type=str)
args = parser.parse_args()
print(args)
base_model_path = args.base_model
lora_model_path = args.lora_model
output_dir = args.output_dir
print(f"Base model: {base_model_path}")
print(f"LoRA model: {lora_model_path}")
peft_config = PeftConfig.from_pretrained(lora_model_path)
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if peft_config.task_type == "SEQ_CLS":
print("Loading LoRA for sequence classification model")
if args.model_type == "chatglm":
raise ValueError("chatglm does not support sequence classification")
base_model = AutoModelForSequenceClassification.from_pretrained(
base_model_path,
num_labels=1,
load_in_8bit=False,
torch_dtype=torch.float32,
trust_remote_code=True,
device_map="auto",
)
else:
print("Loading LoRA for causal language model")
base_model = model_class.from_pretrained(
base_model_path,
torch_dtype='auto',
trust_remote_code=True,
device_map="auto",
)
if args.tokenizer_path:
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
else:
tokenizer = tokenizer_class.from_pretrained(base_model_path, trust_remote_code=True)
if args.resize_emb:
base_model_token_size = base_model.get_input_embeddings().weight.size(0)
if base_model_token_size != len(tokenizer):
base_model.resize_token_embeddings(len(tokenizer))
print(f"Resize vocabulary size {base_model_token_size} to {len(tokenizer)}")
new_model = PeftModel.from_pretrained(
base_model,
lora_model_path,
device_map="auto",
torch_dtype='auto',
)
new_model.eval()
print(f"Merging with merge_and_unload...")
base_model = new_model.merge_and_unload()
print("Saving to Hugging Face format...")
tokenizer.save_pretrained(output_dir)
base_model.save_pretrained(output_dir, max_shard_size='10GB')
print(f"Done! model saved to {output_dir}")
if args.hf_hub_model_id:
print(f"Pushing to Hugging Face Hub...")
base_model.push_to_hub(
args.hf_hub_model_id,
token=args.hf_hub_token,
max_shard_size="10GB",
)
tokenizer.push_to_hub(
args.hf_hub_model_id,
token=args.hf_hub_token,
)
print(f"Done! model pushed to Hugging Face Hub: {args.hf_hub_model_id}")
if __name__ == '__main__':
main()