Skip to content

Conversation

wonjoo-wj
Copy link
Collaborator

@wonjoo-wj wonjoo-wj commented May 14, 2024

Support megacore_mode in paged_attention

JAX reference for megacore_mode: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L318

Test plan:

python test/test_pallas.py PallasTest.test_paged_attention_wrapper_with_megacore_modes

+ TPU CI

@wonjoo-wj
Copy link
Collaborator Author

Locally test is succeeding on my v4-8:

root@t1v-n-4989e8c7-w-0:~/pytorch/xla# python test/test_pallas.py PallasTest.test_paged_attention_wrapper_with_megacore_modes
.
----------------------------------------------------------------------
Ran 1 test in 3.283s

OK
root@t1v-n-4989e8c7-w-0:~/pytorch/xla# 

I'll wait for TPU CI to verify the rest.

@wonjoo-wj wonjoo-wj requested review from alanwaketan and JackCaoG May 14, 2024 19:31
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@wonjoo-wj
Copy link
Collaborator Author

Thanks for the reviews, merging as all CIs are green.

@wonjoo-wj wonjoo-wj merged commit cbb9e21 into master May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants