Skip to content

Added kernel_metadata/bias_metadata args to nnx layers#5074

Merged
copybara-service[bot] merged 2 commits into
mainfrom
params_metadata_args
Nov 12, 2025
Merged

Added kernel_metadata/bias_metadata args to nnx layers#5074
copybara-service[bot] merged 2 commits into
mainfrom
params_metadata_args

Conversation

@vfdev-5

@vfdev-5 vfdev-5 commented Nov 6, 2025

Copy link
Copy Markdown
Collaborator

Description:

  • Added kernel_metadata/bias_metadata args to nnx layers:
  • linear layers: nnx.Linear, nnx.LinearGeneral, nnx.Conv, nnx.ConvTransposed, nnx.Einsum, nnx.Embedding
  • MHA
  • normalization: BN, LN, IN, GN
  • recurrent: LSTM, GRU
  • lora
  • Added tests

@vfdev-5

vfdev-5 commented Nov 6, 2025

Copy link
Copy Markdown
Collaborator Author

Let's pass *_metadata directly into nnx.Param

Comment thread flax/nnx/nn/attention.py Outdated
@vfdev-5 vfdev-5 force-pushed the params_metadata_args branch from 1af973b to f385e25 Compare November 7, 2025 16:15
@vfdev-5 vfdev-5 marked this pull request as ready for review November 7, 2025 16:26
Comment thread flax/nnx/nn/linear.py Outdated
Comment thread tests/nnx/nn/linear_test.py
@vfdev-5 vfdev-5 force-pushed the params_metadata_args branch 3 times, most recently from 6786464 to 485424c Compare November 7, 2025 22:37
@IvyZX

IvyZX commented Nov 11, 2025

Copy link
Copy Markdown
Collaborator

This PR got failed tests internally (which uses JAX head) with errors like this. If it doesn't show in OSS, perhaps it's only repro-able with latest JAX in head.

Traceback (most recent call last):
  File "/build/work/b0c0fd3dff06e484b52ca92fc70ac53f0a39/google3/runfiles/google3/third_party/py/absl/testing/parameterized.py", line 321, in bound_param_test
    return test_method(self, **testcase_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/build/work/b0c0fd3dff06e484b52ca92fc70ac53f0a39/google3/runfiles/google3/third_party/py/flax/tests/nnx/nn/linear_test.py", line 401, in test
    module = module_cls(*args, **metadata_kwargs, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/build/work/b0c0fd3dff06e484b52ca92fc70ac53f0a39/google3/runfiles/google3/third_party/py/flax/nnx/pytreelib.py", line 333, in __call__
    return _graph_node_meta_call(cls, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/build/work/b0c0fd3dff06e484b52ca92fc70ac53f0a39/google3/runfiles/google3/third_party/py/flax/nnx/pytreelib.py", line 345, in _graph_node_meta_call
    cls._pytree_meta_construct(node, *args, **kwargs)
  File "/build/work/b0c0fd3dff06e484b52ca92fc70ac53f0a39/google3/runfiles/google3/third_party/py/flax/nnx/pytreelib.py", line 336, in _pytree_meta_construct
    self.__init__(*args, **kwargs)
  File "/build/work/b0c0fd3dff06e484b52ca92fc70ac53f0a39/google3/runfiles/google3/third_party/py/flax/nnx/nn/linear.py", line 227, in __init__
    self.kernel = nnx.Param(
                  ^^^^^^^^^^
  File "/build/work/b0c0fd3dff06e484b52ca92fc70ac53f0a39/google3/runfiles/google3/third_party/py/flax/nnx/variablelib.py", line 296, in __init__
    value = core_spmd.shard_value(
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/build/work/b0c0fd3dff06e484b52ca92fc70ac53f0a39/google3/runfiles/google3/third_party/py/flax/core/spmd.py", line 47, in shard_value
    return jax.lax.with_sharding_constraint(value, pspec)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/build/work/b0c0fd3dff06e484b52ca92fc70ac53f0a39/google3/runfiles/google3/third_party/py/jax/_src/pjit.py", line 2662, in with_sharding_constraint
    assert_shardings_equal(x_aval, s)
  File "/build/work/b0c0fd3dff06e484b52ca92fc70ac53f0a39/google3/runfiles/google3/third_party/py/jax/_src/pjit.py", line 2582, in assert_shardings_equal
    raise AssertionError(
AssertionError: `with_sharding_constraint` acts as an assert when all axes of mesh are of type `Explicit`. The array sharding: PartitionSpec(None, None, None) did not match the sharding provided: PartitionSpec('din', 'dout', None). Please use `jax.sharding.reshard` to shard your input to the sharding you want.

Description:

Added kernel_metadata/bias_metadata args to nnx layers:
- [x] nnx.Linear
- [x] nnx.LinearGeneral
- [ ] nnx.Conv
- [ ] ...
@vfdev-5 vfdev-5 force-pushed the params_metadata_args branch from 485424c to f320a5c Compare November 12, 2025 11:31
@vfdev-5 vfdev-5 requested a review from cgarciae November 12, 2025 11:32
@copybara-service copybara-service Bot merged commit b08cb20 into main Nov 12, 2025
21 checks passed
@copybara-service copybara-service Bot deleted the params_metadata_args branch November 12, 2025 19:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants