Skip to content

Conversation

alanwaketan
Copy link
Collaborator

Summary:
This pull request adds ab support for flash_attention which is a custom mask for attention weight.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_ab

@alanwaketan alanwaketan self-assigned this Aug 13, 2024
@ZhiyuLi-goog
Copy link
Contributor

Thank you @alanwaketan!

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_ab(self):
jax.config.update("jax_default_matmul_precision", "highest")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol we should really make a context manager that take cares of this in this test.

@JackCaoG JackCaoG added the tpuci label Aug 14, 2024
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rerun the CI? TPUCI was not enabled so the test was not run.

@JackCaoG
Copy link
Collaborator

actually let me just trigger it...

@JackCaoG JackCaoG merged commit 21a0b5a into master Aug 14, 2024
@JackCaoG JackCaoG deleted the alanwaketan/flash_ab branch August 14, 2024 21:19
@alanwaketan
Copy link
Collaborator Author

Thanks, Jack.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants