Skip to content

[Feature] Integrate KVPress in order to try more KV cache optimization heuristics easily #10585

@vincentzed

Description

@vincentzed

What is this?

In long context-serving instances, the KV cache can grow to a large size in O(n) complexity, and optimal compression can help us reach sub linear scaling of memory growth. KVPress library integrates many unique strategies to make it easier to try different heuristics for KV cache optimization in prefill.

Background

KVPress provides a set of "presses" that compress the KV cache during the prefilling-phase. Each press is associated with a compression_ratio attribute that measures the compression of the cache. The easiest way to use a press is through our custom KVPressTextGenerationPipeline. It is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported and handles chat templates and tokenization for you:

Such a code example:

from transformers import pipeline
from kvpress import ExpectedAttentionPress

device = "cuda:0"
model = "meta-llama/Llama-3.1-8B-Instruct"
model_kwargs = {"attn_implementation": "flash_attention_2"}
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)

context = "A very long text you want to compress once and for all"
question = "\nA question about the compressed context"  # optional

press = ExpectedAttentionPress(compression_ratio=0.5)
answer = pipe(context, question=question, press=press)["answer"]

This compression is focused on the pre-fill phase, for example in prompt caching.

Design

The generic BasePress will use compress to return the compressed key and value tensor, with a lower sequence length dim. Some scores are based on heuristics to choose the KV pair with the lowest importance (there are 11 such presses so far). There are also other less straightforward methods of KV cache compression (I.e not the score based system above), which is part of the extensibility strategy we want to promote.

There is no quantization by default, but it can be enabled, with arbitrary precision. For now, Llama, Mistral, Phi, Qwen 2 and 3, and Gemma 3 are all tested.

It works with multiple GPUs, through the hugging face accelerate lib for the e2e inference.

https://huggingface.co/docs/accelerate/en/index

In theory, it should be possible to use this in tandem with KV transfer backend.

As a brief refresher, in tensor parallel, kvcache is sharded across a tension head, so each tpRank handles a subset, and in the disaggregated setup, the kvcache slice is transferred between prefill and decode.

See: python/sglang/srt/disaggregation/mooncake/conn.py

python/sglang/srt/layers/linear.py

In pipeline parallelism, the KV cache is distributed across layer. Each pipeline parallel stage handles specific layer range. And each has its distinct index from start to end layer, which is the logic to determine what layer each PP rank processes. Right now, we don't support pp on the decode side in PD moonncake backend.

Mooncake engine ref: python/sglang/srt/disaggregation/mooncake/conn.py

In data parallelism, each group has an independent KV cache replica. In this case, we can treat each KV cache separately and try to apply press on each of them independently. This would be a good place to start after the initial support.

https://github.com/NVIDIA/kvpress

#1347

Success Criteria

One or more of the following will be possible.

  • You can run inference with KVPress on one node.
  • repeat the above step with any of the three parallelism strategies, maybe more than one
  • Make it work with PD and transfer(s) backend, but this is already covered by the point above.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions