-
Notifications
You must be signed in to change notification settings - Fork 128
Closed
Description
您好!首先向优秀的工作致敬,但是我在测试moba的首词和出词上遇到了一些问题。
我在L40s上用llama3.1-8b-instruct作为基础模型上简单测试了一下moba的首词和出词时间(用的是llama.py这个文件)
约8k长度数据测试LLama-3.1-8B-Instruct:
| 原始模型 | flash attention 2 | moba | moba naive | |
|---|---|---|---|---|
| 首词时间 | 1.444s | 1.406s | 1.687s | OOM |
| 出词速度 | 19.9token/s | 31.7token/s | 31.7token/s | OOM |
约4k长度数据测试LLama-3.1-8B-Instruct:
| 原始模型 | flash attention 2 | moba | moba naive | |
|---|---|---|---|---|
| 首词时间 | 0.659s | 0.648s | 0.766s | 4.289s |
| 出词速度 | 28.0token/s | 34.5token/s | 34.5token/s | 34.5token/s |
看样子首词时间moba还要比flash attention 2慢了一些,出词时间一致,按说moba是不是应该会比flash attention 2快一点才对?
我的代码:
args = parser.parse_args()
register_moba(MoBAConfig(args.moba_chunk_size, args.moba_topk))
model = AutoModelForCausalLM.from_pretrained(
args.model,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16,
attn_implementation=args.attn,
)
tknz = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
streamer = None
with open('8k.txt', 'r') as f:
user_prompt = f.read().strip()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_prompt }
]
prompt = tknz.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(prompt)
cnt = 5
TTFTs, TPTs = [], []
for i in range(cnt):
# 首词
start_time = time.time()
input_tokens = tknz.encode(prompt)
input_ids = torch.tensor([input_tokens], device=model.device)
tokens = model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)
end_time = time.time()
TTFT = end_time-start_time
# 出词
start_time = time.time()
input_tokens = tknz.encode(prompt)
input_ids = torch.tensor([input_tokens], device=model.device)
tokens = model.generate(input_ids, max_new_tokens=200, do_sample=False, streamer=streamer)
end_time = time.time()
TPT = (len(tokens.squeeze().tolist())-input_ids.size(1)-1)/(end_time-start_time-TTFT)
print("首词时间: {:.3f}s".format(TTFT))
print("出词速度: {:.1f}token/s".format(TPT))
TTFTs.append(TTFT)
TPTs.append(TPT)
print("平均首词时间: {:.3f}s".format(sum(TTFTs)/cnt))
print("平均出词速度: {:.1f}token/s".format(sum(TPTs)/cnt))Metadata
Metadata
Assignees
Labels
No labels