Skip to content

4x performance regression for 3D convs with AMP on torch 2.9.0 #166122

@FabianIsensee

Description

@FabianIsensee

🐛 Describe the bug

Relative to torch 2.8.0 I see a performance regression when using 3D convs and AMP. 2.9.0 is about 4x slower. This is confirmed using nnU-Net as well as a small dedicated standalone benchmark:

# 3D Convolution Micro-Benchmark (PyTorch) — FP16 AMP fix
import time
import torch as th
from itertools import product

def bench_conv3d(
    device="cuda" if th.cuda.is_available() else "cpu",
    dtype=th.float32,            # used when amp=False
    batch_size=1,
    sizes=((16, 32), (32, 64)),
    kernels=(3, 5),
    iters=20,
    warmup=10,
    amp=False,                   # autocast+GradScaler (GPU)
    require_grad=True,
):
    th.manual_seed(0)
    is_cuda = device.startswith("cuda")
    if is_cuda:
        th.backends.cudnn.benchmark = True

    results = []
    for (cin, cout), k in product(sizes, kernels):
        # Model dtype:
        # - with amp=True -> params MUST stay fp32
        # - with amp=False -> params follow `dtype`
        param_dtype = th.float32 if amp else dtype

        conv = th.nn.Conv3d(cin, cout, kernel_size=k, padding=k // 2, bias=True).to(device=device, dtype=param_dtype)

        # Input dtype:
        # - with amp=True -> fp32 input; autocast will cast to fp16/bf16
        # - with amp=False -> match `dtype`
        x_dtype = th.float32 if amp else dtype
        x = th.randn(batch_size, cin, 128, 128, 128, device=device, dtype=x_dtype, requires_grad=require_grad)

        opt = th.optim.SGD(conv.parameters(), lr=1e-3) if require_grad else None
        use_scaler = amp and is_cuda and require_grad
        scaler = th.cuda.amp.GradScaler(enabled=use_scaler)
        autocast_ctx = th.cuda.amp.autocast(enabled=amp and is_cuda)

        def loss_fn(y): return y.square().mean()

        # Warmup
        for _ in range(warmup):
            with autocast_ctx:
                y = conv(x)
                if require_grad:
                    loss = loss_fn(y)
            if require_grad:
                opt.zero_grad(set_to_none=True)
                if use_scaler:
                    scaler.scale(loss).backward()
                    scaler.step(opt)
                    scaler.update()
                else:
                    loss.backward()
                    opt.step()
            if is_cuda:
                th.cuda.synchronize()

        # Timed runs
        fwd_ms = 0.0
        bwd_ms = 0.0
        for _ in range(iters):
            t0 = time.perf_counter()
            with autocast_ctx:
                y = conv(x)
            if is_cuda:
                th.cuda.synchronize()
            t1 = time.perf_counter()

            if require_grad:
                with autocast_ctx:
                    loss = loss_fn(y)
                opt.zero_grad(set_to_none=True)
                if use_scaler:
                    scaler.scale(loss).backward()
                    scaler.step(opt)
                    scaler.update()
                else:
                    loss.backward()
                    opt.step()
                if is_cuda:
                    th.cuda.synchronize()
            t2 = time.perf_counter()

            fwd_ms += (t1 - t0) * 1000
            bwd_ms += (t2 - t1) * 1000

        fwd_ms /= iters
        bwd_ms /= iters
        total_ms = fwd_ms + (bwd_ms if require_grad else 0.0)

        vox = batch_size * cout * 128 * 128 * 128
        vox_per_s = (vox / (fwd_ms / 1000.0)) if fwd_ms > 0 else float("nan")

        results.append({
            "device": device,
            "dtype": str(dtype).split(".")[-1] if not amp else "fp16(AMP)-params:fp32",
            "amp": amp,
            "N": batch_size,
            "Cin": cin,
            "Cout": cout,
            "k": k,
            "fwd_ms": round(fwd_ms, 2),
            "bwd_ms": round(bwd_ms, 2),
            "total_ms": round(total_ms, 2),
            "vox/s (fwd)": int(vox_per_s),
        })

    # Pretty print
    hdr = f"{'dev':<5} {'dtype':<20} {'amp':<4} {'N':<3} {'Cin':<4} {'Cout':<5} {'k':<2} {'fwd(ms)':<9} {'bwd(ms)':<9} {'total(ms)':<10} {'vox/s(fwd)':<12}"
    print(hdr)
    print("-" * len(hdr))
    for r in results:
        print(f"{r['device']:<5} {r['dtype']:<20} {str(r['amp']):<4} {r['N']:<3} {r['Cin']:<4} {r['Cout']:<5} {r['k']:<2} "
              f"{r['fwd_ms']:<9} {r['bwd_ms']:<9} {r['total_ms']:<10} {r['vox/s (fwd)']:<12}")

if __name__ == "__main__":
    if th.cuda.is_available():
        bench_conv3d(device="cuda", dtype=th.float32, amp=False)      # FP32
        print()
        bench_conv3d(device="cuda", dtype=th.float32, amp=True)       # AMP (params FP32)
    print()
    bench_conv3d(device="cpu", dtype=th.float32, amp=False)           # CPU FP32

torch 2.8.0:

Image

torch 2.9.0:

Image

I ran this on a RTX 4090 using fresh conda envs for torch 2.8.0 and 2.9.0. Ubuntu 24.04.3 LTS
Best,
Fabian

Versions

Collecting environment information...
PyTorch version: 2.9.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39

Python version: 3.12.12 | packaged by conda-forge | (main, Oct 22 2025, 23:25:55) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.14.0-33-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 570.195.03
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 7 9800X3D 8-Core Processor
CPU family: 26
Model: 68
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU(s) scaling MHz: 83%
CPU max MHz: 5455.0000
CPU min MHz: 600.0000
BogoMIPS: 9381.75
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d amd_lbr_pmc_freeze
Virtualization: AMD-V
L1d cache: 384 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 8 MiB (8 instances)
L3 cache: 96 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Ghostwrite: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; IBPB on VMEXIT only
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==2.3.4
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.9.0
[pip3] triton==3.5.0
[conda] numpy 2.3.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] torch 2.9.0 pypi_0 pypi
[conda] triton 3.5.0 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @jerryzh168 @ptrblck @eqy @mcarilli @leslie-fang-intel @jgong5

Metadata

Metadata

Labels

high prioritymodule: amp (automated mixed precision)autocastmodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: performanceIssues related to performance, either of kernel code or framework gluemodule: vllmtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions