-
Notifications
You must be signed in to change notification settings - Fork 22.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] [CuDNN-Attention] CuDNN backend should return the output in the same stride order as input Query #138340
Comments
Here is a trace of the exact striding behavior: Forward$1: f16[512, 1536] strides=[1, 512] = aten.t.default($0)
$3: f16[2048, 512] strides=[512, 1] = aten.view.default($2, [2048, 512])
$4: f16[2048, 1536] strides=[1536, 1] = aten.mm.default($3, $1)
$5: f16[1, 2048, 1536] strides=[3145728, 1536, 1] = aten._unsafe_view.default($4, [1, 2048, 1536])
[$6: f16[1, 2048, 512] strides=[3145728, 1536, 1], $7: f16[1, 2048, 512] strides=[3145728, 1536, 1], $8: f16[1, 2048, 512] strides=[3145728, 1536, 1]] = aten.split.Tensor($5, 512, 2)
$9: f16[1, 2048, 8, 64] strides=[3145728, 1536, 64, 1] = aten.view.default($7, [1, 2048, 8, 64])
$10: f16[1, 8, 2048, 64] strides=[3145728, 64, 1536, 1] = aten.transpose.int($9, 1, 2)
$11: f16[1, 2048, 8, 64] strides=[3145728, 1536, 64, 1] = aten.view.default($6, [1, 2048, 8, 64])
$12: f16[1, 8, 2048, 64] strides=[3145728, 64, 1536, 1] = aten.transpose.int($11, 1, 2)
$13: f16[1, 2048, 8, 64] strides=[3145728, 1536, 64, 1] = aten.view.default($8, [1, 2048, 8, 64])
$14: f16[1, 8, 2048, 64] strides=[3145728, 64, 1536, 1] = aten.transpose.int($13, 1, 2)
($15: f16[1, 8, 2048, 64] strides=[1048576, 131072, 64, 1], $16: f32[1, 8, 2048] strides=[16384, 2048, 1], None, None, 2048, 2048, $17: i64[] strides=[], $18: i64[] strides=[], None) = aten._scaled_dot_product_cudnn_attention.default($12, $10, $14, None, True, 0.0, True)
$19: f16[1, 8, 2048, 64] strides=[1048576, 131072, 64, 1] = aten.detach.default($15)
$20: f16[1, 2048, 8, 64] strides=[1048576, 64, 131072, 1] = aten.transpose.int($15, 1, 2) |
Relevant Backward Striding $28: f16[1, 2048, 512] strides=[1048576, 512, 1] = aten.ones_like.default($27, pin_memory=False)
$29: f16[2048, 512] strides=[512, 1] = aten.view.default($28, [2048, 512])
$30: f16[512, 2048] strides=[1, 512] = aten.t.default($29)
$31: f16[512, 512] strides=[512, 1] = aten.mm.default($30, $25)
$32: f16[512, 512] strides=[1, 512] = aten.t.default($31)
$33: f16[512, 512] strides=[512, 1] = aten.t.default($24)
$34: f16[2048, 512] strides=[512, 1] = aten.mm.default($29, $33)
$35: f16[1, 2048, 512] strides=[1048576, 512, 1] = aten.view.default($34, [1, 2048, 512])
$36: f16[512, 512] strides=[512, 1] = aten.t.default($32)
$37: f16[512, 512] strides=[512, 1] = aten.detach.default($36)
$38: f16[512, 512] strides=[512, 1] = aten.detach.default($37)
$39: f16[1, 2048, 8, 64] strides=[1048576, 512, 64, 1] = aten.view.default($35, [1, 2048, 8, 64])
$40: f16[1, 8, 2048, 64] strides=[1048576, 64, 512, 1] = aten.transpose.int($39, 1, 2)
$41: f16[1, 8, 2048, 64] strides=[1048576, 64, 512, 1] = aten.detach.default($20)
($42: f16[1, 8, 2048, 64] strides=[1048576, 64, 512, 1], $43: f16[1, 8, 2048, 64] strides=[1048576, 64, 512, 1], $44: f16[1, 8, 2048, 64] strides=[1048576, 64, 512, 1]) = aten._scaled_dot_product_flash_attention_backward.default($40, $12, $10, $14, $41, $16, None, None, 2048, 2048, 0.0, True, $17, $18, scale=0.125)
$45: f16[1, 2048, 8, 64] strides=[1048576, 512, 64, 1] = aten.transpose.int($44, 1, 2)
$46: f16[1, 2048, 512] strides=[1048576, 512, 1] = aten.view.default($45, [1, 2048, 512])
$47: f16[1, 2048, 8, 64] strides=[1048576, 512, 64, 1] = aten.transpose.int($42, 1, 2)
$48: f16[1, 2048, 512] strides=[1048576, 512, 1] = aten.view.default($47, [1, 2048, 512])
$49: f16[1, 2048, 8, 64] strides=[1048576, 512, 64, 1] = aten.transpose.int($43, 1, 2)
$50: f16[1, 2048, 512] strides=[1048576, 512, 1] = aten.view.default($49, [1, 2048, 512])
$51: f16[1, 2048, 1536] strides=[3145728, 1536, 1] = aten.cat.default([$48, $50, $46], 2) |
@ngimel suggests that we disable cuDNN until both forward and backwards op can handle the permuted case |
Not necessarily disable, but at least not make it default (as it emits a lot of warnings and I don't know what are perf implications of those contiguous calls) |
Sure, will add to #138354 |
how did we miss it during all the discussions? and is it still faster than old |
@malfet we dont have any H100 runners in CI/CD. This would only show up in E2E torchbench like testing while most of the performance testing was done at the per-op level. |
I have been testing this on an a modified version of nanogpt: Trace on Nighlty: Trace on #138354 Built against CuDNN 9.4 w/ cuda 12.4 Findings
vs. As well we see that the vs. BUTFor some reason, at least w/ my setup, the forward kernel is 4x slower when writing to transposed output. ForwardPyTorch: vs PyTorch: BackwardThe backwards case is showing around the same speed with kernel PyTorch: |
how long is the added contiguous? |
I have a slight suspicion that the kernel choice is only due to the cuDNN version difference in the setups (9.4 vs. 9.1.0.70 in nightlies) as this is what I saw in my local testing so it might be interesting to see what in-between nightly builds produce if we can merge this The 4x slowdown is baffling though, as it didn't seem that bad on my setup... |
|
@eqy Did your cudnn_frontend version match nightly? |
@ngimel The contiguous on 2048 seq-len takes around: |
Update: I think that something was wrong with my local CuDNN CuDNN On PR: |
So original invocation + overhead from |
@eqy I'm curious is there a fundamental reason cudnn is faster on contiguous inputs? Permuted input is standard implementation, so it's pretty strange cudnn didn't optimize for it. |
The cuDNN team seems surprised by that claim as they said BSHD is the common path as well. from 9.5
Is the change surfacing some kind of CPU overhead? |
Note how max kernel time is much higher on 9.5. @drisspg profiling results also were measuring kernel time (as reported by profiler), not cpu overhead. |
Removing 2.5.1 milestone. |
Great news the new CUDNN pypi binary that fixes the stride bug has been released. We can update if the manylinux CD upgrade is ready. |
Summary
This can have large performance impact in real Attention modules.
The most common pattern (derived from nano-gpt)
Output
---------------------------------------------SDPA-Flash--------------------------------------------- ALL GOOD ---------------------------------------------SDPA-CuDNN--------------------------------------------- ❗ NOT GOOD ❗ view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @csarofeen @ptrblck @xwang233 @mikaylagawarecki
The text was updated successfully, but these errors were encountered: