This is the official implementation of the paper:
RT-DETRv4 is the new version of the state-of-the-art real-time object detector family, RT-DETR. It introduces a cost-effective and adaptable distillation framework that leverages the powerful representations of Vision Foundation Models (VFMs) to enhance lightweight detectors.
RT-DETRv4 achieves new state-of-the-art results on the COCO dataset, outperforming previous real-time detectors.
| Model | AP | AP50 | AP75 | Latency (T4) | FPS (T4) | Config | Log | Checkpoint |
|---|---|---|---|---|---|---|---|---|
| RT-DETRv4-S | 49.8 | 67.1 | 54.0 | 3.66 ms | 273 | yml | log | ckpt |
| RT-DETRv4-M | 53.7 | 71.0 | 58.4 | 5.91 ms | 169 | yml | log | ckpt |
| RT-DETRv4-L | 55.4 | 73.0 | 60.3 | 8.07 ms | 124 | yml | log | ckpt |
| RT-DETRv4-X | 57.0 | 74.6 | 62.1 | 12.90 ms | 78 | yml | log | ckpt |
- [2025.11.17] Code, configs and checkpoints fully released! Thanks for your attention, and feel free to ask any questions!
- [2025.10.30] Repo created, and code will be open-sourced very soon!
This repository also supports the reproduction of DEIM, D-FINE, and RT-DETRv2. Simply run the corresponding configuration files.
conda create -n rtv4 python=3.11.9
conda activate rtv4
pip install -r requirements.txtCOCO2017 Dataset
- Download COCO2017 from OpenDataLab or COCO.
- Modify paths in coco_detection.yml
train_dataloader:
img_folder: /data/COCO2017/train2017/
ann_file: /data/COCO2017/annotations/instances_train2017.json
val_dataloader:
img_folder: /data/COCO2017/val2017/
ann_file: /data/COCO2017/annotations/instances_val2017.jsonCustom Dataset
To train on your custom dataset, you need to organize it in the COCO format. Follow the steps below to prepare your dataset:
-
Set
remap_mscoco_categorytoFalse:This prevents the automatic remapping of category IDs to match the MSCOCO categories.
remap_mscoco_category: False
-
Organize Images:
Structure your dataset directories as follows:
dataset/ ├── images/ │ ├── train/ │ │ ├── image1.jpg │ │ ├── image2.jpg │ │ └── ... │ ├── val/ │ │ ├── image1.jpg │ │ ├── image2.jpg │ │ └── ... └── annotations/ ├── instances_train.json ├── instances_val.json └── ...images/train/: Contains all training images.images/val/: Contains all validation images.annotations/: Contains COCO-formatted annotation files.
-
Convert Annotations to COCO Format:
If your annotations are not already in COCO format, you'll need to convert them. You can use the following Python script as a reference or utilize existing tools:
import json def convert_to_coco(input_annotations, output_annotations): # Implement conversion logic here pass if __name__ == "__main__": convert_to_coco('path/to/your_annotations.json', 'dataset/annotations/instances_train.json')
-
Update Configuration Files:
Modify your custom_detection.yml.
task: detection evaluator: type: CocoEvaluator iou_types: ['bbox', ] num_classes: 777 # your dataset classes remap_mscoco_category: False train_dataloader: type: DataLoader dataset: type: CocoDetection img_folder: /data/yourdataset/train ann_file: /data/yourdataset/train/train.json return_masks: False transforms: type: Compose ops: ~ shuffle: True num_workers: 4 drop_last: True collate_fn: type: BatchImageCollateFunction val_dataloader: type: DataLoader dataset: type: CocoDetection img_folder: /data/yourdataset/val ann_file: /data/yourdataset/val/ann.json return_masks: False transforms: type: Compose ops: ~ shuffle: False num_workers: 4 drop_last: False collate_fn: type: BatchImageCollateFunction
Our framework uses a pre-trained Vision Foundation Model (VFM) as the teacher. We use the ViT-B/16-LVD-1689M model from DINOv3.
Specify the paths to your local DINOv3 repository and the downloaded checkpoint in the model's configuration file ./configs/rtv4/rtv4_hgnetv2_${model}_coco.yml and find the teacher_model section:
teacher_model:
type: "DINOv3TeacherModel"
dinov3_repo_path: dinov3/
dinov3_weights_path: pretrain/dinov3_vitb16_pretrain_lvd1689m.pthUpdate the dinov3_repo_path and dinov3_weights_path to match your local setup.
COCO2017
-
Training
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --master_port=7777 --nproc_per_node=4 train.py -c configs/rtv4/rtv4_hgnetv2_${model}_coco.yml --use-amp --seed=0 -
Testing
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --master_port=7777 --nproc_per_node=4 train.py -c configs/rtv4/rtv4_hgnetv2_${model}_coco.yml --test-only -r model.pth -
Tuning
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --master_port=7777 --nproc_per_node=4 train.py -c configs/rtv4/rtv4_hgnetv2_${model}_coco.yml --use-amp --seed=0 -t model.pth
Customizing Batch Size
For example, if you want to double the total batch size when training RT-DETRv4-L on COCO2017, here are the steps you should follow:
-
Modify your dataloader.yml to increase the
total_batch_size:train_dataloader: total_batch_size: 64 # Previously it was 32, now doubled
-
Modify your rtv4_hgnetv2_l_coco.yml. Here’s how the key parameters should be adjusted:
optimizer: type: AdamW params: - params: '^(?=.*backbone)(?!.*norm|bn).*$' lr: 0.000025 # doubled, linear scaling law - params: '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn)).*$' weight_decay: 0. lr: 0.0005 # doubled, linear scaling law betas: [0.9, 0.999] weight_decay: 0.0001 # need a grid search ema: # added EMA settings decay: 0.9998 # adjusted by 1 - (1 - decay) * 2 warmups: 500 # halved lr_warmup_scheduler: warmup_duration: 250 # halved
Customizing Input Size
If you'd like to train RT-DETRv4 on COCO2017 with an input size of 320x320, follow these steps:
-
Modify your dataloader.yml:
train_dataloader: dataset: transforms: ops: - {type: Resize, size: [320, 320], } collate_fn: base_size: 320 val_dataloader: dataset: transforms: ops: - {type: Resize, size: [320, 320], }
-
Modify your rtv4_base.yml (or the relevant base config file):
eval_spatial_size: [320, 320]
Deployment
-
Setup
pip install onnx onnxsim
-
Export onnx
python tools/deployment/export_onnx.py --check -c configs/rtv4/rtv4_hgnetv2_${model}_coco.yml -r model.pth -
Export tensorrt
trtexec --onnx="model.onnx" --saveEngine="model.engine" --fp16
Inference (Visualization)
-
Setup
pip install -r tools/inference/requirements.txt
-
Inference (onnxruntime / tensorrt / torch)
Inference on images and videos is now supported.
python tools/inference/onnx_inf.py --onnx model.onnx --input image.jpg # or video.mp4 python tools/inference/trt_inf.py --trt model.engine --input image.jpg python tools/inference/torch_inf.py -c configs/rtv4/rtv4_hgnetv2_${model}_coco.yml -r model.pth --input image.jpg --device cuda:0
Benchmark
-
Setup
pip install -r tools/benchmark/requirements.txt
-
Model FLOPs, MACs, and Params
python tools/benchmark/get_info.py -c configs/rtv4/rtv4_hgnetv2_${model}_coco.yml -
TensorRT Latency
python tools/benchmark/trt_benchmark.py --COCO_dir path/to/COCO2017 --engine_dir model.engine
Fiftyone Visualization
-
Setup
pip install fiftyone
-
Voxel51 Fiftyone Visualization (fiftyone)
python tools/visualization/fiftyone_vis.py -c configs/rtv4/rtv4_hgnetv2_${model}_coco.yml -r model.pth
Others
-
Auto Resume Training
bash tools/reference/safe_training.sh
-
Converting Model Weights
python tools/reference/convert_weight.py model.pth
If you find this work helpful, please consider citing:
@article{liao2025rtdetrv4,
title={RT-DETRv4: Painlessly Furthering Real-Time Object Detection with Vision Foundation Models},
author={Zijun Liao and Yian Zhao and Xin Shan and Yu Yan and Chang Liu and Lei Lu and Xiangyang Ji and Jie Chen},
journal={arXiv preprint arXiv:2510.25257},
year={2025}
}Our work is built upon RT-DETR, D-FINE, DEIM and Teacher Model DINOv3. Thanks to these remarkable works!