Status: Proposal / RFC for the Attention op (opset-24). Seeking SIG feedback on whether to adopt Option A (new additive causal_alignment attribute) or Option B (redefine the existing is_causal + nonpad_kv_seqlen + no-past_key combination).
Summary
The opset-24 Attention operator's is_causal=1 masking is index-anchored: it masks each query as if its position equals its index within the current query block, ignoring any offset introduced by an existing key/value cache. That is correct only when the query block begins at absolute position 0 — i.e. when the number of new queries equals the number of valid keys (S_q == valid_kv_len).
For autoregressive decode, or mid-cache / chunked prefill into an external (static) KV buffer — where the query block starts partway through the cache (S_q < valid_kv_len) — each query's true position is t + offset (where offset is the number of cached keys already present before this block). The causal mask must be computed from that true position, not from the in-block index. The standard op cannot express this, so exporters fall back to building an explicit additive attn_mask, which forfeits the maskless Flash kernel and forces a per-layer branch in the graph.
This proposal lets the standard Attention op express offset-aware (position-correct) maskless causal decode directly. It is an op-expressiveness / graph-cleanliness improvement, not a performance change (see Non-goals). Notably, it largely codifies behavior ORT already ships: the past_key internal-cache decode path already applies the offset-aware rule at runtime today, even though the spec text does not describe it.
Background — how we got here (for readers new to this path)
This section motivates the problem from first principles for SIG members who know attention and ONNX generally but may not know this specific ONNX Runtime (ORT) CUDA path.
(a) What we were doing. We export LLMs to ONNX for autoregressive decode using an external/static KV cache: a preallocated key/value buffer that the model writes into one step at a time. The goal is to use the standard opset-24 Attention op with its nonpad_kv_seqlen input (which tells the kernel how many keys in the buffer are valid) so that decode runs on the maskless Flash-attention fast path — no dense attention mask materialized, minimal memory traffic.
(b) The problem we hit. Decode and mid-cache prefill attend a query block of length S_q against a buffer that already holds valid_kv_len valid keys, with S_q < valid_kv_len. The query block therefore starts at a non-zero offset into the cache. But the spec's is_causal rule is index-anchored — it masks query index t to keys 0..t, as if the block began at position 0. With a non-zero cache offset that masks each query to the wrong key range: it ignores the cached prefix and lets each query attend far fewer keys than it should. ORT correctly fail-closes here: the is_causal=1 + nonpad_kv_seqlen + no-past_key combination is rejected with a NOT_IMPLEMENTED guard rather than silently emit an autoregressively wrong mask.
(c) What we had to do instead (and why each step follows). Because the standard op can't express the offset-aware rule, the exporter must hand the kernel an explicit positional attn_mask for the prefill step that encodes the correct per-position validity. That single decision cascades:
- Explicit
attn_mask ⇒ no Flash. The maskless Flash fast path takes its causal frontier from a sequence-length input, not from a dense additive mask. Supplying an explicit mask therefore drops prefill onto the memory-efficient attention (MEA) kernel instead of Flash.
- Two different kernels per step ⇒ a per-layer branch. Decode (
S_q == 1) can still run maskless/Flash, but prefill (S_q > 1) now needs the masked/MEA path. To route between them the graph needs a phase-split If(Greater(S_q, 1)) subgraph per layer — prefill branch (masked / MEA) vs decode branch (maskless / Flash).
So a single missing capability — expressing a cache offset in the causal rule — inflates every exported transformer layer with an If subgraph and an MEA fallback, purely as a workaround.
(d) What we want to solve. Let the standard Attention op express offset-aware maskless causal decode directly, so exporters no longer need the explicit attn_mask, the MEA fallback, or the per-layer phase-split If.
(e) What the open questions are. Two shapes for the fix (Option A: a new additive attribute, recommended; Option B: redefine the existing input combination); a precision note about a count-vs-index convention that must not be conflated; a latent inconsistency where ORT's past_key path already does the offset-aware thing while the spec text doesn't describe it; and a scope decision for the SIG (fix only the external-cache path, or reconcile all paths and the spec text). All are detailed below.
The causal rule, stated positionally
Define, per batch b:
offset[b] = nonpad_kv_seqlen[b] - S_q == write_indices[b] # cached keys present before this query block
# An in-window query at index t (0-based within this step) has true absolute position:
# p = t + offset[b]
# Causal validity: key j (0-based index into the valid prefix) is valid for query t iff:
# j <= t + offset[b] ( i.e. j <= p )
# The current spec is the special case offset[b] == 0 (the query block starts at position 0):
# j <= t
The offset-aware rule and the current index-anchored rule coincide iff offset == 0 (write_indices == 0, i.e. S_q == nonpad_kv_seqlen). They differ for offset > 0 (decode and mid-cache prefill), where only the offset-aware rule is autoregressively correct. (Readers familiar with FlashAttention will recognize this offset as its top-left → bottom-right alignment switch; everywhere below we use only position/offset semantics.)
Index-anchored vs position-anchored (divergence)
Example: S_q = 2 new query positions, nonpad_kv_seqlen = 4 valid keys, so offset = 2. ✓ = attend, · = masked.
| query index t \ key j |
0 |
1 |
2 |
3 |
keys attended |
| index-anchored (today's spec text), t=0 |
✓ |
· |
· |
· |
{0} |
| index-anchored, t=1 |
✓ |
✓ |
· |
· |
{0,1} |
| position-anchored (correct under offset=2), t=0 |
✓ |
✓ |
✓ |
· |
{0,1,2} |
| position-anchored, t=1 |
✓ |
✓ |
✓ |
✓ |
{0,1,2,3} |
Index-anchored masking lets the two new queries see only the first 1–2 keys (ignoring the cached prefix); position-anchored masking correctly lets them see the full causal history up to their true position. The two agree only when offset == 0.
Code & spec citations
Verified against ORT 1.27 (@291311e7d8) and onnx 1.22. The rule above is not hypothetical — it is exactly what the Flash kernel computes, and the standard Attention op already feeds it the count needed to compute the offset.
1. ORT fail-closed guard (causal_cross_no_past) — onnxruntime/core/providers/cuda/llm/attention.cc:
- Definition
:1315-1317: causal_cross_no_past = is_causal && (q_sequence_length != total_sequence_length) && (past_sequence_length == 0).
NOT_IMPLEMENTED rejection of is_causal=1 + nonpad_kv_seqlen + no past_key: :1332-1340 (if (causal_cross_no_past && nonpad_kv_seqlen != nullptr) return ORT_MAKE_STATUS(..., NOT_IMPLEMENTED, ...)).
- Flash eligibility excludes it:
flash_eligible requires !causal_cross_no_past && attn_mask == nullptr (:1345-1356).
- Rationale comment (index-anchored masking required when no past):
:1309-1313.
2. ONNX spec defines index-anchored ("upper left") masking — onnx 1.22.0, Attention opset-24:
is_causal attribute text (verbatim): "If set to 1, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment."
- Main doc body: "If
is_causal is set to 1, attention scores above the diagonal are masked out, regardless of the attn_mask input."
(We quote this only to critique the existing text: "upper left … due to the alignment" is the index-anchored, offset == 0 rule, which gives the wrong key range when offset > 0.)
3. Contrib GroupQueryAttention (maskless by construction → Flash path, per-batch anchor) — onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc:
is_unidirectional_ = true: :99.
seqlens_k input = total_sequence_length - 1 (per-batch valid index): :138, :155; total_sequence_length CPU scalar input: :61.
- Derived to per-batch length and fed as the Flash
seqlens_k → seqused_k: group_query_attention_impl.cu:730, :889, :1156 (fast-decode path group_query_attention.cc:451). No attn_mask input exists on the op → it is maskless by construction, so it takes the Flash path (subject to the usual Flash eligibility conditions — head dim, dtype, arch, etc.).
4. FlashAttention v2.1 — correct handling of S_q < S_k — upstream Dao-AILab/flash-attention, v2.1.0 release notes / CHANGELOG: causal masking was changed so that when the query block is shorter than the cached key length (S_q < S_k), each query is masked at its true position within the full key sequence rather than at its index within the query block — exactly the offset-aware rule used here. Upstream provenance only: documented in the flash-attention changelog, not in the ORT vendored copy.
5. Offset == nonpad - S_q (the proposal's formula) — FlashAttention kernel as vendored in ORT:
- Causal valid-key bound:
contrib_ops/cuda/bert/flash_attention/mask.h:169 — col_idx_limit_right = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right) (the literal source tokens are max_seqlen_k / max_seqlen_q, bound to actual_seqlen_k / actual_seqlen_q at the Mask constructor, flash_fwd_kernel.h:262); causal ⇒ window_size_right=0 ⇒ key j valid for in-window query position t iff j <= t + (actual_seqlen_k - actual_seqlen_q).
- Per-batch
actual_seqlen_k = seqused_k[bidb] - leftpad_k: block_info.h:32-34; Mask built with (actual_seqlen_k, actual_seqlen_q): flash_fwd_kernel.h:262 (split-KV :784).
- Standard
Attention sets seqused_k = nonpad_kv_seqlen as a COUNT: attention.cc:307 (LaunchConvertNonpadKvSeqlenToFlashSeqlensK), conversion attention_mask_impl.cu:57-69 (// count, not index). The exporter computes offset = nonpad_kv_seqlen - S_q and emits that count, yielding the same causal-validity pattern (which keys each query attends) as an explicit MEA prefill mask — i.e. the masking geometry is identical, not (necessarily) the numerical output tensor bit-for-bit.
CRITICAL PRECISION (do not conflate the two conventions). Contrib GQA's seqlens_k is an index (total_len - 1); the standard Attention nonpad_kv_seqlen is a count (total_len). They drive the same Flash actual_seqlen_k path but use different input conventions — the spec must state which one it adopts (this proposal uses the count convention, matching the existing nonpad_kv_seqlen input).
Precedent
Offset-aware causal masking is not new math — it is the established autoregressive convention for S_q < S_k, already battle-tested in the ecosystem:
- FlashAttention v2.1 changed its causal masking to correctly handle
S_q < S_k (a query block shorter than the cached key length) by masking each query at its true position in the full key sequence — exactly the offset > 0 case here.
- The ONNX contrib
GroupQueryAttention op already implements exactly this: is_unidirectional=true with the causal frontier anchored per batch via seqlens_k / total_sequence_length. It is maskless by construction and therefore takes the Flash path (subject to the usual Flash eligibility conditions).
In other words, offset-aware causal decode on a cache is already standard practice — the standard Attention op simply cannot express what the contrib op and FlashAttention already do.
Proposed change
Option A (RECOMMENDED): add an additive causal_alignment attribute
Add a new attribute to Attention, e.g.:
causal_alignment : enum { index_anchored (default), position_anchored }
- When
is_causal=1 and causal_alignment=position_anchored, the causal frontier is anchored per batch using offset[b] = nonpad_kv_seqlen[b] - S_q (so key j is valid for query t iff j <= t + offset[b]). This mode requires nonpad_kv_seqlen to be present as the offset source — offset is undefined without it.
- When unset, behavior is identical to today (
index_anchored).
Backward-compatibility: strictly additive and opt-in. Models that do not set the attribute get bit-identical current (index-anchored) behavior — no behavioral change, no result change, no semantic change. Self-documenting at the graph level. Crucially, it lets the prefill step be maskless offset-aware Flash, removing the need for the per-layer phase-split If and the MEA fallback.
Option B (alternative): redefine the existing combination
Redefine is_causal=1 + nonpad_kv_seqlen present + no past_key to mean the offset-aware (position-anchored) rule.
-
Pro: no new attribute; the cleanest possible graph.
-
Con: it overloads is_causal — the masking rule becomes context-dependent on whether nonpad_kv_seqlen / past_key are present, which is harder to document and reason about. It is also not purely semantic at the conformance level: for that exact input combination, the reference-defined result changes from the index-anchored result to the position-anchored (offset-aware) result.
Honest framing: no real-world consumer breaks today, because (1) ORT already rejects that combination (so no model relies on its result), and (2) the index-anchored result for that combination is autoregressively useless anyway. But this is a change to the spec-defined result for a defined input combination, and the proposal should state that plainly rather than calling it a pure clarification.
Recommendation: Option A. The zero backward-compat surface and self-documenting nature outweigh the single extra attribute; Option B's is_causal overloading reintroduces exactly the kind of context-dependent masking ambiguity this proposal aims to remove.
Scope & boundaries
nonpad_kv_seqlen expresses only contiguous, right-aligned validity. It cannot express general or non-contiguous masking; those cases still require an explicit attn_mask. This proposal does not change that.
- Interaction of
nonpad_kv_seqlen with attn_mask must be specified. The spec must state precisely how an offset-aware causal frontier composes with a user-supplied attn_mask (e.g. intersection of validities) so the combination is not left under-defined.
- The per-batch offset formula must be pinned in the spec (
offset[b] = nonpad_kv_seqlen[b] - S_q). Leaving it implicit would re-introduce the same under-specification that this proposal is trying to eliminate.
Reconciling the existing past_key + is_causal internal-cache path
Confirmed: the standard Attention past_key + is_causal=1 internal-cache decode (Case 1) already applies the offset-aware (position-correct) rule at runtime today. On Flash Path 2 (attention.cc:339-456): seqlens_k = past + kv = total (attention.cc:422-425) and is_causal is passed through to mha_fwd_kvcache (attention.cc:449), so actual_seqlen_k = total and offset = total - S_q = past. The decode step therefore attends keys 0..past correctly — i.e. position-anchored — even though the Attention-24 spec text defines the masking as the "upper left causal bias" unconditionally (its index-anchored, offset == 0 definition). This is a latent spec-vs-implementation inconsistency that ships today.
This is not merely a code-reading: it is runtime-verified on ORT 1.27 / CUDAExecutionProvider with the standard opset-24 Attention op (past_key/past_value + is_causal=1). A 2-query decode against a 3-key cache produces output that uniquely matches the offset-aware (position-anchored) reference and excludes both the index-anchored and the unmasked references, identically in fp16 and fp32. So the runtime behavior is unambiguously offset-aware, while the spec text says otherwise.
ORT's own guard comment (attention.cc:1309-1313) is worth a careful read here. It notes that the index-vs-position masking difference "only manifests when S_q != S_kv with NO past." That is defensible as a statement about where the problem surfaces as an unhandled case — the guard rejects exactly the no-past combination — but it is easy to read as implying the divergence is unique to the no-past case. Mathematically the index-vs-position divergence is also present in the with-past decode path (S_q = 1 < total); there it is simply benign and already handled, because Path 2 supplies actual_seqlen_k = total and the kernel computes the offset-aware result. So the comment is best read as precise about handling but imprecise about existence: the divergence is general to S_q < valid_kv_len, and the no-past case is the one that is currently rejected rather than silently mis-masked.
This reframes the proposal's key selling point. The spec change largely codifies behavior ORT already ships for the internal-cache path, and extends the same offset-aware semantics to the external-cache (nonpad) path. Accordingly, the proposal explicitly offers to unify all three cache-bearing causal cases under one rule:
Whenever a cache makes total_kv > S_q, is_causal is anchored to true positions using offset = total_kv - S_q (per batch).
This covers Case 1 (past_key internal cache — already offset-correct in code), Case 2 (nonpad_kv_seqlen external cache — the case this proposal unblocks), and contrib GQA (already offset-correct via seqlens_k).
Scope is a SIG decision. Minimal: external-cache (nonpad) only — unblock Case 2, leave Case 1's latent inconsistency untouched. Broader: reconcile all three paths and fix the Attention-24 spec text so it no longer describes the masking as unconditionally index-anchored. We recommend the broader reconciliation (the inconsistency already exists; codifying it removes a real footgun), but defer the final scope to the SIG.
Non-goals / honest caveat
- This is not a performance proposal. The per-layer phase-split
If it removes has been measured at ~0 GPU cost; eliminating it does not change runtime latency. Anyone reading this expecting a speedup should stop here.
- The value is expressiveness and graph cleanliness: letting the standard
Attention op express maskless offset-aware Flash decode on external KV caches — without reaching for the contrib GroupQueryAttention op or emitting a per-layer phase-split If / MEA workaround.
- This proposal does not add general masking capability, does not change the masking of existing models, and does not (under Option A) alter any result for any model that doesn't opt in.
Summary
The opset-24
Attentionoperator'sis_causal=1masking is index-anchored: it masks each query as if its position equals its index within the current query block, ignoring any offset introduced by an existing key/value cache. That is correct only when the query block begins at absolute position 0 — i.e. when the number of new queries equals the number of valid keys (S_q == valid_kv_len).For autoregressive decode, or mid-cache / chunked prefill into an external (static) KV buffer — where the query block starts partway through the cache (
S_q < valid_kv_len) — each query's true position ist + offset(whereoffsetis the number of cached keys already present before this block). The causal mask must be computed from that true position, not from the in-block index. The standard op cannot express this, so exporters fall back to building an explicit additiveattn_mask, which forfeits the maskless Flash kernel and forces a per-layer branch in the graph.This proposal lets the standard
Attentionop express offset-aware (position-correct) maskless causal decode directly. It is an op-expressiveness / graph-cleanliness improvement, not a performance change (see Non-goals). Notably, it largely codifies behavior ORT already ships: thepast_keyinternal-cache decode path already applies the offset-aware rule at runtime today, even though the spec text does not describe it.Background — how we got here (for readers new to this path)
This section motivates the problem from first principles for SIG members who know attention and ONNX generally but may not know this specific ONNX Runtime (ORT) CUDA path.
(a) What we were doing. We export LLMs to ONNX for autoregressive decode using an external/static KV cache: a preallocated key/value buffer that the model writes into one step at a time. The goal is to use the standard opset-24
Attentionop with itsnonpad_kv_seqleninput (which tells the kernel how many keys in the buffer are valid) so that decode runs on the maskless Flash-attention fast path — no dense attention mask materialized, minimal memory traffic.(b) The problem we hit. Decode and mid-cache prefill attend a query block of length
S_qagainst a buffer that already holdsvalid_kv_lenvalid keys, withS_q < valid_kv_len. The query block therefore starts at a non-zero offset into the cache. But the spec'sis_causalrule is index-anchored — it masks query indextto keys0..t, as if the block began at position 0. With a non-zero cache offset that masks each query to the wrong key range: it ignores the cached prefix and lets each query attend far fewer keys than it should. ORT correctly fail-closes here: theis_causal=1+nonpad_kv_seqlen+ no-past_keycombination is rejected with aNOT_IMPLEMENTEDguard rather than silently emit an autoregressively wrong mask.(c) What we had to do instead (and why each step follows). Because the standard op can't express the offset-aware rule, the exporter must hand the kernel an explicit positional
attn_maskfor the prefill step that encodes the correct per-position validity. That single decision cascades:attn_mask⇒ no Flash. The maskless Flash fast path takes its causal frontier from a sequence-length input, not from a dense additive mask. Supplying an explicit mask therefore drops prefill onto the memory-efficient attention (MEA) kernel instead of Flash.S_q == 1) can still run maskless/Flash, but prefill (S_q > 1) now needs the masked/MEA path. To route between them the graph needs a phase-splitIf(Greater(S_q, 1))subgraph per layer — prefill branch (masked / MEA) vs decode branch (maskless / Flash).So a single missing capability — expressing a cache offset in the causal rule — inflates every exported transformer layer with an
Ifsubgraph and an MEA fallback, purely as a workaround.(d) What we want to solve. Let the standard
Attentionop express offset-aware maskless causal decode directly, so exporters no longer need the explicitattn_mask, the MEA fallback, or the per-layer phase-splitIf.(e) What the open questions are. Two shapes for the fix (Option A: a new additive attribute, recommended; Option B: redefine the existing input combination); a precision note about a count-vs-index convention that must not be conflated; a latent inconsistency where ORT's
past_keypath already does the offset-aware thing while the spec text doesn't describe it; and a scope decision for the SIG (fix only the external-cache path, or reconcile all paths and the spec text). All are detailed below.The causal rule, stated positionally
Define, per batch
b:The offset-aware rule and the current index-anchored rule coincide iff
offset == 0(write_indices == 0, i.e.S_q == nonpad_kv_seqlen). They differ foroffset > 0(decode and mid-cache prefill), where only the offset-aware rule is autoregressively correct. (Readers familiar with FlashAttention will recognize this offset as its top-left → bottom-right alignment switch; everywhere below we use only position/offset semantics.)Index-anchored vs position-anchored (divergence)
Example:
S_q = 2new query positions,nonpad_kv_seqlen = 4valid keys, sooffset = 2.✓= attend,·= masked.Index-anchored masking lets the two new queries see only the first 1–2 keys (ignoring the cached prefix); position-anchored masking correctly lets them see the full causal history up to their true position. The two agree only when
offset == 0.Code & spec citations
Verified against ORT 1.27 (
@291311e7d8) and onnx 1.22. The rule above is not hypothetical — it is exactly what the Flash kernel computes, and the standardAttentionop already feeds it the count needed to compute the offset.1. ORT fail-closed guard (
causal_cross_no_past) —onnxruntime/core/providers/cuda/llm/attention.cc::1315-1317:causal_cross_no_past = is_causal && (q_sequence_length != total_sequence_length) && (past_sequence_length == 0).NOT_IMPLEMENTEDrejection ofis_causal=1+nonpad_kv_seqlen+ nopast_key::1332-1340(if (causal_cross_no_past && nonpad_kv_seqlen != nullptr) return ORT_MAKE_STATUS(..., NOT_IMPLEMENTED, ...)).flash_eligiblerequires!causal_cross_no_past && attn_mask == nullptr(:1345-1356).:1309-1313.2. ONNX spec defines index-anchored ("upper left") masking — onnx 1.22.0,
Attentionopset-24:is_causalattribute text (verbatim): "If set to1, the attention masking is a lower triangular matrix when the mask is a square matrix. The attention masking has the form of the upper left causal bias due to the alignment."is_causalis set to1, attention scores above the diagonal are masked out, regardless of theattn_maskinput."(We quote this only to critique the existing text: "upper left … due to the alignment" is the index-anchored,
offset == 0rule, which gives the wrong key range whenoffset > 0.)3. Contrib
GroupQueryAttention(maskless by construction → Flash path, per-batch anchor) —onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc:is_unidirectional_ = true::99.seqlens_kinput =total_sequence_length - 1(per-batch valid index)::138,:155;total_sequence_lengthCPU scalar input::61.seqlens_k→seqused_k:group_query_attention_impl.cu:730,:889,:1156(fast-decode pathgroup_query_attention.cc:451). Noattn_maskinput exists on the op → it is maskless by construction, so it takes the Flash path (subject to the usual Flash eligibility conditions — head dim, dtype, arch, etc.).4. FlashAttention v2.1 — correct handling of
S_q < S_k— upstream Dao-AILab/flash-attention, v2.1.0 release notes / CHANGELOG: causal masking was changed so that when the query block is shorter than the cached key length (S_q < S_k), each query is masked at its true position within the full key sequence rather than at its index within the query block — exactly the offset-aware rule used here. Upstream provenance only: documented in the flash-attention changelog, not in the ORT vendored copy.5. Offset
== nonpad - S_q(the proposal's formula) — FlashAttention kernel as vendored in ORT:contrib_ops/cuda/bert/flash_attention/mask.h:169—col_idx_limit_right = min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right)(the literal source tokens aremax_seqlen_k/max_seqlen_q, bound toactual_seqlen_k/actual_seqlen_qat theMaskconstructor,flash_fwd_kernel.h:262); causal ⇒window_size_right=0⇒ keyjvalid for in-window query positiontiffj <= t + (actual_seqlen_k - actual_seqlen_q).actual_seqlen_k = seqused_k[bidb] - leftpad_k:block_info.h:32-34;Maskbuilt with(actual_seqlen_k, actual_seqlen_q):flash_fwd_kernel.h:262(split-KV:784).Attentionsetsseqused_k = nonpad_kv_seqlenas a COUNT:attention.cc:307(LaunchConvertNonpadKvSeqlenToFlashSeqlensK), conversionattention_mask_impl.cu:57-69(// count, not index). The exporter computesoffset = nonpad_kv_seqlen - S_qand emits that count, yielding the same causal-validity pattern (which keys each query attends) as an explicit MEA prefill mask — i.e. the masking geometry is identical, not (necessarily) the numerical output tensor bit-for-bit.Precedent
Offset-aware causal masking is not new math — it is the established autoregressive convention for
S_q < S_k, already battle-tested in the ecosystem:S_q < S_k(a query block shorter than the cached key length) by masking each query at its true position in the full key sequence — exactly theoffset > 0case here.GroupQueryAttentionop already implements exactly this:is_unidirectional=truewith the causal frontier anchored per batch viaseqlens_k/total_sequence_length. It is maskless by construction and therefore takes the Flash path (subject to the usual Flash eligibility conditions).In other words, offset-aware causal decode on a cache is already standard practice — the standard
Attentionop simply cannot express what the contrib op and FlashAttention already do.Proposed change
Option A (RECOMMENDED): add an additive
causal_alignmentattributeAdd a new attribute to
Attention, e.g.:is_causal=1andcausal_alignment=position_anchored, the causal frontier is anchored per batch usingoffset[b] = nonpad_kv_seqlen[b] - S_q(so keyjis valid for querytiffj <= t + offset[b]). This mode requiresnonpad_kv_seqlento be present as the offset source —offsetis undefined without it.index_anchored).Backward-compatibility: strictly additive and opt-in. Models that do not set the attribute get bit-identical current (index-anchored) behavior — no behavioral change, no result change, no semantic change. Self-documenting at the graph level. Crucially, it lets the prefill step be maskless offset-aware Flash, removing the need for the per-layer phase-split
Ifand the MEA fallback.Option B (alternative): redefine the existing combination
Redefine
is_causal=1+nonpad_kv_seqlenpresent + nopast_keyto mean the offset-aware (position-anchored) rule.Pro: no new attribute; the cleanest possible graph.
Con: it overloads
is_causal— the masking rule becomes context-dependent on whethernonpad_kv_seqlen/past_keyare present, which is harder to document and reason about. It is also not purely semantic at the conformance level: for that exact input combination, the reference-defined result changes from the index-anchored result to the position-anchored (offset-aware) result.Honest framing: no real-world consumer breaks today, because (1) ORT already rejects that combination (so no model relies on its result), and (2) the index-anchored result for that combination is autoregressively useless anyway. But this is a change to the spec-defined result for a defined input combination, and the proposal should state that plainly rather than calling it a pure clarification.
Recommendation: Option A. The zero backward-compat surface and self-documenting nature outweigh the single extra attribute; Option B's
is_causaloverloading reintroduces exactly the kind of context-dependent masking ambiguity this proposal aims to remove.Scope & boundaries
nonpad_kv_seqlenexpresses only contiguous, right-aligned validity. It cannot express general or non-contiguous masking; those cases still require an explicitattn_mask. This proposal does not change that.nonpad_kv_seqlenwithattn_maskmust be specified. The spec must state precisely how an offset-aware causal frontier composes with a user-suppliedattn_mask(e.g. intersection of validities) so the combination is not left under-defined.offset[b] = nonpad_kv_seqlen[b] - S_q). Leaving it implicit would re-introduce the same under-specification that this proposal is trying to eliminate.Reconciling the existing
past_key+is_causalinternal-cache pathConfirmed: the standard
Attentionpast_key+is_causal=1internal-cache decode (Case 1) already applies the offset-aware (position-correct) rule at runtime today. On Flash Path 2 (attention.cc:339-456):seqlens_k = past + kv = total(attention.cc:422-425) andis_causalis passed through tomha_fwd_kvcache(attention.cc:449), soactual_seqlen_k = totalandoffset = total - S_q = past. The decode step therefore attends keys0..pastcorrectly — i.e. position-anchored — even though the Attention-24 spec text defines the masking as the "upper left causal bias" unconditionally (its index-anchored,offset == 0definition). This is a latent spec-vs-implementation inconsistency that ships today.This is not merely a code-reading: it is runtime-verified on ORT 1.27 / CUDAExecutionProvider with the standard opset-24
Attentionop (past_key/past_value+is_causal=1). A 2-query decode against a 3-key cache produces output that uniquely matches the offset-aware (position-anchored) reference and excludes both the index-anchored and the unmasked references, identically in fp16 and fp32. So the runtime behavior is unambiguously offset-aware, while the spec text says otherwise.ORT's own guard comment (
attention.cc:1309-1313) is worth a careful read here. It notes that the index-vs-position masking difference "only manifests whenS_q != S_kvwith NO past." That is defensible as a statement about where the problem surfaces as an unhandled case — the guard rejects exactly the no-past combination — but it is easy to read as implying the divergence is unique to the no-past case. Mathematically the index-vs-position divergence is also present in the with-past decode path (S_q = 1 < total); there it is simply benign and already handled, because Path 2 suppliesactual_seqlen_k = totaland the kernel computes the offset-aware result. So the comment is best read as precise about handling but imprecise about existence: the divergence is general toS_q < valid_kv_len, and the no-past case is the one that is currently rejected rather than silently mis-masked.This reframes the proposal's key selling point. The spec change largely codifies behavior ORT already ships for the internal-cache path, and extends the same offset-aware semantics to the external-cache (
nonpad) path. Accordingly, the proposal explicitly offers to unify all three cache-bearing causal cases under one rule:Non-goals / honest caveat
Ifit removes has been measured at ~0 GPU cost; eliminating it does not change runtime latency. Anyone reading this expecting a speedup should stop here.Attentionop express maskless offset-aware Flash decode on external KV caches — without reaching for the contribGroupQueryAttentionop or emitting a per-layer phase-splitIf/ MEA workaround.