Skip to content

Wrong build of TensorRT 10.4 running DINOv3 ViT-B/16 on RTX A6000 #4742

@zakcory

Description

@zakcory

Description

TensorRT 10.4 (shipped with Triton 24.09) miscompiles a DINOv3 ViT-B/16 model exported from PyTorch in fp32. The resulting engine produces output that is uncorrelated (wrong) with the PyTorch reference (cosine similarity 0.018, max absolute difference 4.78 on output values normally in ~±2 range).
ONNX Runtime CPU agrees with PyTorch GPU to within fp32 numerical noise (cosine similarity 1.0000), confirming the ONNX export is correct and the bug is in TensorRT compilation.
The same exact wrong output is produced whether the engine is built via trtexec or polygraphy.

In fp16 the engine produces NaN end-to-end. Common workarounds (lower optimization level, restricted tactic sources, all-fp32 layer precision constraints, ONNX constant folding, graph surgery on attention Reshape) do not fix it.

Environment

TensorRT Version: 10.4.0
NVIDIA GPU: A6000
NVIDIA Driver Version: 580.95.05
CUDA Version: 12.6
CUDNN Version: 9.4.0
Operating System: Fedora 42 host / Ubuntu 22.04 inside container
Python Version: 3.10.12
PyTorch Version: 2.7.1

Relevant Files

Model link: DINOv3 ViTB16 ONNX FP32

Below, I am also attaching the following folder that includes:

  • real_input.npy — preprocessed test input, shape (1, 3, 512, 512) fp32
  • pytorch_reference.npy — PyTorch GPU output, shape (1, 768) fp32
  • ort_output.npy — ONNX Runtime CPU output, shape (1, 768) fp32
  • trt_output.npy — TRT engine output, shape (1, 768) fp32
  • trtexec_build.log — verbose trtexec build log

and the scripts that were used to produce these files:

Inputs, Outputs and References

Steps To Reproduce

Commands or scripts:

  1. Build the TRT engine:
trtexec \
  --onnx=dinov3_vitb16-fp32.onnx \
  --saveEngine=dinov3.trt \
  --shapes=images:1x3x512x512 \
  --verbose > trtexec-build.log
  1. Reproduce the divergence with polygraphy (fastest path):
polygraphy run dinov3_vitb16-fp32.onnx \
  --onnxrt --trt \
  --val-range images:[-2.0,2.0] \
  --atol 1e-3 --rtol 1e-3

Result:
[I] Error Metrics: output
[I] Minimum Required Tolerance: elemwise error | [abs=4.7117] OR [rel=2552.6] (requirements may be lower if both abs/rel tolerances are set)
[I] Absolute Difference | Stats: mean=0.59799, std-dev=0.47202, var=0.22281, median=0.47963, min=0.0047234 at (0, 679), max=4.7117 at (0, 668), avg-magnitude=0.59799, p90=1.165, p95=1.4547, p99=1.9758
[E] FAILED | Output: 'output' | Difference exceeds tolerance (rel=0.001, abs=0.001)

  1. Confirm against PyTorch ground truth using the attached real_input.npy and pytorch_reference.npy:
TRT vs PyTorch-GPU:   max_abs_diff=4.78    cosine_sim=0.0176
TRT vs ORT-CPU:       max_abs_diff=4.78    cosine_sim=0.0176
ORT-CPU vs PyTorch:   max_abs_diff=7.8e-6  cosine_sim=1.0000001

PyTorch GPU and ORT CPU agree to fp32 numerical precision; TRT diverges from both by the same amount. This indicates a TRT compilation bug, not an ONNX export issue.

Have you tried the latest release?:

We are constrained to the TRT version shipped with Triton 24.09 (TRT 10.4) due to deployment requirements (in order to support GPUs that we have).

Can this model run on other frameworks?:

Yes, as previously said, ONNX Runtime CPU produces the correct output (cosine similarity 1.0000001 vs PyTorch GPU).

Additional context

Workarounds tried that did NOT fix it:

  • --builderOptimizationLevel 0, 1, 2 (default 3)
  • --fp16 (produces NaN end-to-end)
  • --noTF32
  • Polygraphy --tactic-sources CUBLAS, CUDNN
  • --precision-constraints obey with all-fp32 layer precision (iteration 1 of polygraphy debug precision already fails with all 4394 layers in fp32)
  • polygraphy surgeon sanitize --fold-constants --override-input-shapes images:[1,3,512,512]
  • onnx-graphsurgeon patch on /blocks.X/attn/Reshape to use static shape [0, 0, 3, 12, 64] with allowzero=0

Localization attempts:

polygraphy debug reduce --mode bisect converges on /blocks.3/attn/Reshape as the minimum failing subgraph. Standalone verification of that subgraph is unreliable due to random shape-tensor inputs causing ORT to reject the reduced model. Per-tensor comparison via polygraphy run --onnx-outputs ... --trt-outputs ... produces inconsistent results: marking different sets of intermediate tensors changes which engine TRT builds, and tensors that pass when marked alone fail when marked alongside others. This made layer-by-layer localization impractical.

Model architecture notes:

DINOv3 ViT-B/16 with 4 register tokens (1029 sequence length at 512×512 input: 1024 patch tokens + CLS + 4 register tokens). Uses RoPE positional encoding implemented via Chunk(2) → Neg → Cat → Mul → Add rotate-half pattern. Attention is implemented manually (not as a single SDPA op), so attention exports as MatMul → Softmax → MatMul.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Module:ONNXIssues relating to ONNX usage and import

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions