Skip to content

Sharding API improvements (non breaking)#4893

Merged
copybara-service[bot] merged 1 commit into
google:mainfrom
IvyZX:mesh-ctx
Aug 19, 2025
Merged

Sharding API improvements (non breaking)#4893
copybara-service[bot] merged 1 commit into
google:mainfrom
IvyZX:mesh-ctx

Conversation

@IvyZX

@IvyZX IvyZX commented Aug 14, 2025

Copy link
Copy Markdown
Collaborator

A splitted version of #4844 that contains all the sharding API benefits but is not breaking:

  • Allow Flax to recognize the new style JAX mesh context (which based off AbstractMesh).

  • Renamed annotation from sharding to sharding_names, so that variable.sharding can just point to the JAX sharding of the actual value.

  • Changed CI to run everything on 4 fake CPU devices, so that we no longer to distinguish between single device and multi device tests.

  • Removed a few redundant functions like nnx.with_sharding_constraint.

@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Comment thread examples/gemma/utils.py
Comment on lines +38 to +41
if hasattr(jax.sharding, 'use_mesh'):
set_mesh = jax.sharding.use_mesh
else:
set_mesh = jax.set_mesh

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it be backwards?

Suggested change
if hasattr(jax.sharding, 'use_mesh'):
set_mesh = jax.sharding.use_mesh
else:
set_mesh = jax.set_mesh
if hasattr(jax, 'set_mesh'):
set_mesh = jax.set_mesh
else:
set_mesh = jax.sharding.use_mesh

@IvyZX IvyZX Aug 18, 2025

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well this is effectively equivalent - a released JAX will either have jax.set_mesh or jax.sharding.use_mesh. If JAX releases today or tomorrow I will roll forward and remove this.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't jax.sharding.use_mesh still going to be there?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's deleted in current head

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see.

Comment thread flax/nnx/spmd.py
mesh=mesh,
**metadata,
)
def get_abstract_model(init_fn, mesh):

@cgarciae cgarciae Aug 19, 2025

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we could extract the ShapeDtypeStruct mapping procedure into its own get_shape_dtypes function such that get_abstract_model is not needed and the user could express the same as:

with jax.set_mesh(mesh):
  gdef, abs_state = nnx.split(jax.eval_shape(init_fn))
  abs_state = nnx.get_shape_dtypes(abs_state)

Currently get_abstract_model feels a bit magical e.g. hard to know what its doing.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea. I'll try to use this in the gspmd guide.

@copybara-service copybara-service Bot merged commit fffc4db into google:main Aug 19, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants