Skip to content

首词时间和出词速度相关问题 #9

@rainstorm12

Description

@rainstorm12

您好!首先向优秀的工作致敬,但是我在测试moba的首词和出词上遇到了一些问题。

我在L40s上用llama3.1-8b-instruct作为基础模型上简单测试了一下moba的首词和出词时间(用的是llama.py这个文件)

约8k长度数据测试LLama-3.1-8B-Instruct:

原始模型 flash attention 2 moba moba naive
首词时间 1.444s 1.406s 1.687s OOM
出词速度 19.9token/s 31.7token/s 31.7token/s OOM

约4k长度数据测试LLama-3.1-8B-Instruct:

原始模型 flash attention 2 moba moba naive
首词时间 0.659s 0.648s 0.766s 4.289s
出词速度 28.0token/s 34.5token/s 34.5token/s 34.5token/s

看样子首词时间moba还要比flash attention 2慢了一些,出词时间一致,按说moba是不是应该会比flash attention 2快一点才对?

我的代码:

    args = parser.parse_args()
    register_moba(MoBAConfig(args.moba_chunk_size, args.moba_topk))
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        trust_remote_code=True,
        device_map="auto",
        torch_dtype=torch.float16,
        attn_implementation=args.attn,
    )
    tknz = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    streamer = None
    with open('8k.txt', 'r') as f:
        user_prompt = f.read().strip()
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": user_prompt }
    ]
    prompt = tknz.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    print(prompt)
    cnt = 5
    TTFTs, TPTs = [], []
    for i in range(cnt):
        # 首词
        start_time = time.time()
        input_tokens = tknz.encode(prompt)
        input_ids = torch.tensor([input_tokens], device=model.device)
        tokens = model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)
        end_time = time.time()
        TTFT = end_time-start_time
        
        # 出词
        start_time = time.time()
        input_tokens = tknz.encode(prompt)
        input_ids = torch.tensor([input_tokens], device=model.device)
        tokens = model.generate(input_ids, max_new_tokens=200, do_sample=False, streamer=streamer)
        end_time = time.time()
        TPT = (len(tokens.squeeze().tolist())-input_ids.size(1)-1)/(end_time-start_time-TTFT)

        print("首词时间: {:.3f}s".format(TTFT))
        print("出词速度: {:.1f}token/s".format(TPT))
        TTFTs.append(TTFT)
        TPTs.append(TPT)

    print("平均首词时间: {:.3f}s".format(sum(TTFTs)/cnt))
    print("平均出词速度: {:.1f}token/s".format(sum(TPTs)/cnt))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions