Skip to content

Automatically apply sharding constraints to sharded models#4844

Merged
copybara-service[bot] merged 1 commit into
google:mainfrom
IvyZX:shard
Sep 16, 2025
Merged

Automatically apply sharding constraints to sharded models#4844
copybara-service[bot] merged 1 commit into
google:mainfrom
IvyZX:shard

Conversation

@IvyZX

@IvyZX IvyZX commented Jul 23, 2025

Copy link
Copy Markdown
Collaborator

Simplify the effort to create sharded Flax NNX models.

BREAKING CHANGE: Now nnx.Variable creation will require a mesh context and automatically run jax.lax.with_sharding_constraint on the value if the sharding annotation is provided.

Basically, to create a sharded model, user only needs to do this:

mesh = jax.make_mesh(((2, 4)), ("data", "model"))
with jax.set_mesh(mesh):
  model = YourModelWithShardingAnnotations()

instead of the current boilerplate combo of nnx.jit, nnx.get_partition_spec, with_sharding_constraint and nnx.update:

@nnx.jit
def create_sharded_model():
  model = YourModelWithShardingAnnotations() # Unsharded at this moment.
  state = nnx.state(model)                   # The model's state, a pure pytree.
  pspecs = nnx.get_partition_spec(state)     # Strip out the annotations from state.
  sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
  nnx.update(model, sharded_state)           # The model is sharded now!
  return model

with mesh:
  sharded_model = create_sharded_model()

@IvyZX IvyZX requested a review from cgarciae July 23, 2025 22:40
@IvyZX IvyZX force-pushed the shard branch 2 times, most recently from 5aa49dd to ddd1560 Compare July 31, 2025 00:59
@IvyZX IvyZX changed the title [WIP] Auto mode spmd sharding utilities and docs Automatically apply sharding constraints to sharded models Jul 31, 2025
@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@IvyZX IvyZX force-pushed the shard branch 3 times, most recently from f57c05d to fe54e64 Compare August 5, 2025 23:10
@IvyZX IvyZX marked this pull request as ready for review August 5, 2025 23:21
@IvyZX IvyZX force-pushed the shard branch 2 times, most recently from 4260556 to 805ccab Compare August 6, 2025 00:45
Comment thread docs_nnx/guides/flax_gspmd.md Outdated

```{code-cell} ipython3
from typing import *
with jax._src.sharding_impls.use_mesh(auto_mesh):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with jax._src.sharding_impls.use_mesh(auto_mesh):
with jax.sharding.use_mesh(auto_mesh):

@IvyZX IvyZX Aug 7, 2025

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a temporary API compatibility diff between internal and external JAX that I'll deal with.

Comment thread docs_nnx/guides/flax_gspmd.md Outdated
Comment on lines +142 to +143
@partial(jax.vmap, spmd_axis_name=None)
@nnx.set_transform_axis_name(name=None)

@cgarciae cgarciae Aug 6, 2025

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why names=None here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_transform_axis_name is a bit less flexible than current nnx.vmap which has access to all in_axes / out_axes so can lift axis other than 0 and can avoid lifting in the None case.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that we should just use nnx.vmap for auto mode.

Comment thread flax/nnx/spmd.py Outdated
)
def get_abstract_sharding(init_fn, mesh):
with jax._src.sharding_impls.use_mesh(mesh):
abs_model = jax.eval_shape(init_fn)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our internal APIs we should assume the user might be using reference sharing and use nnx.eval_shape and nnx.state and nnx.update.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@IvyZX IvyZX force-pushed the shard branch 4 times, most recently from eeec522 to 2f4b706 Compare August 21, 2025 22:08
@IvyZX IvyZX force-pushed the shard branch 4 times, most recently from 851c6c7 to 5c62419 Compare September 15, 2025 19:24
@copybara-service copybara-service Bot merged commit 509da37 into google:main Sep 16, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants