Skip to content

fix cached_partial for raw array attributes#5501

Open
discobot wants to merge 1 commit into
google:mainfrom
discobot:fix/5109-cached-partial-raw-arrays
Open

fix cached_partial for raw array attributes#5501
discobot wants to merge 1 commit into
google:mainfrom
discobot:fix/5109-cached-partial-raw-arrays

Conversation

@discobot

Copy link
Copy Markdown

What does this PR do?

nnx.cached_partial crashes on raw jax.Array / numpy.ndarray attributes. The mechanism:
StaticCache.variables stores flatten() leaves, which can be raw arrays, but two downstream
sites assume every entry is a Variable — the flatten with_paths=False fast path calls
get_raw_value() on each cached entry (the reported crash), and the unflatten cache-update
loop calls update_from_state/set_raw_value (a latent crash on the output path).

This implements the read-only behavior discussed in the issue: raw array leaves are passed
through as-is when flattening, and the merge update loop skips non-Variable entries, so
Variable propagation is unchanged. The cached_partial docstring now documents the read-only
semantics, and the StaticCache.variables annotation is widened to match.

No warning is emitted when a cached function would have updated a raw array: comparing
returned leaves to cached arrays by identity always differs after jit, and by value would
force device syncs, so the constraint is documented instead.

Adds a regression test with raw jax.Array and np.ndarray attributes alongside an
nnx.Variable counter that must keep receiving updates across cached calls;
tests/nnx/transforms_test.py and tests/nnx/graph_utils_test.py pass on CPU.

Fixes #5109

Checklist

@google-cla

google-cla Bot commented Jun 13, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

nnx.cached_partial crashed with an AttributeError when cached graph nodes contained raw jax or numpy array attributes not wrapped in nnx.Variable, because StaticCache assumed all leaves were Variables. Raw array leaves are now passed through as-is when flattening and skipped during update propagation when merging, i.e. they are treated as read-only. The read-only semantics are documented in the cached_partial docstring and covered by a regression test.

Fixes google#5109
@discobot discobot force-pushed the fix/5109-cached-partial-raw-arrays branch from 9bbea17 to 9815805 Compare June 13, 2026 11:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

nnx.cached_partial errors on jax Arrays not wrapped in nnx.Variable

1 participant