Sharding API improvements (non breaking)#4893
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
| if hasattr(jax.sharding, 'use_mesh'): | ||
| set_mesh = jax.sharding.use_mesh | ||
| else: | ||
| set_mesh = jax.set_mesh |
There was a problem hiding this comment.
shouldn't it be backwards?
| if hasattr(jax.sharding, 'use_mesh'): | |
| set_mesh = jax.sharding.use_mesh | |
| else: | |
| set_mesh = jax.set_mesh | |
| if hasattr(jax, 'set_mesh'): | |
| set_mesh = jax.set_mesh | |
| else: | |
| set_mesh = jax.sharding.use_mesh |
There was a problem hiding this comment.
Well this is effectively equivalent - a released JAX will either have jax.set_mesh or jax.sharding.use_mesh. If JAX releases today or tomorrow I will roll forward and remove this.
There was a problem hiding this comment.
Isn't jax.sharding.use_mesh still going to be there?
There was a problem hiding this comment.
No it's deleted in current head
| mesh=mesh, | ||
| **metadata, | ||
| ) | ||
| def get_abstract_model(init_fn, mesh): |
There was a problem hiding this comment.
Wondering if we could extract the ShapeDtypeStruct mapping procedure into its own get_shape_dtypes function such that get_abstract_model is not needed and the user could express the same as:
with jax.set_mesh(mesh):
gdef, abs_state = nnx.split(jax.eval_shape(init_fn))
abs_state = nnx.get_shape_dtypes(abs_state)Currently get_abstract_model feels a bit magical e.g. hard to know what its doing.
There was a problem hiding this comment.
This is a good idea. I'll try to use this in the gspmd guide.
A splitted version of #4844 that contains all the sharding API benefits but is not breaking:
Allow Flax to recognize the new style JAX mesh context (which based off
AbstractMesh).Renamed annotation from
shardingtosharding_names, so thatvariable.shardingcan just point to the JAX sharding of the actual value.Changed CI to run everything on 4 fake CPU devices, so that we no longer to distinguish between single device and multi device tests.
Removed a few redundant functions like
nnx.with_sharding_constraint.