Skip to content

feat(attention): add fused SDPA graph.Node (T3.1a)#838

Merged
dndungu merged 1 commit into
mainfrom
wave-22-task-T3.1a
Apr 29, 2026
Merged

feat(attention): add fused SDPA graph.Node (T3.1a)#838
dndungu merged 1 commit into
mainfrom
wave-22-task-T3.1a

Conversation

@dndungu
Copy link
Copy Markdown
Contributor

@dndungu dndungu commented Apr 29, 2026

Summary

Wraps the existing layers/attention/scaled_dot_product_attention.go SDPA into a graph.Node[T] (FusedSDPA) so consumers (Wolf cross-attention, T3.1b) can compose it via graph.Builder without duplicating math.

  • New layers/attention/fused_sdpa_node.go: FusedSDPA[T tensor.Numeric] implementing graph.Node[T]. OpType=\"FusedSDPA\", attributes head_dim + causal, no parameters, cached OutputShape().
  • Forward accepts (Q,K,V) or (Q,K,V,mask); delegates to inner SDPA.
  • Backward delegates to inner SDPA and appends a nil grad slot for the mask input when present, so input/grad indexing stays aligned.
  • Options: WithFusedSDPABidirectional, WithFusedSDPAHeadCounts. Causal default mirrors existing SDPA convention (causal-on unless bidirectional).
  • New layers/attention/fused_sdpa_node_test.go covers fp32/fp64 x {causal, bidirectional, masked} forward+backward equivalence vs the unfused ScaledDotProductAttention chain. Tolerances 1e-6 fp32 fwd / 1e-12 fp64 fwd / 1e-5 fp32 bwd / 1e-10 fp64 bwd. Stdlib testing only.

Downstream: T3.1b (Wolf swap) consumes this node.

Test plan

  • go build ./...
  • go vet ./...
  • go test ./layers/attention/... -race -count=1 (1.29s, all green incl. new FusedSDPA cases)

@dndungu dndungu merged commit bfcd7a3 into main Apr 29, 2026
7 checks passed
@dndungu dndungu deleted the wave-22-task-T3.1a branch April 29, 2026 03:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant