Skip to content

Update optimizer.py to support masked variable from optax.#4904

Merged
copybara-service[bot] merged 2 commits into
google:mainfrom
ywrt:patch-1
Sep 2, 2025
Merged

Update optimizer.py to support masked variable from optax.#4904
copybara-service[bot] merged 2 commits into
google:mainfrom
ywrt:patch-1

Conversation

@ywrt

@ywrt ywrt commented Aug 22, 2025

Copy link
Copy Markdown
Contributor

Per #4901 the current code doesn't work with masked variable introduced by optax.partition() because it assumes that the value is an array.

This change just uses x.value to get the value directly, instead of using x[...] to get in implicitly.

What does this PR do?

Fixes #4901

Checklist

  • [X ] This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other
    checks if that's the case).

@8bitmp3 8bitmp3 requested review from IvyZX and cgarciae August 22, 2025 21:56
@ywrt ywrt closed this Aug 23, 2025
@ywrt

ywrt commented Aug 23, 2025

Copy link
Copy Markdown
Contributor Author

Tests don't pass, so this simple fix can't be right. Withdrawn.

@cgarciae

Copy link
Copy Markdown
Collaborator

@ywrt not sure why they don't pass but I think you have the right idea

@cgarciae cgarciae reopened this Aug 25, 2025
@cgarciae

Copy link
Copy Markdown
Collaborator

oh its just mypy, I think you just have to rebase
I ran them locally and it looks good

@cgarciae

Copy link
Copy Markdown
Collaborator

seems like main is broken, will send a fix

@vfdev-5

vfdev-5 commented Aug 27, 2025

Copy link
Copy Markdown
Collaborator

@ywrt can you please rebase your PR to move it forward. Thanks !

@ywrt

ywrt commented Aug 28, 2025

Copy link
Copy Markdown
Contributor Author

Ok, I think this is rebased now?

@vfdev-5

vfdev-5 commented Aug 28, 2025

Copy link
Copy Markdown
Collaborator

@ywrt can you please execute the following on your machine to make this PR up to date with main branch:

git checkout patch-1
# assuming "origin" is pointing to git@github.com:google/flax.git
git pull origin main -r
# assuming "fork" is pointing to your fork: git@github.com:ywrt/flax.git
git push fork patch-1 -f

Thanks!

@vfdev-5 vfdev-5 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ywrt !

@copybara-service copybara-service Bot merged commit 49cfa78 into google:main Sep 2, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

optax.contrib.muon does not seem to work with nnx

3 participants