Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] [CuDNN-Attention] CuDNN backend should return the output in the same stride order as input Query #138340

Open
drisspg opened this issue Oct 18, 2024 · 21 comments
Labels
high priority module: cudnn Related to torch.backends.cudnn, and CuDNN support module: multi-headed-attention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@drisspg
Copy link
Contributor

drisspg commented Oct 18, 2024

Summary

This can have large performance impact in real Attention modules.

The most common pattern (derived from nano-gpt)

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

from torch.nn.attention import bias, sdpa_kernel, SDPBackend

@dataclass
class Config:
    n_embd: int = 512
    n_head: int = 8
    n_layer: int = 6
    n_ctx: int = 2048
    bias: bool = False

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)

        # HERE, WE NEED THIS CONTIGUOUS TO BE A NO-OP
        # y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = y.transpose(1, 2).view(B, T, C)
        y = self.c_proj(y)
        return y

def test_attention(backend: SDPBackend):
    config = Config()
    Attention = CausalSelfAttention(config).to("cuda", dtype=torch.float16)
    sample_input = torch.randn(1, 2048, config.n_embd, device="cuda", dtype = torch.float16)
    with sdpa_kernel(backend):
        try:
            out = Attention(sample_input)
            print("ALL GOOD")
        except RuntimeError as e:
            print("❗ NOT GOOD ❗")
            print(e)

if __name__ == "__main__":
    width = 100
    print("SDPA-Flash".center(width, "-"))
    test_attention(SDPBackend.FLASH_ATTENTION)
    print("SDPA-CuDNN".center(width, "-"))
    test_attention(SDPBackend.CUDNN_ATTENTION)

Output

---------------------------------------------SDPA-Flash---------------------------------------------
ALL GOOD
---------------------------------------------SDPA-CuDNN---------------------------------------------
❗ NOT GOOD ❗
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @csarofeen @ptrblck @xwang233 @mikaylagawarecki

@drisspg drisspg added high priority module: cudnn Related to torch.backends.cudnn, and CuDNN support labels Oct 18, 2024
@Skylion007
Copy link
Collaborator

FYI @eqy @nWEIdia

@drisspg drisspg added this to the 2.5.1 milestone Oct 18, 2024
@drisspg
Copy link
Contributor Author

drisspg commented Oct 18, 2024

Here is a trace of the exact striding behavior:

Forward

$1: f16[512, 1536] strides=[1, 512] = aten.t.default($0)
$3: f16[2048, 512] strides=[512, 1] = aten.view.default($2, [2048, 512])
$4: f16[2048, 1536] strides=[1536, 1] = aten.mm.default($3, $1)
$5: f16[1, 2048, 1536] strides=[3145728, 1536, 1] = aten._unsafe_view.default($4, [1, 2048, 1536])
[$6: f16[1, 2048, 512] strides=[3145728, 1536, 1], $7: f16[1, 2048, 512] strides=[3145728, 1536, 1], $8: f16[1, 2048, 512] strides=[3145728, 1536, 1]] = aten.split.Tensor($5, 512, 2)
$9: f16[1, 2048, 8, 64] strides=[3145728, 1536, 64, 1] = aten.view.default($7, [1, 2048, 8, 64])
$10: f16[1, 8, 2048, 64] strides=[3145728, 64, 1536, 1] = aten.transpose.int($9, 1, 2)
$11: f16[1, 2048, 8, 64] strides=[3145728, 1536, 64, 1] = aten.view.default($6, [1, 2048, 8, 64])
$12: f16[1, 8, 2048, 64] strides=[3145728, 64, 1536, 1] = aten.transpose.int($11, 1, 2)
$13: f16[1, 2048, 8, 64] strides=[3145728, 1536, 64, 1] = aten.view.default($8, [1, 2048, 8, 64])
$14: f16[1, 8, 2048, 64] strides=[3145728, 64, 1536, 1] = aten.transpose.int($13, 1, 2)
($15: f16[1, 8, 2048, 64] strides=[1048576, 131072, 64, 1], $16: f32[1, 8, 2048] strides=[16384, 2048, 1], None, None, 2048, 2048, $17: i64[] strides=[], $18: i64[] strides=[], None) = aten._scaled_dot_product_cudnn_attention.default($12, $10, $14, None, True, 0.0, True)
$19: f16[1, 8, 2048, 64] strides=[1048576, 131072, 64, 1] = aten.detach.default($15)
$20: f16[1, 2048, 8, 64] strides=[1048576, 64, 131072, 1] = aten.transpose.int($15, 1, 2)

@drisspg
Copy link
Contributor Author

drisspg commented Oct 18, 2024

Relevant Backward Striding

$28: f16[1, 2048, 512] strides=[1048576, 512, 1] = aten.ones_like.default($27, pin_memory=False)
$29: f16[2048, 512] strides=[512, 1] = aten.view.default($28, [2048, 512])
$30: f16[512, 2048] strides=[1, 512] = aten.t.default($29)
$31: f16[512, 512] strides=[512, 1] = aten.mm.default($30, $25)
$32: f16[512, 512] strides=[1, 512] = aten.t.default($31)
$33: f16[512, 512] strides=[512, 1] = aten.t.default($24)
$34: f16[2048, 512] strides=[512, 1] = aten.mm.default($29, $33)
$35: f16[1, 2048, 512] strides=[1048576, 512, 1] = aten.view.default($34, [1, 2048, 512])
$36: f16[512, 512] strides=[512, 1] = aten.t.default($32)
$37: f16[512, 512] strides=[512, 1] = aten.detach.default($36)
$38: f16[512, 512] strides=[512, 1] = aten.detach.default($37)
$39: f16[1, 2048, 8, 64] strides=[1048576, 512, 64, 1] = aten.view.default($35, [1, 2048, 8, 64])
$40: f16[1, 8, 2048, 64] strides=[1048576, 64, 512, 1] = aten.transpose.int($39, 1, 2)
$41: f16[1, 8, 2048, 64] strides=[1048576, 64, 512, 1] = aten.detach.default($20)
($42: f16[1, 8, 2048, 64] strides=[1048576, 64, 512, 1], $43: f16[1, 8, 2048, 64] strides=[1048576, 64, 512, 1], $44: f16[1, 8, 2048, 64] strides=[1048576, 64, 512, 1]) = aten._scaled_dot_product_flash_attention_backward.default($40, $12, $10, $14, $41, $16, None, None, 2048, 2048, 0.0, True, $17, $18, scale=0.125)
$45: f16[1, 2048, 8, 64] strides=[1048576, 512, 64, 1] = aten.transpose.int($44, 1, 2)
$46: f16[1, 2048, 512] strides=[1048576, 512, 1] = aten.view.default($45, [1, 2048, 512])
$47: f16[1, 2048, 8, 64] strides=[1048576, 512, 64, 1] = aten.transpose.int($42, 1, 2)
$48: f16[1, 2048, 512] strides=[1048576, 512, 1] = aten.view.default($47, [1, 2048, 512])
$49: f16[1, 2048, 8, 64] strides=[1048576, 512, 64, 1] = aten.transpose.int($43, 1, 2)
$50: f16[1, 2048, 512] strides=[1048576, 512, 1] = aten.view.default($49, [1, 2048, 512])
$51: f16[1, 2048, 1536] strides=[3145728, 1536, 1] = aten.cat.default([$48, $50, $46], 2)

@drisspg
Copy link
Contributor Author

drisspg commented Oct 18, 2024

@ngimel suggests that we disable cuDNN until both forward and backwards op can handle the permuted case

@ngimel
Copy link
Collaborator

ngimel commented Oct 18, 2024

Not necessarily disable, but at least not make it default (as it emits a lot of warnings and I don't know what are perf implications of those contiguous calls)

@eqy
Copy link
Collaborator

eqy commented Oct 18, 2024

Sure, will add to #138354

@malfet
Copy link
Contributor

malfet commented Oct 18, 2024

how did we miss it during all the discussions? and is it still faster than old

@drisspg
Copy link
Contributor Author

drisspg commented Oct 18, 2024

@malfet we dont have any H100 runners in CI/CD. This would only show up in E2E torchbench like testing while most of the performance testing was done at the per-op level.

@drisspg
Copy link
Contributor Author

drisspg commented Oct 19, 2024

I have been testing this on an a modified version of nanogpt:
drisspg/nanoGPT#1

Trace on Nighlty:
CUDNN_ATTENTION_nightly.json

Trace on #138354 Built against CuDNN 9.4 w/ cuda 12.4
CUDNN_ATTENTION_dev.json

Findings

  1. Good thing is that I verified this PR does indeed remove the contiguous in the subsequent reshaping post sdpa

PyTorch: 138354
Screenshot 2024-10-18 at 8 05 48 PM

vs.

PyTorch:nightly
Screenshot 2024-10-18 at 8 06 09 PM

As well we see that the contiguous in the backwards has also been removed:

PyTorch: 138354
Screenshot 2024-10-18 at 8 06 52 PM

vs.

PyTorch:nightly
Screenshot 2024-10-18 at 8 07 13 PM

BUT

For some reason, at least w/ my setup, the forward kernel is 4x slower when writing to transposed output.

Forward

PyTorch: 138354 9.4 cudnn w/ 12.4 built locally:
Kernel Name: cudnn_generated_fort_native_sdpa_sm90_flash_fprop_wgmma_f16_knob_31_64x256x64_4x1x1_kernel0_0
Screenshot 2024-10-18 at 8 10 07 PM

vs

PyTorch:nightly cuda 12.4
Kernel Name: cudnn_generated_fort_native_sdpa_sm90_knob_7_64x128x64_4x1x1_kernel0_0
Screenshot 2024-10-18 at 8 11 00 PM

Backward

The backwards case is showing around the same speed with kernel
PyTorch: 138354 9.4 cudnn w/ 12.4 built locally
Kernel Name:
cudnn_generated_fort_native_sdpa_sm90_flash_bprop_wgmma_f16_knob_26_64x64x64_1x4x1_kernel0_0

PyTorch:nightly cuda 12.4
Kernel Name:
cudnn_generated_fort_native_sdpa_sm90_knob_26_64x64x64_1x4x1_kernel0_0

@ngimel
Copy link
Collaborator

ngimel commented Oct 19, 2024

how long is the added contiguous?

@eqy
Copy link
Collaborator

eqy commented Oct 19, 2024

I have a slight suspicion that the kernel choice is only due to the cuDNN version difference in the setups (9.4 vs. 9.1.0.70 in nightlies) as this is what I saw in my local testing so it might be interesting to see what in-between nightly builds produce if we can merge this

The 4x slowdown is baffling though, as it didn't seem that bad on my setup...

@Skylion007
Copy link
Collaborator

Skylion007 commented Oct 19, 2024

Function names of attention kernels have been enhanced with more details on instruction and kernel type. For example, cudnn_generated_fort_native_sdpa_sm80_flash_fprop_wmma_f16_knob_32_64x64x64_4x1x1_kernel0_0 9.4.0 changelog. It could be the exact same kernel with a new name.

@Skylion007
Copy link
Collaborator

Skylion007 commented Oct 19, 2024

I have a slight suspicion that the kernel choice is only due to the cuDNN version difference in the setups (9.4 vs. 9.1.0.70 in nightlies) as this is what I saw in my local testing so it might be interesting to see what in-between nightly builds produce if we can merge this

The 4x slowdown is baffling though, as it didn't seem that bad on my setup...

@eqy Did your cudnn_frontend version match nightly?

@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Oct 21, 2024
@drisspg
Copy link
Contributor Author

drisspg commented Oct 21, 2024

@drisspg
Copy link
Contributor Author

drisspg commented Oct 21, 2024

Update: I think that something was wrong with my local CuDNN
New settings:
CuDNN: 12-9.1.1.17
Cuda-toolkit: 12-4
Sequence Length: 2048

CuDNN On PR: iter 50: loss 2.5125, time 83.60ms, mfu 60.58%
CuDNN On Nightly: iter 70: loss 2.4963, time 85.42ms, mfu 59.38%
FAv2. On Nightly: iter 50: loss 2.5133, time 85.20ms, mfu 59.51%

Trace:
https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/drisspg_a6d15b1b-402b-42d0-a6a6-a7f686ccb76a_CUDNN_ATTENTION_dev.json

@ngimel
Copy link
Collaborator

ngimel commented Oct 21, 2024

So original invocation + overhead from contiguous call is still faster than invoking with correctly strided output? I see ~1.9 ms in the new trace, vs supposedly 1.3 ms + 150 us. Should we then always use contiguous and remove the warning? Should we benchmark for a wider variety of params, different seq lens, different head sizes etc

@ngimel
Copy link
Collaborator

ngimel commented Oct 21, 2024

@eqy I'm curious is there a fundamental reason cudnn is faster on contiguous inputs? Permuted input is standard implementation, so it's pretty strange cudnn didn't optimize for it.

@eqy
Copy link
Collaborator

eqy commented Oct 22, 2024

The cuDNN team seems surprised by that claim as they said BSHD is the common path as well.
I think there's a regression (on both contig and non-contig) that happened sometime between 9.1.1 and 9.5 (maybe by 9.4 according to Driss's results).
x-posting data from a power-limited H100 (PCIe)
from 9.1.1
iter 0: loss 4.2694, time 70112.43ms, mfu -100.00% (no patch)
iter 0: loss 4.2694, time 56938.34ms, mfu -100.00% (with this patch)
nsys nvprof in respective order:
5.6 87523222 2424 36106.9 35872.0 34624 60768 2357.8 cudnn_generated_fort_native_sdpa_sm90_knob_7_64x128x64_4x1x1_kernel0_0
5.8 87536234 2424 36112.3 35872.0 34464 59936 2320.5 cudnn_generated_fort_native_sdpa_sm90_knob_7_64x128x64_4x1x1_kernel0_0

from 9.5
iter 0: loss 4.2694, time 69078.68ms, mfu -100.00%
iter 0: loss 4.2694, time 70822.93ms, mfu -100.00%

6.0 93453883 2424 38553.6 37121.0 36096 181634 14053.9 cudnn_generated_fort_native_sdpa_sm90_flash_fprop_wgmma_f16_knob_31_64x256x64_4x1x1_kernel0_0
6.2 93363033 2424 38516.1 37089.0 35904 181410 14084.5 cudnn_generated_fort_native_sdpa_sm90_flash_fprop_wgmma_f16_knob_31_64x256x64_4x1x1_kernel0_0

but I also seem to be getting a CUTLASS kernel 1.1 16153656 404 39984.3 39968.0 39168 41088 337.4 void cutlass::Kernel2<cutlass_75_tensorop_bf16_s1688gemm_bf16_128x128_tn_align1>(T1::Params) throwing a wrench in the expected perf

Is the change surfacing some kind of CPU overhead?

@ngimel
Copy link
Collaborator

ngimel commented Oct 22, 2024

Note how max kernel time is much higher on 9.5. @drisspg profiling results also were measuring kernel time (as reported by profiler), not cpu overhead.

@drisspg drisspg modified the milestone: 2.5.1 Oct 25, 2024
@kit1980 kit1980 modified the milestones: 2.5.1, 2.6.0 Oct 25, 2024
@kit1980
Copy link
Member

kit1980 commented Oct 25, 2024

Removing 2.5.1 milestone.
For 2.5.1, the issue was mitigated by #138522

@Skylion007
Copy link
Collaborator

Great news the new CUDNN pypi binary that fixes the stride bug has been released. We can update if the manylinux CD upgrade is ready.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: cudnn Related to torch.backends.cudnn, and CuDNN support module: multi-headed-attention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants