Skip to content

Added is_causal mask argument to flax.nnx.dot_product_attention#5093

Merged
copybara-service[bot] merged 1 commit into
google:mainfrom
ibbyml:nnx-sdpa-parity
Dec 11, 2025
Merged

Added is_causal mask argument to flax.nnx.dot_product_attention#5093
copybara-service[bot] merged 1 commit into
google:mainfrom
ibbyml:nnx-sdpa-parity

Conversation

@ibbyml

@ibbyml ibbyml commented Nov 17, 2025

Copy link
Copy Markdown
Contributor

What does this PR do?

  • Adds an is_causal arg to flax.nnx.dot_product_attention and dot_product_attention_weights.
  • Forwards is_causal through to jax.nn.dot_product_attention fast path when possible.
  • Implements a manual causal masking logic that:
    • Supports both self-attention and cross-attention
    • Composes is_causal with input masks with the combine_masks helper
  • Adds 3 attention tests to ensure compatibility and correctness.

Checklist

  • This change is discussed in this discussion.
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

@vfdev-5 vfdev-5 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @ibbyml !

Comment thread flax/nnx/nn/attention.py Outdated
Comment thread tests/nnx/nn/attention_test.py Outdated
@ibbyml

ibbyml commented Nov 17, 2025

Copy link
Copy Markdown
Contributor Author

Pushed the updated parameterized tests and removed the unnecessary formatting. The tests now cover self-attention with and without a padding mask as well as cross-attention with and without a padding mask. Happy to adjust anything else.

@vfdev-5

vfdev-5 commented Nov 17, 2025

Copy link
Copy Markdown
Collaborator

Pushed the updated parameterized tests and removed the unnecessary formatting. The tests now cover self-attention with and without a padding mask as well as cross-attention with and without a padding mask. Happy to adjust anything else.

Thanks! For parameterized tests, the idea is to write all 3 test cases as a parameterized single one. I do not think we need to parameterize on B, T, S etc.

@vfdev-5

vfdev-5 commented Nov 18, 2025

Copy link
Copy Markdown
Collaborator

@ibbyml thanks for the updates! Please squash all commits into 1 otherwise CI will fail for num_commits >= 5

@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@vfdev-5

vfdev-5 commented Nov 18, 2025

Copy link
Copy Markdown
Collaborator

@ibbyml let's keep only your updates

@ibbyml

ibbyml commented Nov 18, 2025

Copy link
Copy Markdown
Contributor Author

@ibbyml let's keep only your updates

Sorry about that. Accidentally pulled everything. New push should be squashed correctly.

@vfdev-5 vfdev-5 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @ibbyml !

Comment thread flax/nnx/nn/attention.py
@copybara-service copybara-service Bot merged commit be1db78 into google:main Dec 11, 2025
18 checks passed
@ibbyml ibbyml deleted the nnx-sdpa-parity branch December 12, 2025 19:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants