Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
297c59d
supports for DETR transformer
v-qjqs Sep 28, 2020
107dbdf
unit test with small data to avoid out of memory in CI
v-qjqs Sep 28, 2020
78315d9
use batch size 1 for unit test to avoid out of memory
v-qjqs Sep 28, 2020
7835e61
move transformer into utils folder and use more small data for unit test
v-qjqs Sep 28, 2020
0e73a19
reformat docstring
v-qjqs Sep 28, 2020
f8c8b32
add more detailed docstring
v-qjqs Sep 29, 2020
d3bc592
Merge branch 'master' of https://github.com/open-mmlab/mmdetection in…
v-qjqs Sep 29, 2020
eabd8ea
reforamt
v-qjqs Sep 29, 2020
82f06e2
Merge pull request #3848 from v-qjqs/transformer
v-qjqs Sep 30, 2020
c03540e
reformat and add build_transformer (#3866)
v-qjqs Oct 9, 2020
431dfa4
Supports for DETR position embedding (#3850)
v-qjqs Oct 14, 2020
7eea638
Supports for DETR inference (#3941)
v-qjqs Oct 19, 2020
ece6ffb
Merge branch 'master' of github.com:open-mmlab/mmdetection into detr
ZwwWayne Oct 19, 2020
3c35f43
Supports for DETR hungarian matcher. (#3929)
v-qjqs Oct 19, 2020
41be486
Merge branch 'master' of github.com:open-mmlab/mmdetection into detr
ZwwWayne Nov 12, 2020
34e9572
Re-implements RandomCrop to support different crop_type (#4093)
v-qjqs Nov 23, 2020
e80d294
Supports for DETR training mode in process. (#3963)
v-qjqs Nov 26, 2020
68d092a
Merge branch 'master' of github.com:open-mmlab/mmdetection into detr
ZwwWayne Nov 26, 2020
a5fea3e
Supports DETR e150 config (#4197)
v-qjqs Nov 28, 2020
f91b58d
add comments on override option in Resize
ZwwWayne Nov 28, 2020
0c88157
add comments on override option in Resize
ZwwWayne Nov 28, 2020
a553bc3
position embeddingto positional encoding
ZwwWayne Nov 28, 2020
934dddf
fix unit tests
ZwwWayne Nov 28, 2020
175f99b
fix registry name bug
ZwwWayne Nov 28, 2020
62f7ea0
rename file
ZwwWayne Nov 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions configs/detr/detr_r50_8x4_150e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
_base_ = [
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
]
model = dict(
type='DETR',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3, ),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch'),
bbox_head=dict(
type='TransformerHead',
num_classes=80,
in_channels=2048,
num_fcs=2,
transformer=dict(
type='Transformer',
embed_dims=256,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
feedforward_channels=2048,
dropout=0.1,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
num_fcs=2,
pre_norm=False,
return_intermediate_dec=True),
positional_encoding=dict(
type='SinePositionalEncoding', num_feats=128, normalize=True),
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
use_sigmoid=False,
loss_weight=1.0,
class_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='HungarianAssigner', cls_weight=1., bbox_weight=5.,
iou_weight=2.))
test_cfg = dict(max_per_img=100)
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='AutoAugment',
policies=[[
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
multiscale_mode='value',
keep_ratio=True)
],
[
dict(
type='Resize',
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
multiscale_mode='value',
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='Resize',
img_scale=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
multiscale_mode='value',
override=True,
keep_ratio=True)
]]),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=1),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
# test_pipeline, NOTE the Pad's size_divisor is different from the default
# setting (size_divisor=32). While there is little effect on the performance
# whether we use the default setting or use size_divisor=1.
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=1),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='AdamW',
lr=0.0001,
weight_decay=0.0001,
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[100])
total_epochs = 150
dist_params = dict(_delete_=True, backend='nccl', port=29504)
7 changes: 4 additions & 3 deletions mmdet/core/bbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
InstanceBalancedPosSampler, IoUBalancedNegSampler,
OHEMSampler, PseudoSampler, RandomSampler,
SamplingResult, ScoreHLRSampler)
from .transforms import (bbox2distance, bbox2result, bbox2roi, bbox_flip,
bbox_mapping, bbox_mapping_back, bbox_rescale,
from .transforms import (bbox2distance, bbox2result, bbox2roi,
bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh,
distance2bbox, roi2bbox)

__all__ = [
Expand All @@ -21,5 +22,5 @@
'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner',
'bbox_rescale'
'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh'
]
4 changes: 3 additions & 1 deletion mmdet/core/bbox/assigners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from .base_assigner import BaseAssigner
from .center_region_assigner import CenterRegionAssigner
from .grid_assigner import GridAssigner
from .hungarian_assigner import HungarianAssigner
from .max_iou_assigner import MaxIoUAssigner
from .point_assigner import PointAssigner

__all__ = [
'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner'
'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
'HungarianAssigner'
]
158 changes: 158 additions & 0 deletions mmdet/core/bbox/assigners/hungarian_assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import torch
from scipy.optimize import linear_sum_assignment

from ..builder import BBOX_ASSIGNERS
from ..iou_calculators import build_iou_calculator
from ..transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
from .assign_result import AssignResult
from .base_assigner import BaseAssigner


@BBOX_ASSIGNERS.register_module()
class HungarianAssigner(BaseAssigner):
"""Computes one-to-one matching between predictions and ground truth.

This class computes an assignment between the targets and the predictions
based on the costs. The costs are weighted sum of three components:
classfication cost, regression L1 cost and regression iou cost. The
targets don't include the no_object, so generally there are more
predictions than targets. After the one-to-one matching, the un-matched
are treated as backgrounds. Thus each query prediction will be assigned
with `0` or a positive integer indicating the ground truth index:

- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt

Args:
cls_weight (int | float, optional): The scale factor for classification
cost. Default 1.0.
bbox_weight (int | float, optional): The scale factor for regression
L1 cost. Default 1.0.
iou_weight (int | float, optional): The scale factor for regression
iou cost. Default 1.0.
iou_calculator (dict | optional): The config for the iou calculation.
Default type `BboxOverlaps2D`.
iou_mode (str | optional): "iou" (intersection over union), "iof"
(intersection over foreground), or "giou" (generalized
intersection over union). Default "giou".
"""

def __init__(self,
cls_weight=1.,
bbox_weight=1.,
iou_weight=1.,
iou_calculator=dict(type='BboxOverlaps2D'),
iou_mode='giou'):
# defaultly giou cost is used in the official DETR repo.
self.iou_mode = iou_mode
self.cls_weight = cls_weight
self.bbox_weight = bbox_weight
self.iou_weight = iou_weight
self.iou_calculator = build_iou_calculator(iou_calculator)

def assign(self,
bbox_pred,
cls_pred,
gt_bboxes,
gt_labels,
img_meta,
gt_bboxes_ignore=None,
eps=1e-7):
"""Computes one-to-one matching based on the weighted costs.

This method assign each query prediction to a ground truth or
background. The `assigned_gt_inds` with -1 means don't care,
0 means negative sample, and positive number is the index (1-based)
of assigned gt.
The assignment is done in the following steps, the order matters.

1. assign every prediction to -1
2. compute the weighted costs
3. do Hungarian matching on CPU based on the costs
4. assign all to 0 (background) first, then for each matched pair
between predictions and gts, treat this prediction as foreground
and assign the corresponding gt index (plus 1) to it.

Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(cx, cy, w, h), which are all in range [0, 1]. Shape
[num_query, 4].
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_bboxes (Tensor): Ground truth boxes with unnormalized
coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
img_meta (dict): Meta information for current image.
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`. Default None.
eps (int | float, optional): A value added to the denominator for
numerical stability. Default 1e-7.

Returns:
:obj:`AssignResult`: The assigned result.
"""
assert gt_bboxes_ignore is None, \
'Only case when gt_bboxes_ignore is None is supported.'
num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)

# 1. assign -1 by default
assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
-1,
dtype=torch.long)
assigned_labels = bbox_pred.new_full((num_bboxes, ),
-1,
dtype=torch.long)
if num_gts == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
if num_gts == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)

# 2. compute the weighted costs
# classification cost.
# Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching,
# so it can be ommitted.
cls_score = cls_pred.softmax(-1)
cls_cost = -cls_score[:, gt_labels] # [num_bboxes, num_gt]

# regression L1 cost
img_h, img_w, _ = img_meta['img_shape']
factor = torch.Tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0).to(gt_bboxes.device)
gt_bboxes_normalized = gt_bboxes / factor
bbox_cost = torch.cdist(
bbox_pred, bbox_xyxy_to_cxcywh(gt_bboxes_normalized),
p=1) # [num_bboxes, num_gt]

# regression iou cost, defaultly giou is used in official DETR.
bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor
# overlaps: [num_bboxes, num_gt]
overlaps = self.iou_calculator(
bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False)
# The 1 is a constant that doesn't change the matching, so ommitted.
iou_cost = -overlaps

# weighted sum of above three costs
cost = self.cls_weight * cls_cost + self.bbox_weight * bbox_cost
cost = cost + self.iou_weight * iou_cost

# 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
matched_row_inds = torch.from_numpy(matched_row_inds).to(
bbox_pred.device)
matched_col_inds = torch.from_numpy(matched_col_inds).to(
bbox_pred.device)

# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds[:] = 0
# assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
28 changes: 28 additions & 0 deletions mmdet/core/bbox/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,31 @@ def bbox_rescale(bboxes, scale_factor=1.0):
else:
rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
return rescaled_bboxes


def bbox_cxcywh_to_xyxy(bbox):
"""Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2).

Args:
bbox (Tensor): Shape (n, 4) for bboxes.

Returns:
Tensor: Converted bboxes.
"""
cx, cy, w, h = bbox.split((1, 1, 1, 1), dim=-1)
bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)]
return torch.cat(bbox_new, dim=-1)


def bbox_xyxy_to_cxcywh(bbox):
"""Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h).

Args:
bbox (Tensor): Shape (n, 4) for bboxes.

Returns:
Tensor: Converted bboxes.
"""
x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1)
bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)]
return torch.cat(bbox_new, dim=-1)
Loading