Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDPA-CUDNN] Make CuDNN Attention Opt in #138522

Closed
wants to merge 4 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Oct 21, 2024

Stack from ghstack (oldest at bottom):

Summary

Currently we have a cudnn_order that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

  1. SDPA: CUDNN backend error w/ q_seq_len = 1 #138529
  2. RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans support the graph. huggingface/diffusers#9704
  3. [cuDNN][SDPA] Match query's memory layout ordering for output in cuDNN SDPA #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:

from torch.nn.attention import sdpa_kernel, SDPBackend    

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

cc @mikaylagawarecki

Copy link

pytorch-bot bot commented Oct 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138522

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0357a75 with merge base 7786869 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

drisspg added a commit that referenced this pull request Oct 21, 2024
ghstack-source-id: 58f70e7943348cd11908451eddc150a2c1b22cde
Pull Request resolved: #138522
cc mikaylagawarecki

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 21, 2024
ghstack-source-id: 0864209b86022cefec2f8ad2029be0b4facb96f2
Pull Request resolved: #138522
@drisspg drisspg added this to the 2.5.1 milestone Oct 22, 2024
@drisspg
Copy link
Contributor Author

drisspg commented Oct 22, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 22, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend    

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). 


Cc atalman

cc mikaylagawarecki

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 22, 2024
ghstack-source-id: 28290c32b53b82f8e49f67b44312a42435ad006b
Pull Request resolved: #138522
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Oct 22, 2024
ghstack-source-id: 30a6d89d5c096b1da8bf95be136bd491d2bdf6de
Pull Request resolved: #138522
@drisspg
Copy link
Contributor Author

drisspg commented Oct 22, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

This PR (#138522) was merged in 9a9a0ab but it is still open, likely due to a Github bug, so mergebot is closing it manually. If you think this is a mistake, please feel free to reopen and contact Dev Infra.

@atalman
Copy link
Contributor

atalman commented Oct 22, 2024

@pytorchbot cherry-pick --onto release/2.5 -c critical

pytorchbot pushed a commit that referenced this pull request Oct 22, 2024
# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: #138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet

(cherry picked from commit 9a9a0ab)
@pytorchbot
Copy link
Collaborator

Cherry picking #138522

The cherry pick PR is at #138587 and it is recommended to link a critical cherry pick PR with an issue.

Details for Dev Infra team Raised by workflow job

@huydhn
Copy link
Contributor

huydhn commented Oct 22, 2024

@drisspg I think this change is failing on Windows https://github.com/pytorch/pytorch/actions/runs/11464507611/job/31902166240#step:15:21915

test_transformers.py::TestSDPACudaOnlyCUDA::test_fused_sdp_choice_type_dense_cuda GH job link HUD commit link

Could you help take a look?

@drisspg
Copy link
Contributor Author

drisspg commented Oct 22, 2024

@huydhn can I forward fix this, #138641

Should fix but dont have a windows machine to test on

kit1980 pushed a commit that referenced this pull request Oct 22, 2024
[SDPA-CUDNN] Make CuDNN Attention Opt in (#138522)

# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: #138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet

(cherry picked from commit 9a9a0ab)

Co-authored-by: drisspg <drisspguessous@gmail.com>
SamGinzburg pushed a commit that referenced this pull request Oct 28, 2024
# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: #138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: multi-headed-attention topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants