Skip to content

platforms/neuron: introduce NKI stack with SDPA kernel#528

Open
hugomano wants to merge 1 commit into
masterfrom
hugomano/neuron/nki-api
Open

platforms/neuron: introduce NKI stack with SDPA kernel#528
hugomano wants to merge 1 commit into
masterfrom
hugomano/neuron/nki-api

Conversation

@hugomano

@hugomano hugomano commented May 11, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR introduces the Neuron NKI stack into ZML and wires it into the LLM attention path.

The main goal is to make Neuron attention run through NKI kernels.
It adds the compiler/runtime plumbing needed to embed NKI kernels as Neuron custom-native calls, then uses that path for SDPA-style prefill/decode attention. Paged Attention will come later.

On inf2.8xlarge it runs :

  • 25 tok/s for Llama 3.1 8B BF16
  • 63.5 tok/s for LFM2.5-1.2B-Instruct BF16
  • 19 tok/s for Qwen3.5-9B BF16

Neuron NKI compiler integration

Adds a new zml.ops.neuronNki(...) API that lowers a Zig tensor operation to a StableHLO custom call with:

  • call_target_name = "AwsNeuronCustomNativeKernel"
  • generated operand/result layouts
  • input/output tensor signatures derived from ZML shapes and dtypes
  • a compiled NKI backend_config attached to the custom call

The compiler bridge lives in:

  • platforms/neuron/nki_kernel.zig
  • platforms/neuron/nki_kernel_compiler.py
  • platforms/neuron/nki_compiler_launcher.zig
  • platforms/neuron/python_launcher.zig

At graph construction time, ZML resolves the NKI source file from Bazel runfiles, writes a JSON compile request, invokes the sandboxed nki-cc launcher, and reads back the base64 backend config expected by Neuron’s custom-native kernel path.

The compiler target is inferred from the actual Neuron instance through nrt_get_instance_info, mapping supported families to targets.

Neuron runtime

The Neuron platform packaging is updated.

The previous setup carried a local libneuronxla.zig shim and initialized Python directly from the PJRT proxy so ZML could expose the compiler pieces expected by Neuron. With libneuronxla v0.3, the AWS package now owns more of that compiler/runtime surface, free of python, so the sandbox is rebuilt around the packaged libneuronxla payload instead of ZML’s local replacement.

  • adds nki>=0.4.0,<0.5
  • includes aws-neuronx-tools, libgcc_s1, and libstdcpp6 required as needed by the so library
  • exposes sandboxed neuronx-cc and nki-cc binaries
  • moves Python initialization out of the PJRT shim and into reusable launcher code
  • configures Neuron compiler/runtime env from the PJRT wrapper, including NEURON_CC_FLAGS, sandbox PATH etc.
  • use Shardy instead of GSPMD

NKI attention backend

Adds a new attention backend:

  • zml.attention.attention.Backend.nki
  • selected automatically for .neuron
  • exposed in examples/llm --backend=nki
  • wired into Llama metadata/session handling

⚠️ We have an issue for Paged Attention, kernels will be added in another PR.

StableHLO/lowering compatibility fixes for Neuron

I added shims in a few Neuron-sensitive lowering paths:

  • keep scalar integer broadcasts as scalar broadcasts before StableHLO emission
  • preserve gather fill semantics by replacing inactive sentinel lanes before backend indirect-memory lowering
  • preserve scatter drop semantics by replacing inactive sentinel indices and masking inactive additive updates to zero
  • keep downstream resharding after the scalar gather fast path, preventing Neuron SPMD from hoisting an all-gather before a narrow dynamic slice

These shims are isolated in zml.ops.LoweringCompatibility so model code can continue expressing normal StableHLO semantics. The gather/scatter sentinel handling is related to upstream Neuron issue aws-neuron/aws-neuron-sdk#1335

Also for topk a specific branch has been added until the Neuron team fix aws-neuron/aws-neuron-sdk#1339 . And I asked for a in graph level data movement API: aws-neuron/aws-neuron-sdk#1340

Neuron profiling workflow

Adds a repo-level Neuron profiling workflow with documentation:

  • --config=neuron-profile
  • //tools/neuron:profile
  • //tools/neuron:server
  • //tools/neuron:ingest
  • //tools/neuron:summary-json
  • //tools/neuron:summary-txt
  • //tools/neuron:summary-perfetto

Known follow-ups

  • MoE NKI kernels
  • Test kernels performance / adapt for trn2 and trn3

Related doc

Base automatically changed from raphael/fix-neuron to master May 12, 2026 07:58
@hugomano hugomano force-pushed the hugomano/neuron/nki-api branch 3 times, most recently from 39e1374 to b3e0185 Compare May 12, 2026 11:45
@hugomano hugomano changed the title platforms/neuron: introduce nki api platforms/neuron: introduce NKI stack with SDPA and paged attentions May 22, 2026
@hugomano hugomano force-pushed the hugomano/neuron/nki-api branch 2 times, most recently from fd8b023 to a2af7bf Compare May 27, 2026 15:00
@hugomano hugomano changed the base branch from master to hugomano/llama-split May 27, 2026 15:00
@hugomano hugomano force-pushed the hugomano/llama-split branch from a9f8f1f to 1969fef Compare May 28, 2026 07:52
@hugomano hugomano force-pushed the hugomano/neuron/nki-api branch from a2af7bf to c4c1b50 Compare May 28, 2026 08:02
Base automatically changed from hugomano/llama-split to master May 28, 2026 08:10
@hugomano hugomano force-pushed the hugomano/neuron/nki-api branch 4 times, most recently from f658bd1 to f1e3bcd Compare June 1, 2026 08:55
@hugomano hugomano changed the title platforms/neuron: introduce NKI stack with SDPA and paged attentions platforms/neuron: introduce NKI stack with SDPA and paged attention Jun 1, 2026
@hugomano hugomano force-pushed the hugomano/neuron/nki-api branch from 0e2a650 to 9282e56 Compare June 2, 2026 12:24
@hugomano hugomano changed the title platforms/neuron: introduce NKI stack with SDPA and paged attention platforms/neuron: introduce NKI stack with SDPA kernel Jun 11, 2026
@hugomano hugomano force-pushed the hugomano/neuron/nki-api branch from 1bebdf2 to 03f638e Compare June 11, 2026 12:44
@hugomano hugomano force-pushed the hugomano/neuron/nki-api branch from 03f638e to 530993c Compare June 11, 2026 12:46
@hugomano hugomano requested a review from Corendos June 11, 2026 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Neuron PJRT StableHLO compile rewrites sort+slice sampling path to AwsNeuronTopK with full vocab size as k

1 participant