Automatically apply sharding constraints to sharded models#4844
Conversation
5aa49dd to
ddd1560
Compare
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
f57c05d to
fe54e64
Compare
4260556 to
805ccab
Compare
|
|
||
| ```{code-cell} ipython3 | ||
| from typing import * | ||
| with jax._src.sharding_impls.use_mesh(auto_mesh): |
There was a problem hiding this comment.
| with jax._src.sharding_impls.use_mesh(auto_mesh): | |
| with jax.sharding.use_mesh(auto_mesh): |
There was a problem hiding this comment.
This is a temporary API compatibility diff between internal and external JAX that I'll deal with.
| @partial(jax.vmap, spmd_axis_name=None) | ||
| @nnx.set_transform_axis_name(name=None) |
There was a problem hiding this comment.
set_transform_axis_name is a bit less flexible than current nnx.vmap which has access to all in_axes / out_axes so can lift axis other than 0 and can avoid lifting in the None case.
There was a problem hiding this comment.
Agreed that we should just use nnx.vmap for auto mode.
| ) | ||
| def get_abstract_sharding(init_fn, mesh): | ||
| with jax._src.sharding_impls.use_mesh(mesh): | ||
| abs_model = jax.eval_shape(init_fn) |
There was a problem hiding this comment.
In our internal APIs we should assume the user might be using reference sharing and use nnx.eval_shape and nnx.state and nnx.update.
eeec522 to
2f4b706
Compare
851c6c7 to
5c62419
Compare
Simplify the effort to create sharded Flax NNX models.
BREAKING CHANGE: Now
nnx.Variablecreation will require a mesh context and automatically runjax.lax.with_sharding_constrainton the value if the sharding annotation is provided.Basically, to create a sharded model, user only needs to do this:
instead of the current boilerplate combo of
nnx.jit,nnx.get_partition_spec,with_sharding_constraintandnnx.update: