fix cached_partial for raw array attributes#5501
Open
discobot wants to merge 1 commit into
Open
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
nnx.cached_partial crashed with an AttributeError when cached graph nodes contained raw jax or numpy array attributes not wrapped in nnx.Variable, because StaticCache assumed all leaves were Variables. Raw array leaves are now passed through as-is when flattening and skipped during update propagation when merging, i.e. they are treated as read-only. The read-only semantics are documented in the cached_partial docstring and covered by a regression test. Fixes google#5109
9bbea17 to
9815805
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
nnx.cached_partial crashes on raw jax.Array / numpy.ndarray attributes. The mechanism:
StaticCache.variables stores flatten() leaves, which can be raw arrays, but two downstream
sites assume every entry is a Variable — the flatten with_paths=False fast path calls
get_raw_value() on each cached entry (the reported crash), and the unflatten cache-update
loop calls update_from_state/set_raw_value (a latent crash on the output path).
This implements the read-only behavior discussed in the issue: raw array leaves are passed
through as-is when flattening, and the merge update loop skips non-Variable entries, so
Variable propagation is unchanged. The cached_partial docstring now documents the read-only
semantics, and the StaticCache.variables annotation is widened to match.
No warning is emitted when a cached function would have updated a raw array: comparing
returned leaves to cached arrays by identity always differs after jit, and by value would
force device syncs, so the constraint is documented instead.
Adds a regression test with raw jax.Array and np.ndarray attributes alongside an
nnx.Variable counter that must keep receiving updates across cached calls;
tests/nnx/transforms_test.py and tests/nnx/graph_utils_test.py pass on CPU.
Fixes #5109
Checklist
nnx.cached_partialerrors on jaxArrays not wrapped innnx.Variable#5109