Hypergraph attention: 3-way attention between token triplets, with O(N) memory scaling.
Clone this repository, then:
# determine CUDA toolkit version:
nvcc --version | grep -oP 'release \K[\d.]+'
# Install CUDA-enabled Torch (either cu126 or cu130 depending on nvcc output)
pip install torch --extra-index-url https://download.pytorch.org/whl/cu130
pip install torch --extra-index-url https://download.pytorch.org/whl/cu126
# If something goes wrong, you can do:
pip install --force-reinstall torch --extra-index-url https://download.pytorch.org/whl/cu130
pip install --force-reinstall torch --extra-index-url https://download.pytorch.org/whl/cu126
# Install
cd att3ntion
pip install -r requirements.txt
# This will launch the CUDA compilation
# The --no-build-isolation flag is mandatory, and considered best practice
pip install -e . --no-build-isolation
# Run demo
python demo.py # Basic usage
python demo.py --memory # Memory scaling benchmark
python demo.py --train # Train on arithmetic task
python demo.py --all # Run everythinguv --version # must be >= 0.11
# to get or update uv, use:
curl -LsSf https://astral.sh/uv/install.sh | sh
hash -r
# clone
git clone git@github.com:tlh24/att3ntion
cd att3ntion
# determine CUDA toolkit version:
nvcc --version | grep -oP 'release \K[\d.]+'
# install dependencies (does not compile yet)
# choose cu130 or cu126 based on nvcc output
uv sync --extra cu130 --no-install-package att3ntion
uv sync --extra cu126 --no-install-package att3ntion
# launch the compilation
uv sync --extra cu130
uv sync --extra cu126
source .venv/bin/activatefrom att3ntion import HypergraphAttention
# Create layer (drop-in replacement for attention)
layer = HypergraphAttention(d_model=64, n_heads=4).cuda()
# Forward pass
x = torch.randn(batch_size, seq_len, d_model, device='cuda')
y = layer(x) # Same shape as input-
3-way token interactions: Models relationships between triplets
$(Q, R, S)$ instead of pairwise relationships typical in standard self-attention. For more information & motivation, see Theory, below. -
O(N) memory: The naive implementation would require
$O(N^3)$ memory for the attention tensor. This library uses Flash-style tiling and online streaming techniques to reduce memory complexity from$O(N^3)$ to$O(N)$ , making 3-way attention feasible. However, the kernels are not yet fully optimized for speed — this is a work in progress!
| Seq Length | att3ntion | Naive O(N³) | Savings |
|---|---|---|---|
| 64 | ~50 MB | ~67 MB | 1.3x |
| 256 | ~200 MB | ~4.3 GB | 21x |
| 1024 | ~800 MB | ~275 GB | 344x |
Run python demo.py --memory to benchmark on your hardware.
- Full autograd support: Custom CUDA kernels with hand-written backward pass equivalent to torch.autograd.
- Developed and tested on an NVIDIA RTX 4080. Should work on any CUDA-capable GPU with compute capability 7.0+ (Volta and newer).
- Python 3.10+
- PyTorch 2.0+ with CUDA support
- CUDA Toolkit 11.8+
- NVIDIA GPU (tested on RTX 4080)
- {You may need to downgrade GCC to work with NVCC}
Normal attention measures the dot-product similarity between two projected versions the tokens,
Above,
L1 attention experimented with using the L1 norm to measure similarity between
Conditional scatter and gather operations are core elements of CS algorithms, and given enough memory and time1, they should be able to approximate any function. Yet some seemingly fundamental operations are poorly expressed in these conditional message passing primitives. For example, model inference: based on structure between pairs of tokens (detectable, presumably, in the latent space), you ought to modify another token - not either of the pair. Alternately, you cannot by fiat 'create' a linkage or edge between tokens with normal attention: attention is ephemeral and dependent on structure within the latent dimensions; it cannot be easily modified, only instantiated. Creating linkages requires being able to 'write' to more than one token at the same time.
An experimental solution to this problem is to allow for higher-order operations on the graph of relations between tokens; the first step of which is to increase the arity of the attention operation.
Assume that you measure some similarity between three tokens,
I.e. attention is one more term in the summation (this is no longer a dot-product!), and the resulting tensor is one higher dimension. To propagate information within the tuples, we can do several operations:
- Reduce along one dimension: e.g. for 'conventional' Q-R interactions, we sum over all S's), then apply softmax to the R dimension, selecting one
$V_R$ to write to Q. This just reduces to conventional attention by way of softmax. - Reduce and softmax along one dimension, but write twice: for Q-R and Q-S interactions, reduce over S and R, then softmax over R and S, selecting a pair
$V_r,V_s$ for writing. This reduces to conventional attention, with two writes per head. - Softmax serially along two dimensions, write twice: for Q-R and Q-S interactions, softmax over S and R, then softmax over R and S, selecting a pair
$V_r,V_s$ for writing. More interesting due to the (decomposed) 2D softmax, which supports interaction terms. - Reduce and softmax along two dimensions: for Q-R-S interactions, softmax over R-S dimensions, selecting one pair
$V_r,V_s$ for writing to Q. R can modulate which S is gathered, and S can modulate which R is gathered - promising. - Reduce and softmax along two dimensions: for Q-R-S interactions, softmax over R-S dimensions, select
$V_q$ for writing to the pair R,S. This is the scatter complement to above.
Options 1-2 reduce to conventional attention; option 3 is more interesting, but options 4 and 5 support the 3-way interactions desired.
Gather operations:
Where
Scatter operations:
As mentioned above, the scatter
The obvious problem with calculating directly via above is that you don't want to instantiate a
The above summations would suggest that a full
att3ntion/ # repo root
├── att3ntion/ # Python package
│ ├── __init__.py # public API: HypergraphAttention
│ ├── _autograd.py # CUDA-backed layer + autograd bridge
│ └── _naive.py # naive O(N³) reference (testing only)
├── cpp/ # pybind11 C++ bindings
│ ├── cuda_bindings.cpp/.h # Python ↔ CUDA glue
│ └── torch_reference.cpp # torch-based reference kernel
├── cuda/ # hand-written CUDA kernels
│ ├── forward.cu # forward pass
│ ├── backward.cu # backward pass
│ └── common.cuh # shared constants & device utilities
├── tests/ # benchmarks, equivalence tests, Makefile
├── demo.py # interactive demo
├── setup.py # build config (C extensions)
└── pyproject.toml # package metadata
Footnotes
-
Normal transformers can have limitless computation via autoregression, but are highly limited in internal memory by their fixed latent space. Sure, you can have an expanding list of tokens in the (also limited) context window, but each head is limited in the number of latent-space "named" global variables. This is a deep problem that can be partly addressed by hypergraph attention. ↩