Skip to content

[nnx] preserve the function's type information in jit#4981

Merged
copybara-service[bot] merged 1 commit into
mainfrom
jit-wrapped-types
Sep 25, 2025
Merged

[nnx] preserve the function's type information in jit#4981
copybara-service[bot] merged 1 commit into
mainfrom
jit-wrapped-types

Conversation

@cgarciae

@cgarciae cgarciae commented Sep 25, 2025

Copy link
Copy Markdown
Collaborator

What does this PR do?

Uses typing.ParamSpec to properly annotate JitWrapped.__call__, this effectively preserves the original function's call type information so the inputs and outputs will be correctly checked by the type checkers.

Comment thread flax/nnx/transforms/compilation.py Outdated
) -> JitWrapped: ...
) -> JitWrapped[P, R]: ...
def jit(
fun: tp.Callable[..., tp.Any] | type[Missing] = Missing,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the signature for fun have to be updated here, as well?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, nice catch! updated.

@jburnim jburnim left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@copybara-service copybara-service Bot merged commit 609a0ab into main Sep 25, 2025
18 checks passed
@copybara-service copybara-service Bot deleted the jit-wrapped-types branch September 25, 2025 22:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants