Skip to content

Conversation

yewentao256
Copy link
Member

@yewentao256 yewentao256 commented Aug 12, 2025

Purpose

Support full cuda graph for cutlass MLA (SM100) and 6% E2E Throughput Improvement

Thanks for previous work of enabling cutlass MLA on SM100!

Test

vllm serve deepseek-ai/DeepSeek-V2-Lite --port 10256 --enable-expert-parallel --data-parallel-size 2 --trust_remote_code -O '{"full_cuda_graph": true}' --cuda-graph-sizes 16 32 64 128 256 512

Acc

lm_eval --model local-completions --model_args "base_url=http://127.0.0.1:10256/v1/completions,model=deepseek-ai/DeepSeek-V2-Lite,num_concurrent=256" --tasks gsm8k
# now
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.3753|±  |0.0133|
|     |       |strict-match    |     5|exact_match||0.3707|±  |0.0133|
# without cuda graph
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.3776|±  |0.0134|
|     |       |strict-match    |     5|exact_match||0.3745|±  |0.0133|

Perf

vllm bench serve --model deepseek-ai/DeepSeek-V2-Lite --dataset-name random --num-prompts 1000 --host 127.0.0.1 --port 10256 --endpoint-type openai --endpoint /v1/completions --max-concurrency 256 --random-input-len 32 --random-output-len 512
# full cuda graph
============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             256       
Benchmark duration (s):                  85.27     
Total input tokens:                      30947     
Total generated tokens:                  506054    
Request throughput (req/s):              11.73     
Output token throughput (tok/s):         5934.46   
Total Token throughput (tok/s):          6297.38   
---------------Time to First Token----------------
Mean TTFT (ms):                          655.21    
Median TTFT (ms):                        734.63    
P99 TTFT (ms):                           1002.44   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          40.85     
Median TPOT (ms):                        40.07     
P99 TPOT (ms):                           46.15     
---------------Inter-token Latency----------------
Mean ITL (ms):                           40.58     
Median ITL (ms):                         35.89     
P99 ITL (ms):                            53.80     
==================================================
# without
============ Serving Benchmark Result ============
Successful requests:                     1000      
Maximum request concurrency:             256       
Benchmark duration (s):                  90.59     
Total input tokens:                      30947     
Total generated tokens:                  505342    
Request throughput (req/s):              11.04     
Output token throughput (tok/s):         5578.42   
Total Token throughput (tok/s):          5920.04   
---------------Time to First Token----------------
Mean TTFT (ms):                          846.60    
Median TTFT (ms):                        790.83    
P99 TTFT (ms):                           1818.41   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          43.14     
Median TPOT (ms):                        45.87     
P99 TPOT (ms):                           50.04     
---------------Inter-token Latency----------------
Mean ITL (ms):                           42.90     
Median ITL (ms):                         37.26     
P99 ITL (ms):                            55.26     
==================================================

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 changed the title [Feature] Full Cuda Graph Support for Cutlass MLA and [Feature] Full Cuda Graph Support for Cutlass MLA and 6% E2E Throughput Improvement Aug 12, 2025
@mergify mergify bot added the v1 label Aug 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables full CUDA graph support for Cutlass MLA in decode-only scenarios. The changes are minimal and correctly implemented by introducing a CutlassMLAMetadataBuilder that signals this capability. My review includes a suggestion to improve code style for better adherence to PEP 8.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 self-assigned this Aug 12, 2025
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 12, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! thanks for doing this!

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) August 12, 2025 20:31
@mgoin mgoin added the performance Performance-related issues label Aug 12, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😮 that is just about as clean as you can do it

Are there any unit tests we have for full cudagraph attention backends? Just thinking how we can test this over time

@robertgshaw2-redhat
Copy link
Collaborator

when it works out of the box >>

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256
Copy link
Member Author

yewentao256 commented Aug 13, 2025

😮 that is just about as clean as you can do it

Are there any unit tests we have for full cudagraph attention backends? Just thinking how we can test this over time

Sounds good, just adding a new unit test for it

pytest vllm/tests/compile/piecewise/test_full_cudagraph.py
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================== 2 passed, 22 skipped, 3 warnings in 119.77s (0:01:59) ====================================

@LucasWilkinson LucasWilkinson merged commit 5c3fbfe into vllm-project:main Aug 15, 2025
41 checks passed
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Aug 19, 2025
…ut Improvement (vllm-project#22763)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
divakar-amd pushed a commit to divakar-amd/vllm_upstream that referenced this pull request Aug 20, 2025
…ut Improvement (vllm-project#22763)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
…ut Improvement (vllm-project#22763)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…ut Improvement (vllm-project#22763)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…ut Improvement (vllm-project#22763)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…ut Improvement (vllm-project#22763)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants