Skip to content

fix MultiMetric GraphDef not hashable under jit#5492

Open
mohsinm-dev wants to merge 1 commit into
google:mainfrom
mohsinm-dev:fix-multimetric-hashable-graphdef
Open

fix MultiMetric GraphDef not hashable under jit#5492
mohsinm-dev wants to merge 1 commit into
google:mainfrom
mohsinm-dev:fix-multimetric-hashable-graphdef

Conversation

@mohsinm-dev

Copy link
Copy Markdown
Contributor

Fixes #5489.

MultiMetric._metric_names was stored as a list. Since it's static graph metadata, GraphDef.__hash__ fails when the graphdef is passed as a static jit argument. Changed to tuple.

Store _metric_names as tuple instead of list so the static GraphDef
produced by nnx.split() is hashable when passed to jax.jit/nnx.jit.

@vfdev-5 vfdev-5 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do no see any objections to make it this way and thus having MultiMetric graphdef hashable.
LGTM, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

nnx.MultiMetric triggers JIT error when splitting the model into graph/state

2 participants