😎 HiP Attention could extend the model context length training-free and can serve 3 million tokens with a single L40S 48GB GPU while achieving a 7.24 estimated speedup.
| Paper (Arxiv, InfiniteHiP latest) | Paper (ICLR 2025) | SGlang Integration |
Note
You can try it in our Playground in DeepAuto.ai!
Important
This is NOT yet free for commercial use. The license is FSL-1.1-MIT, which is free for non-commercial use but will automatically convert to MIT license two years after each release. Please refer to the LICENSE for more details.
- 2025.01.26: Version 1.2 is now ready! The preprint is now prepared in arxiv.
- 2025.01.22: HiP Attention is accepted in ICLR 2025!
... More News ...
- 2025.01.03: Version 1.2 will be released soon. The new version fully supports context extension and better controls pruning hierarchy. It will also have better SGlang support (with proper KV offloading!)
- 2024.10.05: Version 1.1 is now ready, check
ainl-hip-offload. KV offloading feature in under alpha state. - 2024.09.09: Version 1.1 will be released soon. Please refer to the
ainl-hip-attention2branch for a preview. It will reduce the latency further and improve the accuracy (and this will fix most of the internal bugs of v1.0). It offers many more experimental options for further research (e.g., key access logs, modular design of masking kernel). As discussed in the Appendix, this release will actually have (hopefully) a KV offloading feature, either UVM or a custom cache management algorithm. Also, SGLang will be supported by this release. Please take a look at our company's fork for a preview.
hip-attn package is available on PyPI:
pip install hip-attnor using uv:
uv add hip-attnAfter installation, you can access the hip package from any project. hip is the code name of HiP attention.
import torch
from hip_attn import hip_attention_12, HiPAttentionArgs12
device = 'cuda'
batch_size = 1
kv_len = 128 * 1024
q_len = 32 * 1024
num_heads = 32
num_kv_heads = 8
head_dims = 128
dtype = torch.bfloat16
q = torch.randn(
(batch_size, q_len, num_heads, head_dims),
dtype=dtype,
device=device
)
k = torch.randn(
(batch_size, kv_len, num_kv_heads, head_dims),
dtype=dtype,
device=device,
)
v = k.clone()
output, metadata = hip_attention_12(q=q, k=k, v=v, args=HiPAttentionArgs12())
print(output.shape)
# > torch.Size([1, 32768, 32, 128])It’s recommended to use uv, a very fast Python environment manager, to create and manage Python environments. Please follow the documentation to install uv. After installing uv, you can create a new Python environment and install hip-attention using the following commands:
# Clone this repository
git clone git@github.com:DeepAuto-AI/hip-attention.git
cd hip-attention
# This install all research dev dependencies in .venv/
uv sync --no-dev # Install base dependencies first
uv sync # Then install all dependencies including no-build-isolation packages (e.g., flash-attn)
uv run pre-commit installThen you can run any python program with uv run. uv run automatically picks up .venv/ virtual environment:
- Script:
uv run src/hip_research/main/model_eval.py - Module:
uv run -m src.hip_research.main.model_eval
# Clone this repository
git clone git@github.com:DeepAuto-AI/hip-attention.git
cd hip-attention
# Make new conda environment
conda create --name hip python=3.11
conda activate hip
# Default install
pip install -e "."
# (Optional) For research benchmarks and unit tests
pip install -e "hip-research"
# Optional, depends on your CUDA environment
export CUDACXX=/usr/local/cuda/bin/nvcc
# Install SGLang with support for HiP Attention
pip install -e ".[sglang]" \
"sglang[all] @ git+https://github.com/DeepAuto-AI/sglang.git@deepauto/release#subdirectory=python" \
--no-build-isolation \
--verbose \
--find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-pythonDocker images deepauto/hip-attention are available on Docker Hub.
Docker examples are available in Running section.
See the following pages for more details:
Docker compose examples are available in docker-compose folder.
# First copy .env.example to .env
cp .env.example .env
vim .env
# Start sglang server
docker compose \
--env-file .env \
-f docker-compose/sglang-server.yaml \
--project-name hip-attention-sglang-server-local \
up
# Start sglang router
docker compose \
-f docker-compose/sglang-router.yaml \
--project-name hip-attention-sglang-router-local \
upCheck how to reproduce experiment page
@misc{willette2025_delta_attention,
title={Delta Attention: Fast and Accurate Sparse Attention Inference by Delta Correction},
author={Jeffrey Willette and Heejun Lee and Sung Ju Hwang},
year={2025},
eprint={2505.11254},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2505.11254},
}
@misc{lee2025_infinite_hip,
title={InfiniteHiP: Extending Language Model Context Up to 3 Million Tokens on a Single GPU},
author={Heejun Lee and Geon Park and Jaduk Suh and Sung Ju Hwang},
year={2025},
eprint={2502.08910},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.08910},
}
@inproceedings{lee2025_hip_attention,
title={A Training-Free Sub-quadratic Cost Transformer Model Serving Framework with Hierarchically Pruned Attention},
author={Heejun Lee and Geon Park and Youngwan Lee and Jaduk Suh and Jina Kim and Wonyong Jeong and Bumsik Kim and Hyemin Lee and Myeongjae Jeon and Sung Ju Hwang},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=PTcMzQgKmn}
}# This will update git commit hash of sglang
uv lock --upgrade-package sglang
uv sync- PyPI
rm -rf dist
uv build --no-sources
uv publish- Docker
git clone git@github.com:DeepAuto-AI/hip-attention.git
cd hip-attention
docker login
tag_git_short=$(git rev-parse --short HEAD)-sglang
tag_hip_attention_sglang=v$(uv run python -c 'import importlib.metadata; print(importlib.metadata.version("hip-attn"))')-sglang
# Build sglang server image
docker build . \
-f Dockerfile.sglang \
-t deepauto/hip-attention:latest \
-t deepauto/hip-attention:latest-sglang \
-t deepauto/hip-attention:${tag_git_short} \
-t deepauto/hip-attention:${tag_hip_attention_sglang}
# Publish sglang server image
docker push deepauto/hip-attention:latest
docker push deepauto/hip-attention:latest-sglang
docker push deepauto/hip-attention:${tag_git_short}
docker push deepauto/hip-attention:${tag_hip_attention_sglang}
# Build sglang router image
cd ../sglang
docker build . \
-f docker/Dockerfile.router \
--no-cache \
-t deepauto/sglang-router:latest \
-t deepauto/sglang-router:latest-sglang \
-t deepauto/sglang-router:${tag_git_short} \
-t deepauto/sglang-router:${tag_hip_attention_sglang}
# Publish sglang router image
docker push deepauto/sglang-router:latest
docker push deepauto/sglang-router:latest-sglang
docker push deepauto/sglang-router:${tag_git_short}
docker push deepauto/sglang-router:${tag_hip_attention_sglang}
cd -