Skip to content

zml/moe: Mosaic TPU MoE backend.#571

Open
loupicaaa wants to merge 8 commits into
masterfrom
louis/moe-tpu-backend
Open

zml/moe: Mosaic TPU MoE backend.#571
loupicaaa wants to merge 8 commits into
masterfrom
louis/moe-tpu-backend

Conversation

@loupicaaa

@loupicaaa loupicaaa commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

This PR add the TPU backend for MoE. It uses the GMM Expert parallelism kernel to perform the grouped gemms on gate_up and down weights (these weights must have a precise shape to run optimally).

This backend is currently optimized for small batches. A specialized mosaic tpu kernel will be introduced to handle bigger ones.

This implementation come with few fixes concerning the TPU sharding and the Qwen 3.5 moe model which uses the backend :

  • TPU sharding : PJRT reports TPU device coordinates in a max rank form. So it breaks folding. To avoid this we compute the Rank by not counting where coords =0.
  • Linear attention was reading padding token in the Qwen 3.5 model. So we introduce masking of them.

n.b. : The kv head repetition to satisfy the sharding will be handled a more generic way in a next PR

@loupicaaa loupicaaa requested review from Corendos and removed request for Corendos June 3, 2026 17:44
const valid_mask = zml.Tensor.arange(.{ .end = x.dim(.s) }, .i64)
.withTags(.{.s})
.cmp(.LT, token_index.convert(.i64));
recurrent_value = valid_mask.broad(value.shape()).select(recurrent_value, zml.Tensor.zeroes(recurrent_value.shape()));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can drop the zml.Tensor here to make it a bit easier to read.

.cmp(.LT, token_index.convert(.i64));
recurrent_value = valid_mask.broad(value.shape()).select(recurrent_value, zml.Tensor.zeroes(recurrent_value.shape()));
recurrent_beta = valid_mask.broad(beta.shape()).select(recurrent_beta, zml.Tensor.zeroes(recurrent_beta.shape()));
g = valid_mask.broad(g.shape()).select(g, zml.Tensor.zeroes(g.shape()));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a mouthful maybe we should have a helper function to mask along side a given axis until some upper bound.

@loupicaaa loupicaaa force-pushed the louis/moe-tpu-backend branch from e8553d7 to 3f279cc Compare June 9, 2026 16:24
@loupicaaa loupicaaa force-pushed the louis/moe-tpu-backend branch from 43e5679 to f9fb9e8 Compare June 10, 2026 13:37
experts: zml.Sharding,

pub fn init(platform: *zml.Platform) !Shardings {
if (platform.target == .tpu) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use switch over platform.target without else prong

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants