Skip to content
Merged

GFocal #3097

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions configs/gfl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection


## Introduction

We provide config files to reproduce the object detection results in the paper [Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection](https://arxiv.org/abs/2006.04388)

```
@article{li2020generalized,
title={Generalized Focal Loss: Learning Qualified and Distributed Bounding Boxes for Dense Object Detection},
author={Li, Xiang and Wang, Wenhai and Wu, Lijun and Chen, Shuo and Hu, Xiaolin and Li, Jun and Tang, Jinhui and Yang, Jian},
journal={arXiv preprint arXiv:2006.04388},
year={2020}
}
```


## Results and Models

| Backbone | Style | Lr schd | Multi-scale Training| Inf time (fps) | box AP | Download |
|:-----------------:|:-------:|:-------:|:-------------------:|:--------------:|:------:|:--------:|
| R-50 | pytorch | 1x | No | 19.5 | 40.2 | [model](https://drive.google.com/file/d/1lznguKfDocte6Ur-7wc1V31QxQZm4OQs/view?usp=sharing) | [log](https://drive.google.com/file/d/1Wyia0lsSVNzomUlvtu95Um_GpOfwucN4/view?usp=sharing) |
| R-50 | pytorch | 2x | Yes | 19.5 | 42.9 | [model](https://drive.google.com/file/d/1RN19ndpKlnFGazor-C6NvOsyUlJVQIPI/view?usp=sharing) | [log](https://drive.google.com/file/d/1U_XPe61qaYIn_3n-VM-1JTB_LM8NwNA9/view?usp=sharing) |
| R-101 | pytorch | 2x | Yes | 14.7 | 44.7 | [model](https://drive.google.com/file/d/1WKFcvv1kerYdMuSMVcRezRTk0FH5a6LK/view?usp=sharing) | [log](https://drive.google.com/file/d/1sFnxPUPHM_PohelvCzkJJfQDnfkzmDqg/view?usp=sharing) |
| R-101-dcnv2 | pytorch | 2x | Yes | 12.9 | 47.1 | [model](https://drive.google.com/file/d/1Fp-nLJYPBsohI5JPWOEw9383oxbbxXXe/view?usp=sharing) | [log](https://drive.google.com/file/d/13aiU_gFevQQaDapo8bg7rxi3qU-e4YLl/view?usp=sharing) |
| X-101-32x4d | pytorch | 2x | Yes | 12.1 | 45.9 | [model](https://drive.google.com/file/d/1LTVw8GSMbCGB6wDjqkou934Yl32pVGac/view?usp=sharing) | [log](https://drive.google.com/file/d/10FsArE_cJHFhUtn7Og0Z-ZwzKZdwyeh_/view?usp=sharing) |
| X-101-32x4d-dcnv2 | pytorch | 2x | Yes | 10.7 | 48.2 | [model](https://drive.google.com/file/d/1ULjoJ8H71phrkFOKH4uCzqn9WZnGHAsd/view?usp=sharing) | [log](https://drive.google.com/file/d/12JysUE3pBuIXSaprupFfRlE9_fQY6Mez/view?usp=sharing) |

[1] *1x and 2x mean the model is trained for 90K and 180K iterations, respectively.* \
[2] *All results are obtained with a single model and without any test time data augmentation such as multi-scale, flipping and etc..* \
[3] *`dcnv2` denotes deformable convolutional networks v2.* \
[4] *FPS is tested with a single GeForce RTX 2080Ti GPU, using a batch size of 1.*
14 changes: 14 additions & 0 deletions configs/gfl/gfl_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_base_ = './gfl_r50_fpn_mstrain_2x_coco.py'
model = dict(
pretrained='torchvision://resnet101',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
dcn=dict(type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True),
norm_eval=True,
style='pytorch'))
12 changes: 12 additions & 0 deletions configs/gfl/gfl_r101_fpn_mstrain_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = './gfl_r50_fpn_mstrain_2x_coco.py'
model = dict(
pretrained='torchvision://resnet101',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'))
57 changes: 57 additions & 0 deletions configs/gfl/gfl_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='GFL',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='GFLHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
loss_cls=dict(
type='QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25),
reg_max=16,
loss_bbox=dict(type='GIoULoss', loss_weight=2.0)))
# training and testing settings
train_cfg = dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False)
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.6),
max_per_img=100)
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
22 changes: 22 additions & 0 deletions configs/gfl/gfl_r50_fpn_mstrain_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_base_ = './gfl_r50_fpn_1x_coco.py'
# learning policy
lr_config = dict(step=[16, 22])
total_epochs = 24
# multi-scale training
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 480), (1333, 800)],
multiscale_mode='range',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
data = dict(train=dict(pipeline=train_pipeline))
17 changes: 17 additions & 0 deletions configs/gfl/gfl_x101_32x4d_fpn_dconv_c4-c5_mstrain_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = './gfl_r50_fpn_mstrain_2x_coco.py'
model = dict(
type='GFL',
pretrained='open-mmlab://resnext101_32x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=32,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
dcn=dict(type='DCN', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, False, True, True),
norm_eval=True,
style='pytorch'))
15 changes: 15 additions & 0 deletions configs/gfl/gfl_x101_32x4d_fpn_mstrain_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_ = './gfl_r50_fpn_mstrain_2x_coco.py'
model = dict(
type='GFL',
pretrained='open-mmlab://resnext101_32x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=32,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'))
10 changes: 6 additions & 4 deletions mmdet/core/bbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
from .samplers import (BaseSampler, CombinedSampler,
InstanceBalancedPosSampler, IoUBalancedNegSampler,
PseudoSampler, RandomSampler, SamplingResult)
from .transforms import (bbox2result, bbox2roi, bbox_flip, bbox_mapping,
bbox_mapping_back, distance2bbox, roi2bbox)
from .transforms import (bbox2distance, bbox2result, bbox2roi, bbox_flip,
bbox_mapping, bbox_mapping_back, distance2bbox,
roi2bbox)

__all__ = [
'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler',
'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
'SamplingResult', 'build_assigner', 'build_sampler', 'bbox_flip',
'bbox_mapping', 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result',
'distance2bbox', 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner'
'distance2bbox', 'bbox2distance', 'build_bbox_coder', 'BaseBBoxCoder',
'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder',
'CenterRegionAssigner'
]
24 changes: 24 additions & 0 deletions mmdet/core/bbox/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,27 @@ def distance2bbox(points, distance, max_shape=None):
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
return torch.stack([x1, y1, x2, y2], -1)


def bbox2distance(points, bbox, max_dis=None, eps=0.1):
"""Decode bounding box based on distances.

Args:
points (Tensor): Shape (n, 2), [x, y].
bbox (Tensor): Shape (n, 4), "xyxy" format
max_dis (float): Upper bound of the distance.
eps (float): a small value to ensure target < max_dis, instead <=

Returns:
Tensor: Decoded distances.
"""
left = points[:, 0] - bbox[:, 0]
top = points[:, 1] - bbox[:, 1]
right = bbox[:, 2] - points[:, 0]
bottom = bbox[:, 3] - points[:, 1]
if max_dis is not None:
left = left.clamp(min=0, max=max_dis - eps)
top = top.clamp(min=0, max=max_dis - eps)
right = right.clamp(min=0, max=max_dis - eps)
bottom = bottom.clamp(min=0, max=max_dis - eps)
return torch.stack([left, top, right, bottom], -1)
4 changes: 3 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .fsaf_head import FSAFHead
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .gfl_head import GFLHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .nasfcos_head import NASFCOSHead
from .pisa_retinanet_head import PISARetinaHead
Expand All @@ -20,5 +21,6 @@
'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
'ATSSHead', 'FSAFHead', 'NASFCOSHead', 'PISARetinaHead', 'PISASSDHead'
'ATSSHead', 'FSAFHead', 'NASFCOSHead', 'PISARetinaHead', 'PISASSDHead',
'GFLHead'
]
4 changes: 3 additions & 1 deletion mmdet/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def __init__(self,
self.feat_channels = feat_channels
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
# TODO better way to determine whether sample or not
self.sampling = loss_cls['type'] not in ['FocalLoss', 'GHMC']
self.sampling = loss_cls['type'] not in [
'FocalLoss', 'GHMC', 'QualityFocalLoss'
]
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes
else:
Expand Down
Loading