Skip to content

PINTO0309/rf-detr

 
 

RF-DETR: Real-Time SOTA Detection and Segmentation

version downloads codecov python-version license

arXiv hf space colab roboflow discord

RF-DETR is a real-time transformer architecture for object detection and instance segmentation developed by Roboflow. Built on a DINOv2 vision transformer backbone, RF-DETR delivers state-of-the-art accuracy and latency trade-offs on Microsoft COCO and RF100-VL.

RF-DETR uses a DINOv2 vision transformer backbone and supports both detection and instance segmentation in a single, consistent API. The open-source rfdetr package and Apache-designated models are released under Apache 2.0, while Plus components (rfdetr_plus, including RF-DETR-XL/2XL detection models) are licensed under PML 1.0.

rf-detr-segmentation-promo.mp4

Install

To install RF-DETR, install the rfdetr package in a Python 3.12 environment with pip.

pip install rfdetr
Install from source

By installing RF-DETR from source, you can explore the most recent features and enhancements that have not yet been officially released. Please note that these updates are still in development and may not be as stable as the latest published release.

pip install https://github.com/roboflow/rf-detr/archive/refs/heads/develop.zip

Benchmarks

RF-DETR achieves state-of-the-art results in both object detection and instance segmentation, with benchmarks reported on Microsoft COCO and RF100-VL. The charts and tables below compare RF-DETR against other top real-time models across accuracy and latency for detection and segmentation. All latency numbers were measured on an NVIDIA T4 using TensorRT, FP16, and batch size 1. For full benchmarking methodology and reproducibility details, see roboflow/sab.

Detection

rf_detr_1-4_latency_accuracy_object_detection

See object detection benchmark numbers
Architecture COCO AP50 COCO AP50:95 RF100VL AP50 RF100VL AP50:95 Latency (ms) Params (M) Resolution License
RF-DETR-N 67.6 48.4 85.0 57.7 2.3 30.5 384x384 Apache 2.0
RF-DETR-S 72.1 53.0 86.7 60.2 3.5 32.1 512x512 Apache 2.0
RF-DETR-M 73.6 54.7 87.4 61.2 4.4 33.7 576x576 Apache 2.0
RF-DETR-L 75.1 56.5 88.2 62.2 6.8 33.9 704x704 Apache 2.0
RF-DETR-XL △ 77.4 58.6 88.5 62.9 11.5 126.4 700x700 PML 1.0
RF-DETR-2XL △ 78.5 60.1 89.0 63.2 17.2 126.9 880x880 PML 1.0
YOLO11-N 52.0 37.4 81.4 55.3 2.5 2.6 640x640 AGPL-3.0
YOLO11-S 59.7 44.4 82.3 56.2 3.2 9.4 640x640 AGPL-3.0
YOLO11-M 64.1 48.6 82.5 56.5 5.1 20.1 640x640 AGPL-3.0
YOLO11-L 64.9 49.9 82.2 56.5 6.5 25.3 640x640 AGPL-3.0
YOLO11-X 66.1 50.9 81.7 56.2 10.5 56.9 640x640 AGPL-3.0
YOLO26-N 55.8 40.3 76.7 52.0 1.7 2.6 640x640 AGPL-3.0
YOLO26-S 64.3 47.7 82.7 57.0 2.6 9.4 640x640 AGPL-3.0
YOLO26-M 69.7 52.5 84.4 58.7 4.4 20.1 640x640 AGPL-3.0
YOLO26-L 71.1 54.1 85.0 59.3 5.7 25.3 640x640 AGPL-3.0
YOLO26-X 74.0 56.9 85.6 60.0 9.6 56.9 640x640 AGPL-3.0
LW-DETR-T 60.7 42.9 84.7 57.1 1.9 12.1 640x640 Apache 2.0
LW-DETR-S 66.8 48.0 85.0 57.4 2.6 14.6 640x640 Apache 2.0
LW-DETR-M 72.0 52.6 86.8 59.8 4.4 28.2 640x640 Apache 2.0
LW-DETR-L 74.6 56.1 87.4 61.5 6.9 46.8 640x640 Apache 2.0
LW-DETR-X 76.9 58.3 87.9 62.1 13.0 118.0 640x640 Apache 2.0
D-FINE-N 60.2 42.7 84.4 58.2 2.1 3.8 640x640 Apache 2.0
D-FINE-S 67.6 50.6 85.3 60.3 3.5 10.2 640x640 Apache 2.0
D-FINE-M 72.6 55.0 85.5 60.6 5.4 19.2 640x640 Apache 2.0
D-FINE-L 74.9 57.2 86.4 61.6 7.5 31.0 640x640 Apache 2.0
D-FINE-X 76.8 59.3 86.9 62.2 11.5 62.0 640x640 Apache 2.0

Segmentation

rf_detr_1-4_latency_accuracy_instance_segmentation

See instance segmentation benchmark numbers
Architecture COCO AP50 COCO AP50:95 Latency (ms) Params (M) Resolution License
RF-DETR-Seg-N 63.0 40.3 3.4 33.6 312x312 Apache 2.0
RF-DETR-Seg-S 66.2 43.1 4.4 33.7 384x384 Apache 2.0
RF-DETR-Seg-M 68.4 45.3 5.9 35.7 432x432 Apache 2.0
RF-DETR-Seg-L 70.5 47.1 8.8 36.2 504x504 Apache 2.0
RF-DETR-Seg-XL 72.2 48.8 13.5 38.1 624x624 Apache 2.0
RF-DETR-Seg-2XL 73.1 49.9 21.8 38.6 768x768 Apache 2.0
YOLOv8-N-Seg 45.6 28.3 3.5 3.4 640x640 AGPL-3.0
YOLOv8-S-Seg 53.8 34.0 4.2 11.8 640x640 AGPL-3.0
YOLOv8-M-Seg 58.2 37.3 7.0 27.3 640x640 AGPL-3.0
YOLOv8-L-Seg 60.5 39.0 9.7 46.0 640x640 AGPL-3.0
YOLOv8-XL-Seg 61.3 39.5 14.0 71.8 640x640 AGPL-3.0
YOLOv11-N-Seg 47.8 30.0 3.6 2.9 640x640 AGPL-3.0
YOLOv11-S-Seg 55.4 35.0 4.6 10.1 640x640 AGPL-3.0
YOLOv11-M-Seg 60.0 38.5 6.9 22.4 640x640 AGPL-3.0
YOLOv11-L-Seg 61.5 39.5 8.3 27.6 640x640 AGPL-3.0
YOLOv11-XL-Seg 62.4 40.1 13.7 62.1 640x640 AGPL-3.0
YOLO26-N-Seg 54.3 34.7 2.31 2.7 640x640 AGPL-3.0
YOLO26-S-Seg 62.4 40.2 3.47 10.4 640x640 AGPL-3.0
YOLO26-M-Seg 67.8 44.0 6.32 23.6 640x640 AGPL-3.0
YOLO26-L-Seg 69.8 45.5 7.58 28.0 640x640 AGPL-3.0
YOLO26-X-Seg 71.6 46.8 12.92 62.8 640x640 AGPL-3.0

Run Models

Detection

RF-DETR provides multiple model sizes, ranging from Nano to 2XLarge. To use a different model size, replace the class name in the code snippet below with another class from the table.

import supervision as sv
from rfdetr import RFDETRMedium
from rfdetr.assets.coco_classes import COCO_CLASSES

model = RFDETRMedium()

detections = model.predict("https://media.roboflow.com/dog.jpg", threshold=0.5)

labels = [f"{COCO_CLASSES[class_id]}" for class_id in detections.class_id]

annotated_image = sv.BoxAnnotator().annotate(detections.metadata["source_image"], detections)
annotated_image = sv.LabelAnnotator().annotate(annotated_image, detections, labels)
Run RF-DETR with Inference

You can also run RF-DETR models using the Inference library. To switch model size, select the appropriate inference package alias from the table below.

import requests
import supervision as sv
from PIL import Image
from inference import get_model

model = get_model("rfdetr-medium")

image = Image.open(requests.get("https://media.roboflow.com/dog.jpg", stream=True).raw)
predictions = model.infer(image, confidence=0.5)[0]
detections = sv.Detections.from_inference(predictions)

annotated_image = sv.BoxAnnotator().annotate(image, detections)
annotated_image = sv.LabelAnnotator().annotate(annotated_image, detections)
Size RF-DETR package class Inference package alias COCO AP50 COCO AP50:95 Latency (ms) Params (M) Resolution License
N RFDETRNano rfdetr-nano 67.6 48.4 2.3 30.5 384x384 Apache 2.0
S RFDETRSmall rfdetr-small 72.1 53.0 3.5 32.1 512x512 Apache 2.0
M RFDETRMedium rfdetr-medium 73.6 54.7 4.4 33.7 576x576 Apache 2.0
L RFDETRLarge rfdetr-large 75.1 56.5 6.8 33.9 704x704 Apache 2.0
XL RFDETRXLarge rfdetr-xlarge 77.4 58.6 11.5 126.4 700x700 PML 1.0
2XL RFDETR2XLarge rfdetr-2xlarge 78.5 60.1 17.2 126.9 880x880 PML 1.0

△ Requires the rfdetr_plus extension: pip install rfdetr[plus]. See License for details.

Segmentation

RF-DETR supports instance segmentation with model sizes from Nano to 2XLarge. To use a different model size, replace the class name in the code snippet below with another class from the table.

import supervision as sv
from rfdetr import RFDETRSegMedium
from rfdetr.assets.coco_classes import COCO_CLASSES

model = RFDETRSegMedium()

detections = model.predict("https://media.roboflow.com/dog.jpg", threshold=0.5)

labels = [f"{COCO_CLASSES[class_id]}" for class_id in detections.class_id]

annotated_image = sv.MaskAnnotator().annotate(detections.metadata["source_image"], detections)
annotated_image = sv.LabelAnnotator().annotate(annotated_image, detections, labels)
Run RF-DETR-Seg with Inference

You can also run RF-DETR-Seg models using the Inference library. To switch model size, select the appropriate inference package alias from the table below.

import requests
import supervision as sv
from PIL import Image
from inference import get_model

model = get_model("rfdetr-seg-medium")

image = Image.open(requests.get("https://media.roboflow.com/dog.jpg", stream=True).raw)
predictions = model.infer(image, confidence=0.5)[0]
detections = sv.Detections.from_inference(predictions)

annotated_image = sv.MaskAnnotator().annotate(image, detections)
annotated_image = sv.LabelAnnotator().annotate(annotated_image, detections)
Size RF-DETR package class Inference package alias COCO AP50 COCO AP50:95 Latency (ms) Params (M) Resolution License
N RFDETRSegNano rfdetr-seg-nano 63.0 40.3 3.4 33.6 312x312 Apache 2.0
S RFDETRSegSmall rfdetr-seg-small 66.2 43.1 4.4 33.7 384x384 Apache 2.0
M RFDETRSegMedium rfdetr-seg-medium 68.4 45.3 5.9 35.7 432x432 Apache 2.0
L RFDETRSegLarge rfdetr-seg-large 70.5 47.1 8.8 36.2 504x504 Apache 2.0
XL RFDETRSegXLarge rfdetr-seg-xlarge 72.2 48.8 13.5 38.1 624x624 Apache 2.0
2XL RFDETRSeg2XLarge rfdetr-seg-2xlarge 73.1 49.9 21.8 38.6 768x768 Apache 2.0

Train Models

RF-DETR supports training for both object detection and instance segmentation. You can train models in Google Colab or directly on the Roboflow platform. Below you will find a step-by-step video fine-tuning tutorial.

rf-detr-tutorial-banner

For local RF-DETR-Seg-XL training on a memory-constrained NVIDIA GPU, start with a micro-batch size of 1 and use gradient accumulation to keep the effective batch size larger:

Install this local checkout in editable mode before running the example. The deimv2_coco dataset path is available in this repository version, not in an older rfdetr package that may already be installed in site-packages.

python -m pip install --force-reinstall "numpy==1.26.4" "pyarrow==14.0.1"
python -m pip install -e ".[train,loggers,kornia]"
python - <<'PY'
import rfdetr

print(rfdetr.__file__)
PY
python - <<'PY'
from rfdetr import RFDETRSegXLarge

model = RFDETRSegXLarge(num_classes=49, gradient_checkpointing=True)
model.train(
    dataset_dir="wholebody49",
    dataset_file="deimv2_coco",
    augmentation_profile="deimv2",
    output_dir="output",
    epochs=100,
    batch_size=1,
    grad_accum_steps=16,
    multi_scale=True,
    expanded_scales=True,
    multi_scale_max_offset=0,
    device="cuda",
)
PY

This example assumes the DEIMv2 WholeBody dataset has been placed directly under the RF-DETR repository as wholebody49/, with parquet annotations under wholebody49/annotations/. WholeBody-specific mask, segmentation evaluation, center-target, parquet preload, and num_queries=num_select=1240 defaults are applied when dataset_file="deimv2_coco" and augmentation_profile="deimv2" are used; keep num_classes=49 explicit so the detection head matches the copied WholeBody49 dataset. The sample keeps RF-DETR multi-scale training enabled at resolution=624 but sets expanded_scales=True and multi_scale_max_offset=0, so RFDETRSegXLarge samples from 504x504 through 624x624 without using larger-than-base scales.

With augmentation_profile="deimv2", RF-DETR uses the DEIMv2-compatible CPU transform pipeline for the deimv2_coco dataset. The default training sample pipeline is:

  • RandomPhotometricDistort(p=0.5) from epoch 4 through epoch 89.
  • RandomZoomOut(p=0.5, side_range=[1.0, 1.5]) from epoch 4 through epoch 89.
  • RandomIoUCrop(p=0.8) from epoch 4 through epoch 89.
  • SanitizeBoundingBoxes(min_size=1) to remove invalid boxes while keeping labels, masks, mask_valid, and segm_eval_valid aligned.
  • RandomHorizontalFlipWithClass(p=0.5), including WholeBody left/right class-id swaps from class_flip_pairs.
  • Resize to the model resolution, so RFDETRSegXLarge uses 624x624.
  • A second SanitizeBoundingBoxes(min_size=1).
  • ConvertPILImage(dtype="float32", scale=True).
  • Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).
  • ConvertBoxes(fmt="cxcywh", normalize=True).

Validation and test use only deterministic preprocessing: Resize, ConvertPILImage, Normalize, and ConvertBoxes.

ConvertPILImage, Normalize, and ConvertBoxes are the compatibility boundary with RF-DETR's training loop. They produce float tensor images normalized with ImageNet statistics and convert target boxes to normalized cxcywh, which is the format expected by the matcher, box losses, COCO evaluation callback, and postprocessing utilities. Because the DEIMv2 profile already performs this CPU-side finalization, RF-DETR skips Kornia GPU augmentation for augmentation_profile="deimv2" even if augmentation_backend="auto" or "gpu" is requested.

Mosaic, MixUp, and CopyBlend are available in the DEIMv2 profile but are disabled by default. Enable them explicitly with mosaic_prob, mixup_prob, or copyblend_prob:

model.train(
    dataset_dir="wholebody49",
    dataset_file="deimv2_coco",
    augmentation_profile="deimv2",
    mosaic_prob=0.5,
    mixup_prob=0.15,
    copyblend_prob=0.15,
    ...
)

When mosaic_prob > 0, Mosaic is inserted before the photometric/crop transforms and runs only from epoch 4 through epoch 28. If Mosaic is selected for a sample, RandomZoomOut and RandomIoUCrop are skipped for that sample to avoid conflicting geometric policies. Collate-time MixUp runs from epoch 4 through epoch 28, and CopyBlend runs from epoch 4 through epoch 49. RF-DETR intentionally does not import DEIMv2's collate-time base_size_repeat multi-scale resize; the existing RF-DETR resize, padding, multi_scale, square_resize_div_64, and block-size collate behavior remain authoritative.

To automatically resume and retry when CUDA runs out of memory, wrap the training command in a small shell script. The resume command points at the latest full Lightning checkpoint in output_dir:

#!/usr/bin/env bash
set -u

export CUDA_VISIBLE_DEVICES=0
export DATASET_DIR="wholebody49"
export OUTPUT_DIR="output"
export MAX_OOM_RETRIES=20

run_and_log() {
    local log_file="$1"
    shift
    "$@" 2>&1 | tee "$log_file"
    return "${PIPESTATUS[0]}"
}

is_oom_log() {
    grep -Eqi 'out of memory|OutOfMemoryError|CUDNN_STATUS_ALLOC_FAILED' "$1"
}

train_initial() {
    python - <<PY
from rfdetr import RFDETRSegXLarge

model = RFDETRSegXLarge(num_classes=49, gradient_checkpointing=True)
model.train(
    dataset_dir="${DATASET_DIR}",
    dataset_file="deimv2_coco",
    augmentation_profile="deimv2",
    output_dir="${OUTPUT_DIR}",
    epochs=100,
    batch_size=1,
    grad_accum_steps=16,
    resolution=624,
    group_detr=1,
    multi_scale=True,
    expanded_scales=False,
    device="cuda",
    progress_bar="tqdm",
)
PY
}

train_resume() {
    if [ ! -f "${OUTPUT_DIR}/last.ckpt" ]; then
        echo "No ${OUTPUT_DIR}/last.ckpt found; rerunning the initial command."
        train_initial
        return $?
    fi

    python - <<PY
from rfdetr import RFDETRSegXLarge

model = RFDETRSegXLarge(num_classes=49, gradient_checkpointing=True)
model.train(
    dataset_dir="${DATASET_DIR}",
    dataset_file="deimv2_coco",
    augmentation_profile="deimv2",
    output_dir="${OUTPUT_DIR}",
    epochs=100,
    batch_size=1,
    grad_accum_steps=16,
    resolution=624,
    group_detr=1,
    multi_scale=True,
    expanded_scales=False,
    resume="${OUTPUT_DIR}/last.ckpt",
    device="cuda",
    progress_bar="tqdm",
)
PY
}

run_and_log train_initial.log train_initial
status=$?

if [ "$status" -ne 0 ] && is_oom_log train_initial.log; then
    for attempt in $(seq 1 "$MAX_OOM_RETRIES"); do
        log_file="train_resume_oom_retry_${attempt}.log"
        echo "OOM detected. Resume attempt ${attempt}/${MAX_OOM_RETRIES}..."
        run_and_log "$log_file" train_resume
        status=$?

        [ "$status" -eq 0 ] && break
        is_oom_log "$log_file" || break
    done
fi

echo "Training finished with status=$status"

Documentation

Visit our documentation website to learn more about how to use RF-DETR.

License

Licensing is split by component:

  • The open-source rfdetr package and Apache-designated model weights are licensed under Apache License 2.0. See LICENSE.
  • Plus components, including the rfdetr_plus extension and RF-DETR-XL / RF-DETR-2XL detection models, are licensed under PML 1.0.

Acknowledgements

Our work is built upon LW-DETR, DINOv2, and Deformable DETR. Thanks to their authors for their excellent work!

Citation

If you find our work helpful for your research, please consider citing the following BibTeX entry.

@misc{rf-detr,
    title={RF-DETR: Neural Architecture Search for Real-Time Detection Transformers},
    author={Isaac Robinson and Peter Robicheaux and Matvei Popov and Deva Ramanan and Neehar Peri},
    year={2025},
    eprint={2511.09554},
    archivePrefix={arXiv},
    primaryClass={cs.CV},
    url={https://arxiv.org/abs/2511.09554},
}

Contribute

We welcome and appreciate all contributions! If you notice any issues or bugs, have questions, or would like to suggest new features, please open an issue or pull request. By sharing your ideas and improvements, you help make RF-DETR better for everyone.

About

RF-DETR is a real-time object detection and segmentation model architecture developed by Roboflow, SOTA on COCO, designed for fine-tuning. [ICLR 2026]

Resources

License

Contributing

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 98.7%
  • Jupyter Notebook 1.3%