diff --git a/configs/detr/detr_r50_8x4_150e_coco.py b/configs/detr/detr_r50_8x4_150e_coco.py new file mode 100644 index 00000000000..324b1a89035 --- /dev/null +++ b/configs/detr/detr_r50_8x4_150e_coco.py @@ -0,0 +1,130 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py' +] +model = dict( + type='DETR', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(3, ), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch'), + bbox_head=dict( + type='TransformerHead', + num_classes=80, + in_channels=2048, + num_fcs=2, + transformer=dict( + type='Transformer', + embed_dims=256, + num_heads=8, + num_encoder_layers=6, + num_decoder_layers=6, + feedforward_channels=2048, + dropout=0.1, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + num_fcs=2, + pre_norm=False, + return_intermediate_dec=True), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + loss_cls=dict( + type='CrossEntropyLoss', + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + class_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=5.0), + loss_iou=dict(type='GIoULoss', loss_weight=2.0))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='HungarianAssigner', cls_weight=1., bbox_weight=5., + iou_weight=2.)) +test_cfg = dict(max_per_img=100) +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different +# from the default setting in mmdet. +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict( + type='AutoAugment', + policies=[[ + dict( + type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + multiscale_mode='value', + keep_ratio=True) + ], + [ + dict( + type='Resize', + img_scale=[(400, 1333), (500, 1333), (600, 1333)], + multiscale_mode='value', + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + multiscale_mode='value', + override=True, + keep_ratio=True) + ]]), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=1), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +# test_pipeline, NOTE the Pad's size_divisor is different from the default +# setting (size_divisor=32). While there is little effect on the performance +# whether we use the default setting or use size_divisor=1. +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=1), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) +# optimizer +optimizer = dict( + type='AdamW', + lr=0.0001, + weight_decay=0.0001, + paramwise_cfg=dict( + custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)})) +optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2)) +# learning policy +lr_config = dict(policy='step', step=[100]) +total_epochs = 150 +dist_params = dict(_delete_=True, backend='nccl', port=29504) diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index 1826cbf39f3..6fc1f831307 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -8,8 +8,9 @@ InstanceBalancedPosSampler, IoUBalancedNegSampler, OHEMSampler, PseudoSampler, RandomSampler, SamplingResult, ScoreHLRSampler) -from .transforms import (bbox2distance, bbox2result, bbox2roi, bbox_flip, - bbox_mapping, bbox_mapping_back, bbox_rescale, +from .transforms import (bbox2distance, bbox2result, bbox2roi, + bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping, + bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh, distance2bbox, roi2bbox) __all__ = [ @@ -21,5 +22,5 @@ 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance', 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner', - 'bbox_rescale' + 'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh' ] diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index 9ba7e1adcf6..b8f0f48d8cf 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -4,10 +4,12 @@ from .base_assigner import BaseAssigner from .center_region_assigner import CenterRegionAssigner from .grid_assigner import GridAssigner +from .hungarian_assigner import HungarianAssigner from .max_iou_assigner import MaxIoUAssigner from .point_assigner import PointAssigner __all__ = [ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', - 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner' + 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner', + 'HungarianAssigner' ] diff --git a/mmdet/core/bbox/assigners/hungarian_assigner.py b/mmdet/core/bbox/assigners/hungarian_assigner.py new file mode 100644 index 00000000000..224609300f6 --- /dev/null +++ b/mmdet/core/bbox/assigners/hungarian_assigner.py @@ -0,0 +1,158 @@ +import torch +from scipy.optimize import linear_sum_assignment + +from ..builder import BBOX_ASSIGNERS +from ..iou_calculators import build_iou_calculator +from ..transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +@BBOX_ASSIGNERS.register_module() +class HungarianAssigner(BaseAssigner): + """Computes one-to-one matching between predictions and ground truth. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classfication cost, regression L1 cost and regression iou cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + cls_weight (int | float, optional): The scale factor for classification + cost. Default 1.0. + bbox_weight (int | float, optional): The scale factor for regression + L1 cost. Default 1.0. + iou_weight (int | float, optional): The scale factor for regression + iou cost. Default 1.0. + iou_calculator (dict | optional): The config for the iou calculation. + Default type `BboxOverlaps2D`. + iou_mode (str | optional): "iou" (intersection over union), "iof" + (intersection over foreground), or "giou" (generalized + intersection over union). Default "giou". + """ + + def __init__(self, + cls_weight=1., + bbox_weight=1., + iou_weight=1., + iou_calculator=dict(type='BboxOverlaps2D'), + iou_mode='giou'): + # defaultly giou cost is used in the official DETR repo. + self.iou_mode = iou_mode + self.cls_weight = cls_weight + self.bbox_weight = bbox_weight + self.iou_weight = iou_weight + self.iou_calculator = build_iou_calculator(iou_calculator) + + def assign(self, + bbox_pred, + cls_pred, + gt_bboxes, + gt_labels, + img_meta, + gt_bboxes_ignore=None, + eps=1e-7): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + bbox_pred (Tensor): Predicted boxes with normalized coordinates + (cx, cy, w, h), which are all in range [0, 1]. Shape + [num_query, 4]. + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_bboxes (Tensor): Ground truth boxes with unnormalized + coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + img_meta (dict): Meta information for current image. + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert gt_bboxes_ignore is None, \ + 'Only case when gt_bboxes_ignore is None is supported.' + num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0) + + # 1. assign -1 by default + assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), + -1, + dtype=torch.long) + assigned_labels = bbox_pred.new_full((num_bboxes, ), + -1, + dtype=torch.long) + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + + # 2. compute the weighted costs + # classification cost. + # Following the official DETR repo, contrary to the loss that + # NLL is used, we approximate it in 1 - cls_score[gt_label]. + # The 1 is a constant that doesn't change the matching, + # so it can be ommitted. + cls_score = cls_pred.softmax(-1) + cls_cost = -cls_score[:, gt_labels] # [num_bboxes, num_gt] + + # regression L1 cost + img_h, img_w, _ = img_meta['img_shape'] + factor = torch.Tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0).to(gt_bboxes.device) + gt_bboxes_normalized = gt_bboxes / factor + bbox_cost = torch.cdist( + bbox_pred, bbox_xyxy_to_cxcywh(gt_bboxes_normalized), + p=1) # [num_bboxes, num_gt] + + # regression iou cost, defaultly giou is used in official DETR. + bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor + # overlaps: [num_bboxes, num_gt] + overlaps = self.iou_calculator( + bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False) + # The 1 is a constant that doesn't change the matching, so ommitted. + iou_cost = -overlaps + + # weighted sum of above three costs + cost = self.cls_weight * cls_cost + self.bbox_weight * bbox_cost + cost = cost + self.iou_weight * iou_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = torch.from_numpy(matched_row_inds).to( + bbox_pred.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to( + bbox_pred.device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index 5fbe1dcdd94..102db0d1f38 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -194,3 +194,31 @@ def bbox_rescale(bboxes, scale_factor=1.0): else: rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1) return rescaled_bboxes + + +def bbox_cxcywh_to_xyxy(bbox): + """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + + Returns: + Tensor: Converted bboxes. + """ + cx, cy, w, h = bbox.split((1, 1, 1, 1), dim=-1) + bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)] + return torch.cat(bbox_new, dim=-1) + + +def bbox_xyxy_to_cxcywh(bbox): + """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h). + + Args: + bbox (Tensor): Shape (n, 4) for bboxes. + + Returns: + Tensor: Converted bboxes. + """ + x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1) + bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)] + return torch.cat(bbox_new, dim=-1) diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 5f68a317dad..f9ffda51beb 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -54,6 +54,12 @@ class Resize(object): backend (str): Image resize backend, choices are 'cv2' and 'pillow'. These two backends generates slightly different results. Defaults to 'cv2'. + override (bool, optional): Whether to override `scale` and + `scale_factor` so as to call resize twice. Default False. If True, + after the first resizing, the existed `scale` and `scale_factor` + will be ignored so the second resizing can be allowed. + This option is a work-around for multiple times of resize in DETR. + Defaults to False. """ def __init__(self, @@ -62,7 +68,8 @@ def __init__(self, ratio_range=None, keep_ratio=True, bbox_clip_border=True, - backend='cv2'): + backend='cv2', + override=False): if img_scale is None: self.img_scale = None else: @@ -83,6 +90,8 @@ def __init__(self, self.multiscale_mode = multiscale_mode self.ratio_range = ratio_range self.keep_ratio = keep_ratio + # TODO: refactor the override option in Resize + self.override = override self.bbox_clip_border = bbox_clip_border @staticmethod @@ -280,8 +289,14 @@ def __call__(self, results): else: self._random_scale(results) else: - assert 'scale_factor' not in results, ( - 'scale and scale_factor cannot be both set.') + if not self.override: + assert 'scale_factor' not in results, ( + 'scale and scale_factor cannot be both set.') + else: + results.pop('scale') + if 'scale_factor' in results: + results.pop('scale_factor') + self._random_scale(results) self._resize_img(results) self._resize_bboxes(results) @@ -572,15 +587,29 @@ def __repr__(self): class RandomCrop(object): """Random crop the image & bboxes & masks. + The absolute `crop_size` is sampled based on `crop_type` and `image_size`, + then the cropped results are generated. + Args: - crop_size (tuple): Expected size after cropping, (h, w). - allow_negative_crop (bool): Whether to allow a crop that does not - contain any bbox area. Default to False. + crop_size (tuple): The relative ratio or absolute pixels of + height and width. + crop_type (str, optional): one of "relative_range", "relative", + "absolute", "absolute_range". "relative" randomly crops + (h * crop_size[0], w * crop_size[1]) part from an input of size + (h, w). "relative_range" uniformly samples relative crop size from + range [crop_size[0], 1] and [crop_size[1], 1] for height and width + respectively. "absolute" crops from an input with absolute size + (crop_size[0], crop_size[1]). "absolute_range" uniformly samples + crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w + in range [crop_size[0], min(w, crop_size[1])]. Default "absolute". + allow_negative_crop (bool, optional): Whether to allow a crop that does + not contain any bbox area. Default False. bbox_clip_border (bool, optional): Whether clip the objects outside the border of the image. Defaults to True. Note: - - If the image is smaller than the crop size, return the original image + - If the image is smaller than the absolute crop size, return the + original image. - The keys for bboxes, labels and masks must be aligned. That is, `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and `gt_bboxes_ignore` corresponds to `gt_labels_ignore` and @@ -591,10 +620,21 @@ class RandomCrop(object): def __init__(self, crop_size, + crop_type='absolute', allow_negative_crop=False, bbox_clip_border=True): - assert crop_size[0] > 0 and crop_size[1] > 0 + if crop_type not in [ + 'relative_range', 'relative', 'absolute', 'absolute_range' + ]: + raise ValueError(f'Invalid crop_type {crop_type}.') + if crop_type in ['absolute', 'absolute_range']: + assert crop_size[0] > 0 and crop_size[1] > 0 + assert isinstance(crop_size[0], int) and isinstance( + crop_size[1], int) + else: + assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1 self.crop_size = crop_size + self.crop_type = crop_type self.allow_negative_crop = allow_negative_crop self.bbox_clip_border = bbox_clip_border # The key correspondence from bboxes to labels and masks. @@ -607,26 +647,29 @@ def __init__(self, 'gt_bboxes_ignore': 'gt_masks_ignore' } - def __call__(self, results): - """Call function to randomly crop images, bounding boxes, masks, - semantic segmentation maps. + def _crop_data(self, results, crop_size, allow_negative_crop): + """Function to randomly crop images, bounding boxes, masks, semantic + segmentation maps. Args: results (dict): Result dict from loading pipeline. + crop_size (tuple): Expected absolute size after cropping, (h, w). + allow_negative_crop (bool): Whether to allow a crop that does not + contain any bbox area. Default to False. Returns: dict: Randomly cropped results, 'img_shape' key in result dict is updated according to crop size. """ - + assert crop_size[0] > 0 and crop_size[1] > 0 for key in results.get('img_fields', ['img']): img = results[key] - margin_h = max(img.shape[0] - self.crop_size[0], 0) - margin_w = max(img.shape[1] - self.crop_size[1], 0) + margin_h = max(img.shape[0] - crop_size[0], 0) + margin_w = max(img.shape[1] - crop_size[1], 0) offset_h = np.random.randint(0, margin_h + 1) offset_w = np.random.randint(0, margin_w + 1) - crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] - crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] # crop the image img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] @@ -646,9 +689,9 @@ def __call__(self, results): valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & ( bboxes[:, 3] > bboxes[:, 1]) # If the crop does not contain any gt-bbox area and - # self.allow_negative_crop is False, skip this image. + # allow_negative_crop is False, skip this image. if (key == 'gt_bboxes' and not valid_inds.any() - and not self.allow_negative_crop): + and not allow_negative_crop): return None results[key] = bboxes[valid_inds, :] # label fields. e.g. gt_labels and gt_labels_ignore @@ -669,8 +712,57 @@ def __call__(self, results): return results + def _get_crop_size(self, image_size): + """Randomly generates the absolute crop size based on `crop_type` and + `image_size`. + + Args: + image_size (tuple): (h, w). + + Returns: + crop_size (tuple): (crop_h, crop_w) in absolute pixels. + """ + h, w = image_size + if self.crop_type == 'absolute': + return (min(self.crop_size[0], h), min(self.crop_size[1], w)) + elif self.crop_type == 'absolute_range': + assert self.crop_size[0] <= self.crop_size[1] + crop_h = np.random.randint( + min(h, self.crop_size[0]), + min(h, self.crop_size[1]) + 1) + crop_w = np.random.randint( + min(w, self.crop_size[0]), + min(w, self.crop_size[1]) + 1) + return crop_h, crop_w + elif self.crop_type == 'relative': + crop_h, crop_w = self.crop_size + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + elif self.crop_type == 'relative_range': + crop_size = np.asarray(self.crop_size, dtype=np.float32) + crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size) + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + + def __call__(self, results): + """Call function to randomly crop images, bounding boxes, masks, + semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + image_size = results['img'].shape[:2] + crop_size = self._get_crop_size(image_size) + results = self._crop_data(results, crop_size, self.allow_negative_crop) + return results + def __repr__(self): - repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}), ' + repr_str = self.__class__.__name__ + repr_str += f'(crop_size={self.crop_size}, ' + repr_str += f'crop_type={self.crop_type}, ' + repr_str += f'allow_negative_crop={self.allow_negative_crop}, ' repr_str += f'bbox_clip_border={self.bbox_clip_border})' return repr_str diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 518cd3cc67e..74750546b58 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -21,6 +21,7 @@ from .rpn_head import RPNHead from .sabl_retina_head import SABLRetinaHead from .ssd_head import SSDHead +from .transformer_head import TransformerHead from .vfnet_head import VFNetHead from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead from .yolo_head import YOLOV3Head @@ -32,5 +33,5 @@ 'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead', 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead', 'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead', - 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead' + 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'TransformerHead' ] diff --git a/mmdet/models/dense_heads/transformer_head.py b/mmdet/models/dense_heads/transformer_head.py new file mode 100644 index 00000000000..da3b035609b --- /dev/null +++ b/mmdet/models/dense_heads/transformer_head.py @@ -0,0 +1,655 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, Linear, build_activation_layer +from mmcv.runner import force_fp32 + +from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, + build_assigner, build_sampler, multi_apply, + reduce_mean) +from mmdet.models.utils import (FFN, build_positional_encoding, + build_transformer) +from ..builder import HEADS, build_loss +from .anchor_free_head import AnchorFreeHead + + +@HEADS.register_module() +class TransformerHead(AnchorFreeHead): + """Implements the DETR transformer head. + + See `paper: End-to-End Object Detection with Transformers + `_ for details. + + Args: + num_classes (int): Number of categories excluding the background. + in_channels (int): Number of channels in the input feature map. + num_fcs (int, optional): Number of fully-connected layers used in + `FFN`, which is then used for the regression head. Default 2. + transformer (dict, optional): Config for transformer. + positional_encoding (dict, optional): Config for position encoding. + loss_cls (dict, optional): Config of the classification loss. + Default `CrossEntropyLoss`. + loss_bbox (dict, optional): Config of the regression loss. + Default `L1Loss`. + loss_iou (dict, optional): Config of the regression iou loss. + Default `GIoULoss`. + tran_cfg (dict, optional): Training config of transformer head. + test_cfg (dict, optional): Testing config of transformer head. + + Example: + >>> import torch + >>> self = TransformerHead(80, 2048) + >>> x = torch.rand(1, 2048, 32, 32) + >>> mask = torch.ones(1, 32, 32).to(x.dtype) + >>> mask[:, :16, :15] = 0 + >>> all_cls_scores, all_bbox_preds = self(x, mask) + """ + + def __init__(self, + num_classes, + in_channels, + num_fcs=2, + transformer=dict( + type='Transformer', + embed_dims=256, + num_heads=8, + num_encoder_layers=6, + num_decoder_layers=6, + feedforward_channels=2048, + dropout=0.1, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + num_fcs=2, + pre_norm=False, + return_intermediate_dec=True), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True), + loss_cls=dict( + type='CrossEntropyLoss', + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + class_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=5.0), + loss_iou=dict(type='GIoULoss', loss_weight=2.0), + train_cfg=dict( + assigner=dict( + type='HungarianAssigner', + cls_weight=1., + bbox_weight=5., + iou_weight=2., + iou_calculator=dict(type='BboxOverlaps2D'), + iou_mode='giou')), + test_cfg=dict(max_per_img=100), + **kwargs): + # NOTE here use `AnchorFreeHead` instead of `TransformerHead`, + # since it brings inconvenience when the initialization of + # `AnchorFreeHead` is called. + super(AnchorFreeHead, self).__init__() + use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + assert not use_sigmoid_cls, 'setting use_sigmoid_cls as True is ' \ + 'not supported in DETR, since background is needed for the ' \ + 'matching process.' + assert 'embed_dims' in transformer \ + and 'num_feats' in positional_encoding + num_feats = positional_encoding['num_feats'] + embed_dims = transformer['embed_dims'] + assert num_feats * 2 == embed_dims, 'embed_dims should' \ + f' be exactly 2 times of num_feats. Found {embed_dims}' \ + f' and {num_feats}.' + assert test_cfg is not None and 'max_per_img' in test_cfg + + class_weight = loss_cls.get('class_weight', None) + if class_weight is not None: + assert isinstance(class_weight, float), 'Expected ' \ + 'class_weight to have type float. Found ' \ + f'{type(class_weight)}.' + # NOTE following the official DETR rep0, bg_cls_weight means + # relative classification weight of the no-object class. + bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight) + assert isinstance(bg_cls_weight, float), 'Expected ' \ + 'bg_cls_weight to have type float. Found ' \ + f'{type(bg_cls_weight)}.' + class_weight = torch.ones(num_classes + 1) * class_weight + # set background class as the last indice + class_weight[num_classes] = bg_cls_weight + loss_cls.update({'class_weight': class_weight}) + if 'bg_cls_weight' in loss_cls: + loss_cls.pop('bg_cls_weight') + self.bg_cls_weight = bg_cls_weight + + if train_cfg: + assert 'assigner' in train_cfg, 'assigner should be provided '\ + 'when train_cfg is set.' + assigner = train_cfg['assigner'] + assert loss_cls['loss_weight'] == assigner['cls_weight'], \ + 'The classification weight for loss and matcher should be' \ + 'exactly the same.' + assert loss_bbox['loss_weight'] == assigner['bbox_weight'], \ + 'The regression L1 weight for loss and matcher should be' \ + 'exactly the same.' + assert loss_iou['loss_weight'] == assigner['iou_weight'], \ + 'The regression iou weight for loss and matcher should be' \ + 'exactly the same.' + self.assigner = build_assigner(assigner) + # DETR sampling=False, so use PseudoSampler + sampler_cfg = dict(type='PseudoSampler') + self.sampler = build_sampler(sampler_cfg, context=self) + self.num_classes = num_classes + self.cls_out_channels = num_classes + 1 + self.in_channels = in_channels + self.num_fcs = num_fcs + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.use_sigmoid_cls = use_sigmoid_cls + self.embed_dims = embed_dims + self.num_query = test_cfg['max_per_img'] + self.fp16_enabled = False + self.loss_cls = build_loss(loss_cls) + self.loss_bbox = build_loss(loss_bbox) + self.loss_iou = build_loss(loss_iou) + self.act_cfg = transformer.get('act_cfg', + dict(type='ReLU', inplace=True)) + self.activate = build_activation_layer(self.act_cfg) + self.positional_encoding = build_positional_encoding( + positional_encoding) + self.transformer = build_transformer(transformer) + self._init_layers() + + def _init_layers(self): + """Initialize layers of the transformer head.""" + self.input_proj = Conv2d( + self.in_channels, self.embed_dims, kernel_size=1) + self.fc_cls = Linear(self.embed_dims, self.cls_out_channels) + self.reg_ffn = FFN( + self.embed_dims, + self.embed_dims, + self.num_fcs, + self.act_cfg, + dropout=0.0, + add_residual=False) + self.fc_reg = Linear(self.embed_dims, 4) + self.query_embedding = nn.Embedding(self.num_query, self.embed_dims) + + def init_weights(self, distribution='uniform'): + """Initialize weights of the transformer head.""" + # The initialization for transformer is important + self.transformer.init_weights() + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """load checkpoints.""" + # NOTE here use `AnchorFreeHead` instead of `TransformerHead`, + # since `AnchorFreeHead._load_from_state_dict` should not be + # called here. Invoking the default `Module._load_from_state_dict` + # is enough. + super(AnchorFreeHead, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels. + + - all_cls_scores_list (list[Tensor]): Classification scores \ + for each scale level. Each is a 4D-tensor with shape \ + [nb_dec, bs, num_query, cls_out_channels]. Note \ + `cls_out_channels` should includes background. + - all_bbox_preds_list (list[Tensor]): Sigmoid regression \ + outputs for each scale level. Each is a 4D-tensor with \ + normalized coordinate format (cx, cy, w, h) and shape \ + [nb_dec, bs, num_query, 4]. + """ + num_levels = len(feats) + img_metas_list = [img_metas for _ in range(num_levels)] + return multi_apply(self.forward_single, feats, img_metas_list) + + def forward_single(self, x, img_metas): + """"Forward function for a single feature level. + + Args: + x (Tensor): Input feature from backbone's single stage, shape + [bs, c, h, w]. + img_metas (list[dict]): List of image information. + + Returns: + all_cls_scores (Tensor): Outputs from the classification head, + shape [nb_dec, bs, num_query, cls_out_channels]. Note + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression + head with normalized coordinate format (cx, cy, w, h). + Shape [nb_dec, bs, num_query, 4]. + """ + # construct binary masks which used for the transformer. + # NOTE following the official DETR repo, non-zero values representing + # ignored positions, while zero values means valid positions. + batch_size = x.size(0) + input_img_h, input_img_w = img_metas[0]['batch_intput_shape'] + masks = x.new_ones((batch_size, input_img_h, input_img_w)) + for img_id in range(batch_size): + img_h, img_w, _ = img_metas[img_id]['img_shape'] + masks[img_id, :img_h, :img_w] = 0 + + x = self.input_proj(x) + # interpolate masks to have the same spatial shape with x + masks = F.interpolate( + masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1) + # position encoding + pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w] + # outs_dec: [nb_dec, bs, num_query, embed_dim] + outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight, + pos_embed) + + all_cls_scores = self.fc_cls(outs_dec) + all_bbox_preds = self.fc_reg(self.activate( + self.reg_ffn(outs_dec))).sigmoid() + return all_cls_scores, all_bbox_preds + + @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list')) + def loss(self, + all_cls_scores_list, + all_bbox_preds_list, + gt_bboxes_list, + gt_labels_list, + img_metas, + gt_bboxes_ignore=None): + """"Loss function. + + Only outputs from the last feature level are used for computing + losses by default. + + Args: + all_cls_scores_list (list[Tensor]): Classification outputs + for each feature level. Each is a 4D-tensor with shape + [nb_dec, bs, num_query, cls_out_channels]. + all_bbox_preds_list (list[Tensor]): Sigmoid regression + outputs for each feature level. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and shape + [nb_dec, bs, num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + gt_bboxes_ignore (list[Tensor], optional): Bounding boxes + which can be ignored for each image. Default None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + # NOTE defaultly only the outputs from the last feature scale is used. + all_cls_scores = all_cls_scores_list[-1] + all_bbox_preds = all_bbox_preds_list[-1] + assert gt_bboxes_ignore is None, \ + 'Only supports for gt_bboxes_ignore setting to None.' + + num_dec_layers = len(all_cls_scores) + all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_bboxes_ignore_list = [ + gt_bboxes_ignore for _ in range(num_dec_layers) + ] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + + losses_cls, losses_bbox, losses_iou = multi_apply( + self.loss_single, all_cls_scores, all_bbox_preds, + all_gt_bboxes_list, all_gt_labels_list, img_metas_list, + all_gt_bboxes_ignore_list) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_bbox'] = losses_bbox[-1] + loss_dict['loss_iou'] = losses_iou[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1], + losses_bbox[:-1], + losses_iou[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i + loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i + num_dec_layer += 1 + return loss_dict + + def loss_single(self, + cls_scores, + bbox_preds, + gt_bboxes_list, + gt_labels_list, + img_metas, + gt_bboxes_ignore_list=None): + """"Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images. Shape [bs, num_query, cls_out_channels]. + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape [bs, num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + + Returns: + dict[str, Tensor]: A dictionary of loss components for outputs from + a single decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, + gt_bboxes_list, gt_labels_list, + img_metas, gt_bboxes_ignore_list) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + labels = torch.cat(labels_list, 0) + label_weights = torch.cat(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + + # classification loss + cls_scores = cls_scores.reshape(-1, self.cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt boxes accross all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_meta, bbox_pred in zip(img_metas, bbox_preds): + img_h, img_w, _ = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0).repeat( + bbox_pred.size(0), 1) + factors.append(factor) + factors = torch.cat(factors, 0) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_iou( + bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) + + # regression L1 loss + loss_bbox = self.loss_bbox( + bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + return loss_cls, loss_bbox, loss_iou + + def get_targets(self, + cls_scores_list, + bbox_preds_list, + gt_bboxes_list, + gt_labels_list, + img_metas, + gt_bboxes_ignore_list=None): + """"Compute regression and classification targets for a batch image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image with shape [num_query, + cls_out_channels]. + bbox_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + img_metas (list[dict]): List of image meta information. + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all \ + images. + - bbox_targets_list (list[Tensor]): BBox targets for all \ + images. + - bbox_weights_list (list[Tensor]): BBox weights for all \ + images. + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + """ + assert gt_bboxes_ignore_list is None, \ + 'Only supports for gt_bboxes_ignore setting to None.' + num_imgs = len(cls_scores_list) + gt_bboxes_ignore_list = [ + gt_bboxes_ignore_list for _ in range(num_imgs) + ] + + (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply( + self._get_target_single, cls_scores_list, bbox_preds_list, + gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, + cls_score, + bbox_pred, + gt_bboxes, + gt_labels, + img_meta, + gt_bboxes_ignore=None): + """"Compute regression and classification targets for one image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_query, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_query, 4]. + gt_bboxes (Tensor): Ground truth bboxes for one image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (Tensor): Ground truth class indices for one image + with shape (num_gts, ). + img_meta (dict): Meta information for one image. + gt_bboxes_ignore (Tensor, optional): Bounding boxes + which can be ignored. Default None. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + + num_bboxes = bbox_pred.size(0) + # assigner and sampler + assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, + gt_labels, img_meta, + gt_bboxes_ignore) + sampling_result = self.sampler.sample(assign_result, bbox_pred, + gt_bboxes) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label targets + labels = gt_bboxes.new_full((num_bboxes, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_bboxes) + + # bbox targets + bbox_targets = torch.zeros_like(bbox_pred) + bbox_weights = torch.zeros_like(bbox_pred) + bbox_weights[pos_inds] = 1.0 + img_h, img_w, _ = img_meta['img_shape'] + + # DETR regress the relative position of boxes (cxcywh) in the image. + # Thus the learning target should be normalized by the image size, also + # the box format should be converted from defaultly x1y1x2y2 to cxcywh. + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) + pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor + pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized) + bbox_targets[pos_inds] = pos_gt_bboxes_targets + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds) + + # over-write because img_metas are needed as inputs for bbox_head. + def forward_train(self, + x, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=None, + proposal_cfg=None, + **kwargs): + """Forward function for training mode. + + Args: + x (list[Tensor]): Features from backbone. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + proposal_cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert proposal_cfg is None, '"proposal_cfg" must be None' + outs = self(x, img_metas) + if gt_labels is None: + loss_inputs = outs + (gt_bboxes, img_metas) + else: + loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) + losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + return losses + + @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list')) + def get_bboxes(self, + all_cls_scores_list, + all_bbox_preds_list, + img_metas, + rescale=False): + """Transform network outputs for a batch into bbox predictions. + + Args: + all_cls_scores_list (list[Tensor]): Classification outputs + for each feature level. Each is a 4D-tensor with shape + [nb_dec, bs, num_query, cls_out_channels]. + all_bbox_preds_list (list[Tensor]): Sigmoid regression + outputs for each feature level. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and shape + [nb_dec, bs, num_query, 4]. + img_metas (list[dict]): Meta information of each image. + rescale (bool, optional): If True, return boxes in original + image space. Defalut False. + + Returns: + list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \ + The first item is an (n, 5) tensor, where the first 4 columns \ + are bounding box positions (tl_x, tl_y, br_x, br_y) and the \ + 5-th column is a score between 0 and 1. The second item is a \ + (n,) tensor where each item is the predicted class label of \ + the corresponding box. + """ + # NOTE defaultly only using outputs from the last feature level, + # and only the ouputs from the last decoder layer is used. + cls_scores = all_cls_scores_list[-1][-1] + bbox_preds = all_bbox_preds_list[-1][-1] + + result_list = [] + for img_id in range(len(img_metas)): + cls_score = cls_scores[img_id] + bbox_pred = bbox_preds[img_id] + img_shape = img_metas[img_id]['img_shape'] + scale_factor = img_metas[img_id]['scale_factor'] + proposals = self._get_bboxes_single(cls_score, bbox_pred, + img_shape, scale_factor, + rescale) + result_list.append(proposals) + return result_list + + def _get_bboxes_single(self, + cls_score, + bbox_pred, + img_shape, + scale_factor, + rescale=False): + """Transform outputs from the last decoder layer into bbox predictions + for each image. + + Args: + cls_score (Tensor): Box score logits from the last decoder layer + for each image. Shape [num_query, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from the last decoder layer + for each image, with coordinate format (cx, cy, w, h) and + shape [num_query, 4]. + img_shape (tuple[int]): Shape of input image, (height, width, 3). + scale_factor (ndarray, optional): Scale factor of the image arange + as (w_scale, h_scale, w_scale, h_scale). + rescale (bool, optional): If True, return boxes in original image + space. Default False. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. + + - det_bboxes: Predicted bboxes with shape [num_query, 5], \ + where the first 4 columns are bounding box positions \ + (tl_x, tl_y, br_x, br_y) and the 5-th column are scores \ + between 0 and 1. + - det_labels: Predicted labels of the corresponding box with \ + shape [num_query]. + """ + assert len(cls_score) == len(bbox_pred) + # exclude background + scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1) + det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred) + det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] + det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0] + det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1]) + det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0]) + if rescale: + det_bboxes /= det_bboxes.new_tensor(scale_factor) + det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1) + return det_bboxes, det_labels diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index d9e92f56f41..e819f7fd0b8 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -2,6 +2,7 @@ from .base import BaseDetector from .cascade_rcnn import CascadeRCNN from .cornernet import CornerNet +from .detr import DETR from .fast_rcnn import FastRCNN from .faster_rcnn import FasterRCNN from .fcos import FCOS @@ -29,5 +30,5 @@ 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', - 'YOLOV3', 'YOLACT', 'VFNet' + 'YOLOV3', 'YOLACT', 'VFNet', 'DETR' ] diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index 0ddac24a909..627e3912fce 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -61,7 +61,6 @@ def extract_feats(self, imgs): assert isinstance(imgs, list) return [self.extract_feat(img) for img in imgs] - @abstractmethod def forward_train(self, imgs, img_metas, **kwargs): """ Args: @@ -74,7 +73,12 @@ def forward_train(self, imgs, img_metas, **kwargs): :class:`mmdet.datasets.pipelines.Collect`. kwargs (keyword arguments): Specific to concrete implementation. """ - pass + # NOTE the batched image size information may be useful, e.g. + # in DETR, this is needed for the construction of masks, which is + # then used for the transformer_head. + batch_intput_shape = tuple(imgs[0].size()[-2:]) + for img_meta in img_metas: + img_meta['batch_intput_shape'] = batch_intput_shape async def async_simple_test(self, img, img_metas, **kwargs): raise NotImplementedError @@ -136,6 +140,14 @@ def forward_test(self, imgs, img_metas, **kwargs): raise ValueError(f'num of augmentations ({len(imgs)}) ' f'!= num of image meta ({len(img_metas)})') + # NOTE the batched image size information may be useful, e.g. + # in DETR, this is needed for the construction of masks, which is + # then used for the transformer_head. + for img, img_meta in zip(imgs, img_metas): + batch_size = len(img_meta) + for img_id in range(batch_size): + img_meta[img_id]['batch_intput_shape'] = tuple(img.size()[-2:]) + if num_augs == 1: # proposals (List[List[Tensor]]): the outer list indicates # test-time augs (multiscale, flip, etc.) and the inner list diff --git a/mmdet/models/detectors/detr.py b/mmdet/models/detectors/detr.py new file mode 100644 index 00000000000..5ff82a280da --- /dev/null +++ b/mmdet/models/detectors/detr.py @@ -0,0 +1,46 @@ +from mmdet.core import bbox2result +from ..builder import DETECTORS +from .single_stage import SingleStageDetector + + +@DETECTORS.register_module() +class DETR(SingleStageDetector): + r"""Implementation of `DETR: End-to-End Object Detection with + Transformers `_""" + + def __init__(self, + backbone, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(DETR, self).__init__(backbone, None, bbox_head, train_cfg, + test_cfg, pretrained) + + def simple_test(self, img, img_metas, rescale=False): + """Test function without test time augmentation. + + Args: + imgs (list[torch.Tensor]): List of multiple images + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. + The outer list corresponds to each image. The inner list + corresponds to each class. + """ + batch_size = len(img_metas) + assert batch_size == 1, 'Currently only batch_size 1 for inference ' \ + f'mode is supported. Found batch_size {batch_size}.' + x = self.extract_feat(img) + outs = self.bbox_head(x, img_metas) + bbox_list = self.bbox_head.get_bboxes( + *outs, img_metas, rescale=rescale) + + bbox_results = [ + bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) + for det_bboxes, det_labels in bbox_list + ] + return bbox_results diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py index 3932c9afcce..96c4acac08f 100644 --- a/mmdet/models/detectors/single_stage.py +++ b/mmdet/models/detectors/single_stage.py @@ -89,6 +89,7 @@ def forward_train(self, Returns: dict[str, Tensor]: A dictionary of loss components. """ + super(SingleStageDetector, self).forward_train(img, img_metas) x = self.extract_feat(img) losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore) diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py index 8bdcce28ea4..1fc09b5e8fe 100644 --- a/mmdet/models/utils/__init__.py +++ b/mmdet/models/utils/__init__.py @@ -1,4 +1,16 @@ +from .builder import build_positional_encoding, build_transformer from .gaussian_target import gaussian_radius, gen_gaussian_target +from .positional_encoding import (LearnedPositionalEncoding, + SinePositionalEncoding) from .res_layer import ResLayer +from .transformer import (FFN, MultiheadAttention, Transformer, + TransformerDecoder, TransformerDecoderLayer, + TransformerEncoder, TransformerEncoderLayer) -__all__ = ['ResLayer', 'gaussian_radius', 'gen_gaussian_target'] +__all__ = [ + 'ResLayer', 'gaussian_radius', 'gen_gaussian_target', 'MultiheadAttention', + 'FFN', 'TransformerEncoderLayer', 'TransformerEncoder', + 'TransformerDecoderLayer', 'TransformerDecoder', 'Transformer', + 'build_transformer', 'build_positional_encoding', 'SinePositionalEncoding', + 'LearnedPositionalEncoding' +] diff --git a/mmdet/models/utils/builder.py b/mmdet/models/utils/builder.py new file mode 100644 index 00000000000..f362d1c92ca --- /dev/null +++ b/mmdet/models/utils/builder.py @@ -0,0 +1,14 @@ +from mmcv.utils import Registry, build_from_cfg + +TRANSFORMER = Registry('Transformer') +POSITIONAL_ENCODING = Registry('Position encoding') + + +def build_transformer(cfg, default_args=None): + """Builder for Transformer.""" + return build_from_cfg(cfg, TRANSFORMER, default_args) + + +def build_positional_encoding(cfg, default_args=None): + """Builder for Position Encoding.""" + return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args) diff --git a/mmdet/models/utils/positional_encoding.py b/mmdet/models/utils/positional_encoding.py new file mode 100644 index 00000000000..9bda2bbdbfc --- /dev/null +++ b/mmdet/models/utils/positional_encoding.py @@ -0,0 +1,150 @@ +import math + +import torch +import torch.nn as nn +from mmcv.cnn import uniform_init + +from .builder import POSITIONAL_ENCODING + + +@POSITIONAL_ENCODING.register_module() +class SinePositionalEncoding(nn.Module): + """Position encoding with sine and cosine functions. + + See `End-to-End Object Detection with Transformers + `_ for details. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. + temperature (int, optional): The temperature used for scaling + the position embedding. Default 10000. + normalize (bool, optional): Whether to normalize the position + embedding. Default False. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when `normalize` is True. + Default 2*pi. + eps (float, optional): A value added to the denominator for + numerical stability. Default 1e-6. + """ + + def __init__(self, + num_feats, + temperature=10000, + normalize=False, + scale=2 * math.pi, + eps=1e-6): + super(SinePositionalEncoding, self).__init__() + if normalize: + assert isinstance(scale, (float, int)), 'when normalize is set,' \ + 'scale should be provided and in float or int type, ' \ + f'found {type(scale)}' + self.num_feats = num_feats + self.temperature = temperature + self.normalize = normalize + self.scale = scale + self.eps = eps + + def forward(self, mask): + """Forward function for `SinePositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = torch.arange( + self.num_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(num_feats={self.num_feats}, ' + repr_str += f'temperature={self.temperature}, ' + repr_str += f'normalize={self.normalize}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'eps={self.eps})' + return repr_str + + +@POSITIONAL_ENCODING.register_module() +class LearnedPositionalEncoding(nn.Module): + """Position embedding with learnable embedding weights. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. The final returned dimension for + each position is 2 times of this value. + row_num_embed (int, optional): The dictionary size of row embeddings. + Default 50. + col_num_embed (int, optional): The dictionary size of col embeddings. + Default 50. + """ + + def __init__(self, num_feats, row_num_embed=50, col_num_embed=50): + super(LearnedPositionalEncoding, self).__init__() + self.row_embed = nn.Embedding(row_num_embed, num_feats) + self.col_embed = nn.Embedding(col_num_embed, num_feats) + self.num_feats = num_feats + self.row_num_embed = row_num_embed + self.col_num_embed = col_num_embed + self.init_weights() + + def init_weights(self): + """Initialize the learnable weights.""" + uniform_init(self.row_embed) + uniform_init(self.col_embed) + + def forward(self, mask): + """Forward function for `LearnedPositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + h, w = mask.shape[-2:] + x = torch.arange(w, device=mask.device) + y = torch.arange(h, device=mask.device) + x_embed = self.col_embed(x) + y_embed = self.row_embed(y) + pos = torch.cat( + (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat( + 1, w, 1)), + dim=-1).permute(2, 0, + 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(num_feats={self.num_feats}, ' + repr_str += f'row_num_embed={self.row_num_embed}, ' + repr_str += f'col_num_embed={self.col_num_embed})' + return repr_str diff --git a/mmdet/models/utils/transformer.py b/mmdet/models/utils/transformer.py new file mode 100644 index 00000000000..f94b183b5b8 --- /dev/null +++ b/mmdet/models/utils/transformer.py @@ -0,0 +1,744 @@ +import torch +import torch.nn as nn +from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer, + xavier_init) + +from .builder import TRANSFORMER + + +class MultiheadAttention(nn.Module): + """A warpper for torch.nn.MultiheadAttention. + + This module implements MultiheadAttention with residual connection, + and positional encoding used in DETR is also passed as input. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. Same as + `nn.MultiheadAttention`. + dropout (float): A Dropout layer on attn_output_weights. Default 0.0. + """ + + def __init__(self, embed_dims, num_heads, dropout=0.0): + super(MultiheadAttention, self).__init__() + assert embed_dims % num_heads == 0, 'embed_dims must be ' \ + f'divisible by num_heads. got {embed_dims} and {num_heads}.' + self.embed_dims = embed_dims + self.num_heads = num_heads + self.dropout = dropout + self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout) + self.dropout = nn.Dropout(dropout) + + def forward(self, + x, + key=None, + value=None, + residual=None, + query_pos=None, + key_pos=None, + attn_mask=None, + key_padding_mask=None): + """Forward function for `MultiheadAttention`. + + Args: + x (Tensor): The input query with shape [num_query, bs, + embed_dims]. Same in `nn.MultiheadAttention.forward`. + key (Tensor): The key tensor with shape [num_key, bs, + embed_dims]. Same in `nn.MultiheadAttention.forward`. + Default None. If None, the `query` will be used. + value (Tensor): The value tensor with same shape as `key`. + Same in `nn.MultiheadAttention.forward`. Default None. + If None, the `key` will be used. + residual (Tensor): The tensor used for addition, with the + same shape as `x`. Default None. If None, `x` will be used. + query_pos (Tensor): The positional encoding for query, with + the same shape as `x`. Default None. If not None, it will + be added to `x` before forward function. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. Default None. If not None, it will + be added to `key` before forward function. If None, and + `query_pos` has the same shape as `key`, then `query_pos` + will be used for `key_pos`. + attn_mask (Tensor): ByteTensor mask with shape [num_query, + num_key]. Same in `nn.MultiheadAttention.forward`. + Default None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_key]. + Same in `nn.MultiheadAttention.forward`. Default None. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + query = x + if key is None: + key = query + if value is None: + value = key + if residual is None: + residual = x + if key_pos is None: + if query_pos is not None and key is not None: + if query_pos.shape == key.shape: + key_pos = query_pos + if query_pos is not None: + query = query + query_pos + if key_pos is not None: + key = key + key_pos + out = self.attn( + query, + key, + value=value, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask)[0] + + return residual + self.dropout(out) + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(embed_dims={self.embed_dims}, ' + repr_str += f'num_heads={self.num_heads}, ' + repr_str += f'dropout={self.dropout})' + return repr_str + + +class FFN(nn.Module): + """Implements feed-forward networks (FFNs) with residual connection. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. + feedforward_channels (int): The hidden dimension of FFNs. + num_fcs (int): The number of fully-connected layers in FFNs. + act_cfg (dict): The activation config for FFNs. + dropout (float): Probability of an element to be zeroed. Default 0.0. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + dropout=0.0, + add_residual=True): + super(FFN, self).__init__() + assert num_fcs >= 2, 'num_fcs should be no less ' \ + f'than 2. got {num_fcs}.' + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.dropout = dropout + self.activate = build_activation_layer(act_cfg) + + layers = nn.ModuleList() + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + nn.Sequential( + Linear(in_channels, feedforward_channels), self.activate, + nn.Dropout(dropout))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + self.layers = nn.Sequential(*layers) + self.dropout = nn.Dropout(dropout) + self.add_residual = add_residual + + def forward(self, x, residual=None): + """Forward function for `FFN`.""" + out = self.layers(x) + if not self.add_residual: + return out + if residual is None: + residual = x + return residual + self.dropout(out) + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(embed_dims={self.embed_dims}, ' + repr_str += f'feedforward_channels={self.feedforward_channels}, ' + repr_str += f'num_fcs={self.num_fcs}, ' + repr_str += f'act_cfg={self.act_cfg}, ' + repr_str += f'dropout={self.dropout}, ' + repr_str += f'add_residual={self.add_residual})' + return repr_str + + +class TransformerEncoderLayer(nn.Module): + """Implements one encoder layer in DETR transformer. + + Args: + embed_dims (int): The feature dimension. Same as `FFN`. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + dropout (float): Probability of an element to be zeroed. Default 0.0. + order (tuple[str]): The order for encoder layer. Valid examples are + ('selfattn', 'norm', 'ffn', 'norm') and ('norm', 'selfattn', + 'norm', 'ffn'). Default ('selfattn', 'norm', 'ffn', 'norm'). + act_cfg (dict): The activation config for FFNs. Defalut ReLU. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization. + num_fcs (int): The number of fully-connected layers for FFNs. + Default 2. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + dropout=0.0, + order=('selfattn', 'norm', 'ffn', 'norm'), + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + num_fcs=2): + super(TransformerEncoderLayer, self).__init__() + assert isinstance(order, tuple) and len(order) == 4 + assert set(order) == set(['selfattn', 'norm', 'ffn']) + self.embed_dims = embed_dims + self.num_heads = num_heads + self.feedforward_channels = feedforward_channels + self.dropout = dropout + self.order = order + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.num_fcs = num_fcs + self.pre_norm = order[0] == 'norm' + self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout) + self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg, + dropout) + self.norms = nn.ModuleList() + self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) + self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) + + def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None): + """Forward function for `TransformerEncoderLayer`. + + Args: + x (Tensor): The input query with shape [num_key, bs, + embed_dims]. Same in `MultiheadAttention.forward`. + pos (Tensor): The positional encoding for query. Default None. + Same as `query_pos` in `MultiheadAttention.forward`. + attn_mask (Tensor): ByteTensor mask with shape [num_key, + num_key]. Same in `MultiheadAttention.forward`. Default None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_key]. + Same in `MultiheadAttention.forward`. Default None. + + Returns: + Tensor: forwarded results with shape [num_key, bs, embed_dims]. + """ + norm_cnt = 0 + inp_residual = x + for layer in self.order: + if layer == 'selfattn': + # self attention + query = key = value = x + x = self.self_attn( + query, + key, + value, + inp_residual if self.pre_norm else None, + query_pos=pos, + key_pos=pos, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask) + inp_residual = x + elif layer == 'norm': + x = self.norms[norm_cnt](x) + norm_cnt += 1 + elif layer == 'ffn': + x = self.ffn(x, inp_residual if self.pre_norm else None) + return x + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(embed_dims={self.embed_dims}, ' + repr_str += f'num_heads={self.num_heads}, ' + repr_str += f'feedforward_channels={self.feedforward_channels}, ' + repr_str += f'dropout={self.dropout}, ' + repr_str += f'order={self.order}, ' + repr_str += f'act_cfg={self.act_cfg}, ' + repr_str += f'norm_cfg={self.norm_cfg}, ' + repr_str += f'num_fcs={self.num_fcs})' + return repr_str + + +class TransformerDecoderLayer(nn.Module): + """Implements one decoder layer in DETR transformer. + + Args: + embed_dims (int): The feature dimension. Same as + `TransformerEncoderLayer`. + num_heads (int): Parallel attention heads. + feedforward_channels (int): Same as `TransformerEncoderLayer`. + dropout (float): Same as `TransformerEncoderLayer`. Default 0.0. + order (tuple[str]): The order for decoder layer. Valid examples are + ('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', 'norm') and + ('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn'). + Default the former. + act_cfg (dict): Same as `TransformerEncoderLayer`. Defalut ReLU. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization. + num_fcs (int): The number of fully-connected layers in FFNs. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + dropout=0.0, + order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', + 'norm'), + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + num_fcs=2): + super(TransformerDecoderLayer, self).__init__() + assert isinstance(order, tuple) and len(order) == 6 + assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn']) + self.embed_dims = embed_dims + self.num_heads = num_heads + self.feedforward_channels = feedforward_channels + self.dropout = dropout + self.order = order + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.num_fcs = num_fcs + self.pre_norm = order[0] == 'norm' + self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout) + self.multihead_attn = MultiheadAttention(embed_dims, num_heads, + dropout) + self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg, + dropout) + self.norms = nn.ModuleList() + # 3 norm layers in official DETR's TransformerDecoderLayer + for _ in range(3): + self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1]) + + def forward(self, + x, + memory, + memory_pos=None, + query_pos=None, + memory_attn_mask=None, + target_attn_mask=None, + memory_key_padding_mask=None, + target_key_padding_mask=None): + """Forward function for `TransformerDecoderLayer`. + + Args: + x (Tensor): Input query with shape [num_query, bs, embed_dims]. + memory (Tensor): Tensor got from `TransformerEncoder`, with shape + [num_key, bs, embed_dims]. + memory_pos (Tensor): The positional encoding for `memory`. Default + None. Same as `key_pos` in `MultiheadAttention.forward`. + query_pos (Tensor): The positional encoding for `query`. Default + None. Same as `query_pos` in `MultiheadAttention.forward`. + memory_attn_mask (Tensor): ByteTensor mask for `memory`, with + shape [num_key, num_key]. Same as `attn_mask` in + `MultiheadAttention.forward`. Default None. + target_attn_mask (Tensor): ByteTensor mask for `x`, with shape + [num_query, num_query]. Same as `attn_mask` in + `MultiheadAttention.forward`. Default None. + memory_key_padding_mask (Tensor): ByteTensor for `memory`, with + shape [bs, num_key]. Same as `key_padding_mask` in + `MultiheadAttention.forward`. Default None. + target_key_padding_mask (Tensor): ByteTensor for `x`, with shape + [bs, num_query]. Same as `key_padding_mask` in + `MultiheadAttention.forward`. Default None. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + norm_cnt = 0 + inp_residual = x + for layer in self.order: + if layer == 'selfattn': + query = key = value = x + x = self.self_attn( + query, + key, + value, + inp_residual if self.pre_norm else None, + query_pos, + key_pos=query_pos, + attn_mask=target_attn_mask, + key_padding_mask=target_key_padding_mask) + inp_residual = x + elif layer == 'norm': + x = self.norms[norm_cnt](x) + norm_cnt += 1 + elif layer == 'multiheadattn': + query = x + key = value = memory + x = self.multihead_attn( + query, + key, + value, + inp_residual if self.pre_norm else None, + query_pos, + key_pos=memory_pos, + attn_mask=memory_attn_mask, + key_padding_mask=memory_key_padding_mask) + inp_residual = x + elif layer == 'ffn': + x = self.ffn(x, inp_residual if self.pre_norm else None) + return x + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(embed_dims={self.embed_dims}, ' + repr_str += f'num_heads={self.num_heads}, ' + repr_str += f'feedforward_channels={self.feedforward_channels}, ' + repr_str += f'dropout={self.dropout}, ' + repr_str += f'order={self.order}, ' + repr_str += f'act_cfg={self.act_cfg}, ' + repr_str += f'norm_cfg={self.norm_cfg}, ' + repr_str += f'num_fcs={self.num_fcs})' + return repr_str + + +class TransformerEncoder(nn.Module): + """Implements the encoder in DETR transformer. + + Args: + num_layers (int): The number of `TransformerEncoderLayer`. + embed_dims (int): Same as `TransformerEncoderLayer`. + num_heads (int): Same as `TransformerEncoderLayer`. + feedforward_channels (int): Same as `TransformerEncoderLayer`. + dropout (float): Same as `TransformerEncoderLayer`. Default 0.0. + order (tuple[str]): Same as `TransformerEncoderLayer`. + act_cfg (dict): Same as `TransformerEncoderLayer`. Defalut ReLU. + norm_cfg (dict): Same as `TransformerEncoderLayer`. Default + layer normalization. + num_fcs (int): Same as `TransformerEncoderLayer`. Default 2. + """ + + def __init__(self, + num_layers, + embed_dims, + num_heads, + feedforward_channels, + dropout=0.0, + order=('selfattn', 'norm', 'ffn', 'norm'), + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + num_fcs=2): + super(TransformerEncoder, self).__init__() + assert isinstance(order, tuple) and len(order) == 4 + assert set(order) == set(['selfattn', 'norm', 'ffn']) + self.num_layers = num_layers + self.embed_dims = embed_dims + self.num_heads = num_heads + self.feedforward_channels = feedforward_channels + self.dropout = dropout + self.order = order + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.num_fcs = num_fcs + self.pre_norm = order[0] == 'norm' + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + TransformerEncoderLayer(embed_dims, num_heads, + feedforward_channels, dropout, order, + act_cfg, norm_cfg, num_fcs)) + self.norm = build_norm_layer(norm_cfg, + embed_dims)[1] if self.pre_norm else None + + def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None): + """Forward function for `TransformerEncoder`. + + Args: + x (Tensor): Input query. Same in `TransformerEncoderLayer.forward`. + pos (Tensor): Positional encoding for query. Default None. + Same in `TransformerEncoderLayer.forward`. + attn_mask (Tensor): ByteTensor attention mask. Default None. + Same in `TransformerEncoderLayer.forward`. + key_padding_mask (Tensor): Same in + `TransformerEncoderLayer.forward`. Default None. + + Returns: + Tensor: Results with shape [num_key, bs, embed_dims]. + """ + for layer in self.layers: + x = layer(x, pos, attn_mask, key_padding_mask) + if self.norm is not None: + x = self.norm(x) + return x + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(num_layers={self.num_layers}, ' + repr_str += f'embed_dims={self.embed_dims}, ' + repr_str += f'num_heads={self.num_heads}, ' + repr_str += f'feedforward_channels={self.feedforward_channels}, ' + repr_str += f'dropout={self.dropout}, ' + repr_str += f'order={self.order}, ' + repr_str += f'act_cfg={self.act_cfg}, ' + repr_str += f'norm_cfg={self.norm_cfg}, ' + repr_str += f'num_fcs={self.num_fcs})' + return repr_str + + +class TransformerDecoder(nn.Module): + """Implements the decoder in DETR transformer. + + Args: + num_layers (int): The number of `TransformerDecoderLayer`. + embed_dims (int): Same as `TransformerDecoderLayer`. + num_heads (int): Same as `TransformerDecoderLayer`. + feedforward_channels (int): Same as `TransformerDecoderLayer`. + dropout (float): Same as `TransformerDecoderLayer`. Default 0.0. + order (tuple[str]): Same as `TransformerDecoderLayer`. + act_cfg (dict): Same as `TransformerDecoderLayer`. Defalut ReLU. + norm_cfg (dict): Same as `TransformerDecoderLayer`. Default + layer normalization. + num_fcs (int): Same as `TransformerDecoderLayer`. Default 2. + """ + + def __init__(self, + num_layers, + embed_dims, + num_heads, + feedforward_channels, + dropout=0.0, + order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', + 'norm'), + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + num_fcs=2, + return_intermediate=False): + super(TransformerDecoder, self).__init__() + assert isinstance(order, tuple) and len(order) == 6 + assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn']) + self.num_layers = num_layers + self.embed_dims = embed_dims + self.num_heads = num_heads + self.feedforward_channels = feedforward_channels + self.dropout = dropout + self.order = order + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.num_fcs = num_fcs + self.return_intermediate = return_intermediate + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + TransformerDecoderLayer(embed_dims, num_heads, + feedforward_channels, dropout, order, + act_cfg, norm_cfg, num_fcs)) + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + def forward(self, + x, + memory, + memory_pos=None, + query_pos=None, + memory_attn_mask=None, + target_attn_mask=None, + memory_key_padding_mask=None, + target_key_padding_mask=None): + """Forward function for `TransformerDecoder`. + + Args: + x (Tensor): Input query. Same in `TransformerDecoderLayer.forward`. + memory (Tensor): Same in `TransformerDecoderLayer.forward`. + memory_pos (Tensor): Same in `TransformerDecoderLayer.forward`. + Default None. + query_pos (Tensor): Same in `TransformerDecoderLayer.forward`. + Default None. + memory_attn_mask (Tensor): Same in + `TransformerDecoderLayer.forward`. Default None. + target_attn_mask (Tensor): Same in + `TransformerDecoderLayer.forward`. Default None. + memory_key_padding_mask (Tensor): Same in + `TransformerDecoderLayer.forward`. Default None. + target_key_padding_mask (Tensor): Same in + `TransformerDecoderLayer.forward`. Default None. + + Returns: + Tensor: Results with shape [num_query, bs, embed_dims]. + """ + intermediate = [] + for layer in self.layers: + x = layer(x, memory, memory_pos, query_pos, memory_attn_mask, + target_attn_mask, memory_key_padding_mask, + target_key_padding_mask) + if self.return_intermediate: + intermediate.append(self.norm(x)) + if self.norm is not None: + x = self.norm(x) + if self.return_intermediate: + intermediate.pop() + intermediate.append(x) + if self.return_intermediate: + return torch.stack(intermediate) + return x.unsqueeze(0) + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(num_layers={self.num_layers}, ' + repr_str += f'embed_dims={self.embed_dims}, ' + repr_str += f'num_heads={self.num_heads}, ' + repr_str += f'feedforward_channels={self.feedforward_channels}, ' + repr_str += f'dropout={self.dropout}, ' + repr_str += f'order={self.order}, ' + repr_str += f'act_cfg={self.act_cfg}, ' + repr_str += f'norm_cfg={self.norm_cfg}, ' + repr_str += f'num_fcs={self.num_fcs}, ' + repr_str += f'return_intermediate={self.return_intermediate})' + return repr_str + + +@TRANSFORMER.register_module() +class Transformer(nn.Module): + """Implements the DETR transformer. + + Following the official DETR implementation, this module copy-paste + from torch.nn.Transformer with modifications: + + * positional encodings are passed in MultiheadAttention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers + + See `paper: End-to-End Object Detection with Transformers + `_ for details. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. Same as + `nn.MultiheadAttention`. + num_encoder_layers (int): Number of `TransformerEncoderLayer`. + num_decoder_layers (int): Number of `TransformerDecoderLayer`. + feedforward_channels (int): The hidden dimension for FFNs used in both + encoder and decoder. + dropout (float): Probability of an element to be zeroed. Default 0.0. + act_cfg (dict): Activation config for FFNs used in both encoder + and decoder. Defalut ReLU. + norm_cfg (dict): Config dict for normalization used in both encoder + and decoder. Default layer normalization. + num_fcs (int): The number of fully-connected layers in FFNs, which is + used for both encoder and decoder. + pre_norm (bool): Whether the normalization layer is ordered + first in the encoder and decoder. Default False. + return_intermediate_dec (bool): Whether to return the intermediate + output from each TransformerDecoderLayer or only the last + TransformerDecoderLayer. Default False. If False, the returned + `hs` has shape [num_decoder_layers, bs, num_query, embed_dims]. + If True, the returned `hs` will have shape [1, bs, num_query, + embed_dims]. + """ + + def __init__(self, + embed_dims=512, + num_heads=8, + num_encoder_layers=6, + num_decoder_layers=6, + feedforward_channels=2048, + dropout=0.0, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + num_fcs=2, + pre_norm=False, + return_intermediate_dec=False): + super(Transformer, self).__init__() + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_encoder_layers = num_encoder_layers + self.num_decoder_layers = num_decoder_layers + self.feedforward_channels = feedforward_channels + self.dropout = dropout + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.num_fcs = num_fcs + self.pre_norm = pre_norm + self.return_intermediate_dec = return_intermediate_dec + if self.pre_norm: + encoder_order = ('norm', 'selfattn', 'norm', 'ffn') + decoder_order = ('norm', 'selfattn', 'norm', 'multiheadattn', + 'norm', 'ffn') + else: + encoder_order = ('selfattn', 'norm', 'ffn', 'norm') + decoder_order = ('selfattn', 'norm', 'multiheadattn', 'norm', + 'ffn', 'norm') + self.encoder = TransformerEncoder(num_encoder_layers, embed_dims, + num_heads, feedforward_channels, + dropout, encoder_order, act_cfg, + norm_cfg, num_fcs) + self.decoder = TransformerDecoder(num_decoder_layers, embed_dims, + num_heads, feedforward_channels, + dropout, decoder_order, act_cfg, + norm_cfg, num_fcs, + return_intermediate_dec) + + def init_weights(self, distribution='uniform'): + """Initialize the transformer weights.""" + # follow the official DETR to init parameters + for m in self.modules(): + if hasattr(m, 'weight') and m.weight.dim() > 1: + xavier_init(m, distribution=distribution) + + def forward(self, x, mask, query_embed, pos_embed): + """Forward function for `Transformer`. + + Args: + x (Tensor): Input query with shape [bs, c, h, w] where + c = embed_dims. + mask (Tensor): The key_padding_mask used for encoder and decoder, + with shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, with shape + [num_query, c]. + pos_embed (Tensor): The positional encoding for encoder and + decoder, with the same shape as `x`. + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - out_dec: Output from decoder. If return_intermediate_dec \ + is True output has shape [num_dec_layers, bs, + num_query, embed_dims], else has shape [1, bs, \ + num_query, embed_dims]. + - memory: Output results from encoder, with shape \ + [bs, embed_dims, h, w]. + """ + bs, c, h, w = x.shape + x = x.flatten(2).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat( + 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] + mask = mask.flatten(1) # [bs, h, w] -> [bs, h*w] + memory = self.encoder( + x, pos=pos_embed, attn_mask=None, key_padding_mask=mask) + target = torch.zeros_like(query_embed) + # out_dec: [num_layers, num_query, bs, dim] + out_dec = self.decoder( + target, + memory, + memory_pos=pos_embed, + query_pos=query_embed, + memory_attn_mask=None, + target_attn_mask=None, + memory_key_padding_mask=mask, + target_key_padding_mask=None) + out_dec = out_dec.transpose(1, 2) + memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) + return out_dec, memory + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(embed_dims={self.embed_dims}, ' + repr_str += f'num_heads={self.num_heads}, ' + repr_str += f'num_encoder_layers={self.num_encoder_layers}, ' + repr_str += f'num_decoder_layers={self.num_decoder_layers}, ' + repr_str += f'feedforward_channels={self.feedforward_channels}, ' + repr_str += f'dropout={self.dropout}, ' + repr_str += f'act_cfg={self.act_cfg}, ' + repr_str += f'norm_cfg={self.norm_cfg}, ' + repr_str += f'num_fcs={self.num_fcs}, ' + repr_str += f'pre_norm={self.pre_norm}, ' + repr_str += f'return_intermediate_dec={self.return_intermediate_dec})' + return repr_str diff --git a/setup.cfg b/setup.cfg index 64a65cec5b4..873406e8f19 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmdet -known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,robustness_eval,seaborn,six,terminaltables,torch +known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,robustness_eval,scipy,seaborn,six,terminaltables,torch no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_assigner.py b/tests/test_assigner.py index 107dc2a30c7..8e2d4b7e288 100644 --- a/tests/test_assigner.py +++ b/tests/test_assigner.py @@ -7,8 +7,8 @@ import torch from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner, - CenterRegionAssigner, MaxIoUAssigner, - PointAssigner) + CenterRegionAssigner, HungarianAssigner, + MaxIoUAssigner, PointAssigner) def test_max_iou_assigner(): @@ -376,3 +376,37 @@ def test_center_region_assigner_with_empty_gts(): assert len(assign_result.gt_inds) == 2 expected_gt_inds = torch.LongTensor([0, 0]) assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_hungarian_match_assigner(): + self = HungarianAssigner() + assert self.iou_mode == 'giou' + + # test no gt bboxes + bbox_pred = torch.rand((10, 4)) + cls_pred = torch.rand((10, 81)) + gt_bboxes = torch.empty((0, 4)).float() + gt_labels = torch.empty((0, )).long() + img_meta = dict(img_shape=(10, 8, 3)) + assign_result = self.assign(bbox_pred, cls_pred, gt_bboxes, gt_labels, + img_meta) + assert torch.all(assign_result.gt_inds == 0) + assert torch.all(assign_result.labels == -1) + + # test with gt bboxes + gt_bboxes = torch.FloatTensor([[0, 0, 5, 7], [3, 5, 7, 8]]) + gt_labels = torch.LongTensor([1, 20]) + assign_result = self.assign(bbox_pred, cls_pred, gt_bboxes, gt_labels, + img_meta) + assert torch.all(assign_result.gt_inds > -1) + assert (assign_result.gt_inds > 0).sum() == gt_bboxes.size(0) + assert (assign_result.labels > -1).sum() == gt_bboxes.size(0) + + # test iou mode + self = HungarianAssigner(iou_mode='iou') + assert self.iou_mode == 'iou' + assign_result = self.assign(bbox_pred, cls_pred, gt_bboxes, gt_labels, + img_meta) + assert torch.all(assign_result.gt_inds > -1) + assert (assign_result.gt_inds > 0).sum() == gt_bboxes.size(0) + assert (assign_result.labels > -1).sum() == gt_bboxes.size(0) diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index dcbb531f88c..90a11ea3682 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -227,6 +227,81 @@ def area(bboxes): assert (area(results['gt_bboxes']) <= area(gt_bboxes)).all() assert (area(results['gt_bboxes_ignore']) <= area(gt_bboxes_ignore)).all() + # test assertion for invalid crop_type + with pytest.raises(ValueError): + transform = dict( + type='RandomCrop', crop_size=(1, 1), crop_type='unknown') + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid crop_size + with pytest.raises(AssertionError): + transform = dict( + type='RandomCrop', crop_type='relative', crop_size=(0, 0)) + build_from_cfg(transform, PIPELINES) + + def _construct_toy_data(): + img = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.uint8) + img = np.stack([img, img, img], axis=-1) + results = dict() + # image + results['img'] = img + results['img_shape'] = img.shape + results['img_fields'] = ['img'] + # bboxes + results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore'] + results['gt_bboxes'] = np.array([[0., 0., 2., 1.]], dtype=np.float32) + results['gt_bboxes_ignore'] = np.array([[2., 0., 3., 1.]], + dtype=np.float32) + # labels + results['gt_labels'] = np.array([1], dtype=np.int64) + return results + + # test crop_type "relative_range" + results = _construct_toy_data() + transform = dict( + type='RandomCrop', + crop_type='relative_range', + crop_size=(0.3, 0.7), + allow_negative_crop=True) + transform_module = build_from_cfg(transform, PIPELINES) + results_transformed = transform_module(copy.deepcopy(results)) + h, w = results_transformed['img_shape'][:2] + assert int(2 * 0.3 + 0.5) <= h <= int(2 * 1 + 0.5) + assert int(4 * 0.7 + 0.5) <= w <= int(4 * 1 + 0.5) + + # test crop_type "relative" + transform = dict( + type='RandomCrop', + crop_type='relative', + crop_size=(0.3, 0.7), + allow_negative_crop=True) + transform_module = build_from_cfg(transform, PIPELINES) + results_transformed = transform_module(copy.deepcopy(results)) + h, w = results_transformed['img_shape'][:2] + assert h == int(2 * 0.3 + 0.5) and w == int(4 * 0.7 + 0.5) + + # test crop_type "absolute" + transform = dict( + type='RandomCrop', + crop_type='absolute', + crop_size=(1, 2), + allow_negative_crop=True) + transform_module = build_from_cfg(transform, PIPELINES) + results_transformed = transform_module(copy.deepcopy(results)) + h, w = results_transformed['img_shape'][:2] + assert h == 1 and w == 2 + + # test crop_type "absolute_range" + transform = dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(1, 20), + allow_negative_crop=True) + transform_module = build_from_cfg(transform, PIPELINES) + results_transformed = transform_module(copy.deepcopy(results)) + h, w = results_transformed['img_shape'][:2] + assert 1 <= h <= 2 and 1 <= w <= 4 + def test_min_iou_random_crop(): diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index 32ea93a0b2a..0244f02e0ad 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -379,3 +379,59 @@ def test_yolact_forward(): rescale=True, return_loss=False) batch_results.append(result) + + +def test_detr_forward(): + model, train_cfg, test_cfg = _get_detector_cfg( + 'detr/detr_r50_8x4_150e_coco.py') + model['pretrained'] = None + + from mmdet.models import build_detector + detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg) + + input_shape = (1, 3, 550, 550) + mm_inputs = _demo_mm_inputs(input_shape) + + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + + # Test forward train with non-empty truth batch + detector.train() + gt_bboxes = mm_inputs['gt_bboxes'] + gt_labels = mm_inputs['gt_labels'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + return_loss=True) + assert isinstance(losses, dict) + loss, _ = detector._parse_losses(losses) + assert float(loss.item()) > 0 + + # Test forward train with an empty truth batch + mm_inputs = _demo_mm_inputs(input_shape, num_items=[0]) + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_bboxes = mm_inputs['gt_bboxes'] + gt_labels = mm_inputs['gt_labels'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + return_loss=True) + assert isinstance(losses, dict) + loss, _ = detector._parse_losses(losses) + assert float(loss.item()) > 0 + + # Test forward test + detector.eval() + with torch.no_grad(): + img_list = [g[None, :] for g in imgs] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + rescale=True, + return_loss=False) + batch_results.append(result) diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 7dee4639685..2632b6690b2 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -6,8 +6,9 @@ from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps from mmdet.models.dense_heads import (AnchorHead, CornerHead, FCOSHead, FSAFHead, GuidedAnchorHead, PAAHead, - SABLRetinaHead, VFNetHead, YOLACTHead, - YOLACTProtonet, YOLACTSegmHead, paa_head) + SABLRetinaHead, TransformerHead, + VFNetHead, YOLACTHead, YOLACTProtonet, + YOLACTSegmHead, paa_head) from mmdet.models.dense_heads.paa_head import levels_to_images from mmdet.models.roi_heads.bbox_heads import BBoxHead, SABLHead from mmdet.models.roi_heads.mask_heads import FCNMaskHead, MaskIoUHead @@ -1228,3 +1229,85 @@ def test_yolact_head_loss(): one_gt_mask_loss = sum(one_gt_mask_loss['loss_mask']) assert one_gt_segm_loss.item() > 0, 'segm loss should be non-zero' assert one_gt_mask_loss.item() > 0, 'mask loss should be non-zero' + + +def test_transformer_head_loss(): + """Tests transformer head loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3), + 'batch_intput_shape': (s, s) + }] + train_cfg = dict( + assigner=dict( + type='HungarianAssigner', + cls_weight=1., + bbox_weight=5., + iou_weight=2., + iou_calculator=dict(type='BboxOverlaps2D'), + iou_mode='giou')) + transformer_cfg = dict( + type='Transformer', + embed_dims=4, + num_heads=1, + num_encoder_layers=1, + num_decoder_layers=1, + feedforward_channels=1, + dropout=0.1, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + num_fcs=2, + pre_norm=False, + return_intermediate_dec=True) + positional_encoding_cfg = dict( + type='SinePositionalEncoding', num_feats=2, normalize=True) + self = TransformerHead( + num_classes=4, + in_channels=1, + num_fcs=2, + train_cfg=train_cfg, + transformer=transformer_cfg, + positional_encoding=positional_encoding_cfg) + self.init_weights() + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32, 64] + ] + cls_scores, bbox_preds = self.forward(feat, img_metas) + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + for key, loss in empty_gt_losses.items(): + if 'cls' in key: + assert loss.item() > 0, 'cls loss should be non-zero' + elif 'bbox' in key: + assert loss.item( + ) == 0, 'there should be no box loss when there are no true boxes' + elif 'iou' in key: + assert loss.item( + ) == 0, 'there should be no iou loss when there are no true boxes' + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + for loss in one_gt_losses.values(): + assert loss.item( + ) > 0, 'cls loss, or box loss, or iou loss should be non-zero' + + # test forward_train + self.forward_train(feat, img_metas, gt_bboxes, gt_labels) + + # test inference mode + self.get_bboxes(cls_scores, bbox_preds, img_metas, rescale=True) diff --git a/tests/test_models/test_position_encoding.py b/tests/test_models/test_position_encoding.py new file mode 100644 index 00000000000..94fdd479a47 --- /dev/null +++ b/tests/test_models/test_position_encoding.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from mmdet.models.utils import (LearnedPositionalEncoding, + SinePositionalEncoding) + + +def test_sine_positional_encoding(num_feats=16, batch_size=2): + # test invalid type of scale + with pytest.raises(AssertionError): + module = SinePositionalEncoding( + num_feats, scale=(3., ), normalize=True) + + module = SinePositionalEncoding(num_feats) + h, w = 10, 6 + mask = torch.rand(batch_size, h, w) > 0.5 + assert not module.normalize + out = module(mask) + assert out.shape == (batch_size, num_feats * 2, h, w) + + # set normalize + module = SinePositionalEncoding(num_feats, normalize=True) + assert module.normalize + out = module(mask) + assert out.shape == (batch_size, num_feats * 2, h, w) + + +def test_learned_positional_encoding(num_feats=16, + row_num_embed=10, + col_num_embed=10, + batch_size=2): + module = LearnedPositionalEncoding(num_feats, row_num_embed, col_num_embed) + assert module.row_embed.weight.shape == (row_num_embed, num_feats) + assert module.col_embed.weight.shape == (col_num_embed, num_feats) + h, w = 10, 6 + mask = torch.rand(batch_size, h, w) > 0.5 + out = module(mask) + assert out.shape == (batch_size, num_feats * 2, h, w) diff --git a/tests/test_models/test_transformer.py b/tests/test_models/test_transformer.py new file mode 100644 index 00000000000..0e21549ae8b --- /dev/null +++ b/tests/test_models/test_transformer.py @@ -0,0 +1,523 @@ +from unittest.mock import patch + +import pytest +import torch + +from mmdet.models.utils import (FFN, MultiheadAttention, Transformer, + TransformerDecoder, TransformerDecoderLayer, + TransformerEncoder, TransformerEncoderLayer) + + +def _ffn_forward(self, x, residual=None): + if residual is None: + residual = x + residual_str = residual.split('_')[-1] + if '(residual' in residual_str: + residual_str = residual_str.split('(residual')[0] + return x + '_ffn(residual={})'.format(residual_str) + + +def _multihead_attention_forward(self, + x, + key=None, + value=None, + residual=None, + query_pos=None, + key_pos=None, + attn_mask=None, + key_padding_mask=None, + selfattn=True): + if residual is None: + residual = x + residual_str = residual.split('_')[-1] + if '(residual' in residual_str: + residual_str = residual_str.split('(residual')[0] + attn_str = 'selfattn' if selfattn else 'multiheadattn' + return x + '_{}(residual={})'.format(attn_str, residual_str) + + +def _encoder_layer_forward(self, + x, + pos=None, + attn_mask=None, + key_padding_mask=None): + norm_cnt = 0 + inp_residual = x + for layer in self.order: + if layer == 'selfattn': + x = self.self_attn( + x, + x, + x, + inp_residual if self.pre_norm else None, + query_pos=pos, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask) + inp_residual = x + elif layer == 'norm': + x = x + '_norm{}'.format(norm_cnt) + norm_cnt += 1 + elif layer == 'ffn': + x = self.ffn(x, inp_residual if self.pre_norm else None) + else: + raise ValueError(f'Unsupported layer type {layer}.') + return x + + +def _decoder_layer_forward(self, + x, + memory, + memory_pos=None, + query_pos=None, + memory_attn_mask=None, + target_attn_mask=None, + memory_key_padding_mask=None, + target_key_padding_mask=None): + norm_cnt = 0 + inp_residual = x + for layer in self.order: + if layer == 'selfattn': + x = self.self_attn( + x, + x, + x, + inp_residual if self.pre_norm else None, + query_pos, + attn_mask=target_attn_mask, + key_padding_mask=target_key_padding_mask) + inp_residual = x + elif layer == 'norm': + x = x + '_norm{}'.format(norm_cnt) + norm_cnt += 1 + elif layer == 'multiheadattn': + x = self.multihead_attn( + x, + memory, + memory, + inp_residual if self.pre_norm else None, + query_pos, + key_pos=memory_pos, + attn_mask=memory_attn_mask, + key_padding_mask=memory_key_padding_mask, + selfattn=False) + inp_residual = x + elif layer == 'ffn': + x = self.ffn(x, inp_residual if self.pre_norm else None) + else: + raise ValueError(f'Unsupported layer type {layer}.') + return x + + +def test_multihead_attention(embed_dims=8, + num_heads=2, + dropout=0.1, + num_query=5, + num_key=10, + batch_size=1): + module = MultiheadAttention(embed_dims, num_heads, dropout) + # self attention + query = torch.rand(num_query, batch_size, embed_dims) + out = module(query) + assert out.shape == (num_query, batch_size, embed_dims) + + # set key + key = torch.rand(num_key, batch_size, embed_dims) + out = module(query, key) + assert out.shape == (num_query, batch_size, embed_dims) + + # set residual + residual = torch.rand(num_query, batch_size, embed_dims) + out = module(query, key, key, residual) + assert out.shape == (num_query, batch_size, embed_dims) + + # set query_pos and key_pos + query_pos = torch.rand(num_query, batch_size, embed_dims) + key_pos = torch.rand(num_key, batch_size, embed_dims) + out = module(query, key, None, residual, query_pos, key_pos) + assert out.shape == (num_query, batch_size, embed_dims) + + # set key_padding_mask + key_padding_mask = torch.rand(batch_size, num_key) > 0.5 + out = module(query, key, None, residual, query_pos, key_pos, None, + key_padding_mask) + assert out.shape == (num_query, batch_size, embed_dims) + + # set attn_mask + attn_mask = torch.rand(num_query, num_key) > 0.5 + out = module(query, key, key, residual, query_pos, key_pos, attn_mask, + key_padding_mask) + assert out.shape == (num_query, batch_size, embed_dims) + + +def test_ffn(embed_dims=8, feedforward_channels=8, num_fcs=2, batch_size=1): + # test invalid num_fcs + with pytest.raises(AssertionError): + module = FFN(embed_dims, feedforward_channels, 1) + + module = FFN(embed_dims, feedforward_channels, num_fcs) + x = torch.rand(batch_size, embed_dims) + out = module(x) + assert out.shape == (batch_size, embed_dims) + # set residual + residual = torch.rand(batch_size, embed_dims) + out = module(x, residual) + assert out.shape == (batch_size, embed_dims) + + # test case with no residual + module = FFN(embed_dims, feedforward_channels, num_fcs, add_residual=False) + x = torch.rand(batch_size, embed_dims) + out = module(x) + assert out.shape == (batch_size, embed_dims) + + +def test_transformer_encoder_layer(embed_dims=8, + num_heads=2, + feedforward_channels=8, + num_key=10, + batch_size=1): + x = torch.rand(num_key, batch_size, embed_dims) + # test invalid number of order + with pytest.raises(AssertionError): + order = ('norm', 'selfattn', 'norm', 'ffn', 'norm') + module = TransformerEncoderLayer( + embed_dims, num_heads, feedforward_channels, order=order) + + # test invalid value of order + with pytest.raises(AssertionError): + order = ('norm', 'selfattn', 'norm', 'unknown') + module = TransformerEncoderLayer( + embed_dims, num_heads, feedforward_channels, order=order) + + module = TransformerEncoderLayer(embed_dims, num_heads, + feedforward_channels) + + key_padding_mask = torch.rand(batch_size, num_key) > 0.5 + out = module(x, key_padding_mask=key_padding_mask) + assert not module.pre_norm + assert out.shape == (num_key, batch_size, embed_dims) + + # set pos + pos = torch.rand(num_key, batch_size, embed_dims) + out = module(x, pos, key_padding_mask=key_padding_mask) + assert out.shape == (num_key, batch_size, embed_dims) + + # set attn_mask + attn_mask = torch.rand(num_key, num_key) > 0.5 + out = module(x, pos, attn_mask, key_padding_mask) + assert out.shape == (num_key, batch_size, embed_dims) + + # set pre_norm + order = ('norm', 'selfattn', 'norm', 'ffn') + module = TransformerEncoderLayer( + embed_dims, num_heads, feedforward_channels, order=order) + assert module.pre_norm + out = module(x, pos, attn_mask, key_padding_mask) + assert out.shape == (num_key, batch_size, embed_dims) + + @patch('mmdet.models.utils.TransformerEncoderLayer.forward', + _encoder_layer_forward) + @patch('mmdet.models.utils.FFN.forward', _ffn_forward) + @patch('mmdet.models.utils.MultiheadAttention.forward', + _multihead_attention_forward) + def test_order(): + module = TransformerEncoderLayer(embed_dims, num_heads, + feedforward_channels) + out = module('input') + assert out == 'input_selfattn(residual=input)_norm0_ffn' \ + '(residual=norm0)_norm1' + + # pre_norm + order = ('norm', 'selfattn', 'norm', 'ffn') + module = TransformerEncoderLayer( + embed_dims, num_heads, feedforward_channels, order=order) + out = module('input') + assert out == 'input_norm0_selfattn(residual=input)_' \ + 'norm1_ffn(residual=selfattn)' + + test_order() + + +def test_transformer_decoder_layer(embed_dims=8, + num_heads=2, + feedforward_channels=8, + num_key=10, + num_query=5, + batch_size=1): + query = torch.rand(num_query, batch_size, embed_dims) + # test invalid number of order + with pytest.raises(AssertionError): + order = ('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', + 'norm') + module = TransformerDecoderLayer( + embed_dims, num_heads, feedforward_channels, order=order) + + # test invalid value of order + with pytest.raises(AssertionError): + order = ('norm', 'selfattn', 'unknown', 'multiheadattn', 'norm', 'ffn') + module = TransformerDecoderLayer( + embed_dims, num_heads, feedforward_channels, order=order) + + module = TransformerDecoderLayer(embed_dims, num_heads, + feedforward_channels) + memory = torch.rand(num_key, batch_size, embed_dims) + assert not module.pre_norm + out = module(query, memory) + assert out.shape == (num_query, batch_size, embed_dims) + + # set query_pos + query_pos = torch.rand(num_query, batch_size, embed_dims) + out = module(query, memory, memory_pos=None, query_pos=query_pos) + assert out.shape == (num_query, batch_size, embed_dims) + + # set memory_pos + memory_pos = torch.rand(num_key, batch_size, embed_dims) + out = module(query, memory, memory_pos, query_pos) + assert out.shape == (num_query, batch_size, embed_dims) + + # set memory_key_padding_mask + memory_key_padding_mask = torch.rand(batch_size, num_key) > 0.5 + out = module( + query, + memory, + memory_pos, + query_pos, + memory_key_padding_mask=memory_key_padding_mask) + assert out.shape == (num_query, batch_size, embed_dims) + + # set target_key_padding_mask + target_key_padding_mask = torch.rand(batch_size, num_query) > 0.5 + out = module( + query, + memory, + memory_pos, + query_pos, + memory_key_padding_mask=memory_key_padding_mask, + target_key_padding_mask=target_key_padding_mask) + assert out.shape == (num_query, batch_size, embed_dims) + + # set memory_attn_mask + memory_attn_mask = torch.rand(num_query, num_key) + out = module( + query, + memory, + memory_pos, + query_pos, + memory_attn_mask, + memory_key_padding_mask=memory_key_padding_mask, + target_key_padding_mask=target_key_padding_mask) + assert out.shape == (num_query, batch_size, embed_dims) + + # set target_attn_mask + target_attn_mask = torch.rand(num_query, num_query) + out = module(query, memory, memory_pos, query_pos, memory_attn_mask, + target_attn_mask, memory_key_padding_mask, + target_key_padding_mask) + assert out.shape == (num_query, batch_size, embed_dims) + + # pre_norm + order = ('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn') + module = TransformerDecoderLayer( + embed_dims, num_heads, feedforward_channels, order=order) + assert module.pre_norm + out = module( + query, + memory, + memory_pos, + query_pos, + memory_attn_mask, + memory_key_padding_mask=memory_key_padding_mask, + target_key_padding_mask=target_key_padding_mask) + assert out.shape == (num_query, batch_size, embed_dims) + + @patch('mmdet.models.utils.TransformerDecoderLayer.forward', + _decoder_layer_forward) + @patch('mmdet.models.utils.FFN.forward', _ffn_forward) + @patch('mmdet.models.utils.MultiheadAttention.forward', + _multihead_attention_forward) + def test_order(): + module = TransformerDecoderLayer(embed_dims, num_heads, + feedforward_channels) + out = module('input', 'memory') + assert out == 'input_selfattn(residual=input)_norm0_multiheadattn' \ + '(residual=norm0)_norm1_ffn(residual=norm1)_norm2' + + # pre_norm + order = ('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn') + module = TransformerDecoderLayer( + embed_dims, num_heads, feedforward_channels, order=order) + out = module('input', 'memory') + assert out == 'input_norm0_selfattn(residual=input)_norm1_' \ + 'multiheadattn(residual=selfattn)_norm2_ffn(residual=' \ + 'multiheadattn)' + + test_order() + + +def test_transformer_encoder(num_layers=2, + embed_dims=8, + num_heads=2, + feedforward_channels=8, + num_key=10, + batch_size=1): + module = TransformerEncoder(num_layers, embed_dims, num_heads, + feedforward_channels) + assert not module.pre_norm + assert module.norm is None + x = torch.rand(num_key, batch_size, embed_dims) + out = module(x) + assert out.shape == (num_key, batch_size, embed_dims) + + # set pos + pos = torch.rand(num_key, batch_size, embed_dims) + out = module(x, pos) + assert out.shape == (num_key, batch_size, embed_dims) + + # set key_padding_mask + key_padding_mask = torch.rand(batch_size, num_key) > 0.5 + out = module(x, pos, None, key_padding_mask) + assert out.shape == (num_key, batch_size, embed_dims) + + # set attn_mask + attn_mask = torch.rand(num_key, num_key) > 0.5 + out = module(x, pos, attn_mask, key_padding_mask) + assert out.shape == (num_key, batch_size, embed_dims) + + # pre_norm + order = ('norm', 'selfattn', 'norm', 'ffn') + module = TransformerEncoder( + num_layers, embed_dims, num_heads, feedforward_channels, order=order) + assert module.pre_norm + assert module.norm is not None + out = module(x, pos, attn_mask, key_padding_mask) + assert out.shape == (num_key, batch_size, embed_dims) + + +def test_transformer_decoder(num_layers=2, + embed_dims=8, + num_heads=2, + feedforward_channels=8, + num_key=10, + num_query=5, + batch_size=1): + module = TransformerDecoder(num_layers, embed_dims, num_heads, + feedforward_channels) + query = torch.rand(num_query, batch_size, embed_dims) + memory = torch.rand(num_key, batch_size, embed_dims) + out = module(query, memory) + assert out.shape == (1, num_query, batch_size, embed_dims) + + # set query_pos + query_pos = torch.rand(num_query, batch_size, embed_dims) + out = module(query, memory, query_pos=query_pos) + assert out.shape == (1, num_query, batch_size, embed_dims) + + # set memory_pos + memory_pos = torch.rand(num_key, batch_size, embed_dims) + out = module(query, memory, memory_pos, query_pos) + assert out.shape == (1, num_query, batch_size, embed_dims) + + # set memory_key_padding_mask + memory_key_padding_mask = torch.rand(batch_size, num_key) > 0.5 + out = module( + query, + memory, + memory_pos, + query_pos, + memory_key_padding_mask=memory_key_padding_mask) + assert out.shape == (1, num_query, batch_size, embed_dims) + + # set target_key_padding_mask + target_key_padding_mask = torch.rand(batch_size, num_query) > 0.5 + out = module( + query, + memory, + memory_pos, + query_pos, + memory_key_padding_mask=memory_key_padding_mask, + target_key_padding_mask=target_key_padding_mask) + assert out.shape == (1, num_query, batch_size, embed_dims) + + # set memory_attn_mask + memory_attn_mask = torch.rand(num_query, num_key) > 0.5 + out = module(query, memory, memory_pos, query_pos, memory_attn_mask, None, + memory_key_padding_mask, target_key_padding_mask) + assert out.shape == (1, num_query, batch_size, embed_dims) + + # set target_attn_mask + target_attn_mask = torch.rand(num_query, num_query) > 0.5 + out = module(query, memory, memory_pos, query_pos, memory_attn_mask, + target_attn_mask, memory_key_padding_mask, + target_key_padding_mask) + assert out.shape == (1, num_query, batch_size, embed_dims) + + # pre_norm + order = ('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn') + module = TransformerDecoder( + num_layers, embed_dims, num_heads, feedforward_channels, order=order) + out = module(query, memory, memory_pos, query_pos, memory_attn_mask, + target_attn_mask, memory_key_padding_mask, + target_key_padding_mask) + assert out.shape == (1, num_query, batch_size, embed_dims) + + # return_intermediate + module = TransformerDecoder( + num_layers, + embed_dims, + num_heads, + feedforward_channels, + order=order, + return_intermediate=True) + out = module(query, memory, memory_pos, query_pos, memory_attn_mask, + target_attn_mask, memory_key_padding_mask, + target_key_padding_mask) + assert out.shape == (num_layers, num_query, batch_size, embed_dims) + + +def test_transformer(num_enc_layers=2, + num_dec_layers=2, + embed_dims=8, + num_heads=2, + num_query=5, + batch_size=1): + module = Transformer(embed_dims, num_heads, num_enc_layers, num_dec_layers) + height, width = 8, 6 + x = torch.rand(batch_size, embed_dims, height, width) + mask = torch.rand(batch_size, height, width) > 0.5 + query_embed = torch.rand(num_query, embed_dims) + pos_embed = torch.rand(batch_size, embed_dims, height, width) + hs, mem = module(x, mask, query_embed, pos_embed) + assert hs.shape == (1, batch_size, num_query, embed_dims) + assert mem.shape == (batch_size, embed_dims, height, width) + + # pre_norm + module = Transformer( + embed_dims, num_heads, num_enc_layers, num_dec_layers, pre_norm=True) + hs, mem = module(x, mask, query_embed, pos_embed) + assert hs.shape == (1, batch_size, num_query, embed_dims) + assert mem.shape == (batch_size, embed_dims, height, width) + + # return_intermediate + module = Transformer( + embed_dims, + num_heads, + num_enc_layers, + num_dec_layers, + return_intermediate_dec=True) + hs, mem = module(x, mask, query_embed, pos_embed) + assert hs.shape == (num_dec_layers, batch_size, num_query, embed_dims) + assert mem.shape == (batch_size, embed_dims, height, width) + + # pre_norm and return_intermediate + module = Transformer( + embed_dims, + num_heads, + num_enc_layers, + num_dec_layers, + pre_norm=True, + return_intermediate_dec=True) + hs, mem = module(x, mask, query_embed, pos_embed) + assert hs.shape == (num_dec_layers, batch_size, num_query, embed_dims) + assert mem.shape == (batch_size, embed_dims, height, width) + + # test init_weights + module.init_weights()