diff --git a/mmdet/datasets/cityscapes.py b/mmdet/datasets/cityscapes.py index fc6431b6686..499b1b1cce7 100644 --- a/mmdet/datasets/cityscapes.py +++ b/mmdet/datasets/cityscapes.py @@ -24,17 +24,29 @@ class CityscapesDataset(CocoDataset): def _filter_imgs(self, min_size=32): """Filter images too small or without ground truths.""" valid_inds = [] + # obtain images that contain annotation ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.coco.cat_img_map[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_img_ids = [] for i, img_info in enumerate(self.data_infos): img_id = img_info['id'] ann_ids = self.coco.getAnnIds(imgIds=[img_id]) ann_info = self.coco.loadAnns(ann_ids) all_iscrowd = all([_['iscrowd'] for _ in ann_info]) - if self.filter_empty_gt and (self.img_ids[i] not in ids_with_ann + if self.filter_empty_gt and (self.img_ids[i] not in ids_in_cat or all_iscrowd): continue if min(img_info['width'], img_info['height']) >= min_size: valid_inds.append(i) + valid_img_ids.append(img_id) + self.img_ids = valid_img_ids return valid_inds def _parse_ann_info(self, img_info, ann_info): diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py index 692e826590b..703020b207d 100644 --- a/mmdet/datasets/coco.py +++ b/mmdet/datasets/coco.py @@ -96,39 +96,27 @@ def get_cat_ids(self, idx): def _filter_imgs(self, min_size=32): """Filter images too small or without ground truths.""" valid_inds = [] + # obtain images that contain annotation ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) + # obtain images that contain annotations of the required categories + ids_in_cat = set() + for i, class_id in enumerate(self.cat_ids): + ids_in_cat |= set(self.coco.cat_img_map[class_id]) + # merge the image id sets of the two conditions and use the merged set + # to filter out images if self.filter_empty_gt=True + ids_in_cat &= ids_with_ann + + valid_img_ids = [] for i, img_info in enumerate(self.data_infos): - if self.filter_empty_gt and self.img_ids[i] not in ids_with_ann: + img_id = self.img_ids[i] + if self.filter_empty_gt and img_id not in ids_in_cat: continue if min(img_info['width'], img_info['height']) >= min_size: valid_inds.append(i) + valid_img_ids.append(img_id) + self.img_ids = valid_img_ids return valid_inds - def get_subset_by_classes(self): - """Get img ids that contain any category in class_ids. - - Different from the coco.getImgIds(), this function returns the id if - the img contains one of the categories rather than all. - - Args: - class_ids (list[int]): list of category ids - - Return: - ids (list[int]): integer list of img ids - """ - - ids = set() - for i, class_id in enumerate(self.cat_ids): - ids |= set(self.coco.cat_img_map[class_id]) - self.img_ids = list(ids) - - data_infos = [] - for i in self.img_ids: - info = self.coco.load_imgs([i])[0] - info['filename'] = info['file_name'] - data_infos.append(info) - return data_infos - def _parse_ann_info(self, img_info, ann_info): """Parse bbox and mask annotation. diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index 85feba8f2ef..5c6a4699412 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -1,4 +1,5 @@ import os.path as osp +import warnings import mmcv import numpy as np @@ -42,7 +43,9 @@ class CustomDataset(Dataset): ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. test_mode (bool, optional): If set True, annotation will not be loaded. filter_empty_gt (bool, optional): If set true, images without bounding - boxes will be filtered out. + boxes of the dataset's classes will be filtered out. This option + only works when `test_mode=False`, i.e., we never filter images + during tests. """ CLASSES = None @@ -80,23 +83,21 @@ def __init__(self, self.proposal_file) # load annotations (and proposals) self.data_infos = self.load_annotations(self.ann_file) - # filter data infos if classes are customized - if self.custom_classes: - self.data_infos = self.get_subset_by_classes() if self.proposal_file is not None: self.proposals = self.load_proposals(self.proposal_file) else: self.proposals = None - # filter images too small + + # filter images too small and containing no annotations if not test_mode: valid_inds = self._filter_imgs() self.data_infos = [self.data_infos[i] for i in valid_inds] if self.proposals is not None: self.proposals = [self.proposals[i] for i in valid_inds] - # set group flag for the sampler - if not self.test_mode: + # set group flag for the sampler self._set_group_flag() + # processing pipeline self.pipeline = Compose(pipeline) @@ -147,6 +148,9 @@ def pre_pipeline(self, results): def _filter_imgs(self, min_size=32): """Filter images too small.""" + if self.filter_empty_gt: + warnings.warn( + 'CustomDataset does not support filtering empty gt images.') valid_inds = [] for i, img_info in enumerate(self.data_infos): if min(img_info['width'], img_info['height']) >= min_size: @@ -237,12 +241,13 @@ def get_classes(cls, classes=None): string, take it as a file name. The file contains the name of classes where each line contains one class name. If classes is a tuple or list, override the CLASSES defined by the dataset. + + Returns: + tuple[str] or list[str]: Names of categories of the dataset. """ if classes is None: - cls.custom_classes = False return cls.CLASSES - cls.custom_classes = True if isinstance(classes, str): # take it as a file path class_names = mmcv.list_from_file(classes) @@ -253,9 +258,6 @@ def get_classes(cls, classes=None): return class_names - def get_subset_by_classes(self): - return self.data_infos - def format_results(self, results, **kwargs): """Place holder to format result to dataset specific output.""" pass diff --git a/mmdet/datasets/xml_style.py b/mmdet/datasets/xml_style.py index e5b7cae6a44..b912de38d12 100644 --- a/mmdet/datasets/xml_style.py +++ b/mmdet/datasets/xml_style.py @@ -58,22 +58,26 @@ def load_annotations(self, ann_file): return data_infos - def get_subset_by_classes(self): - """Filter imgs by user-defined categories.""" - subset_data_infos = [] - for data_info in self.data_infos: - img_id = data_info['id'] - xml_path = osp.join(self.img_prefix, 'Annotations', - f'{img_id}.xml') - tree = ET.parse(xml_path) - root = tree.getroot() - for obj in root.findall('object'): - name = obj.find('name').text - if name in self.CLASSES: - subset_data_infos.append(data_info) - break - - return subset_data_infos + def _filter_imgs(self, min_size=32): + """Filter images too small or without annotation.""" + valid_inds = [] + for i, img_info in enumerate(self.data_infos): + if min(img_info['width'], img_info['height']) < min_size: + continue + if self.filter_empty_gt: + img_id = img_info['id'] + xml_path = osp.join(self.img_prefix, 'Annotations', + f'{img_id}.xml') + tree = ET.parse(xml_path) + root = tree.getroot() + for obj in root.findall('object'): + name = obj.find('name').text + if name in self.CLASSES: + valid_inds.append(i) + break + else: + valid_inds.append(i) + return valid_inds def get_ann_info(self, idx): """Get annotation from XML file by index. diff --git a/tests/data/coco_sample.json b/tests/data/coco_sample.json new file mode 100644 index 00000000000..b66cdf309e3 --- /dev/null +++ b/tests/data/coco_sample.json @@ -0,0 +1,77 @@ +{ + "images": [ + { + "file_name": "fake1.jpg", + "height": 800, + "width": 800, + "id": 0 + }, + { + "file_name": "fake2.jpg", + "height": 800, + "width": 800, + "id": 1 + }, + { + "file_name": "fake3.jpg", + "height": 800, + "width": 800, + "id": 2 + } + ], + "annotations": [ + { + "bbox": [ + 0, + 0, + 20, + 20 + ], + "area": 400.00, + "score": 1.0, + "category_id": 1, + "id": 1, + "image_id": 0 + }, + { + "bbox": [ + 0, + 0, + 20, + 20 + ], + "area": 400.00, + "score": 1.0, + "category_id": 2, + "id": 2, + "image_id": 0 + }, + { + "bbox": [ + 0, + 0, + 20, + 20 + ], + "area": 400.00, + "score": 1.0, + "category_id": 1, + "id": 3, + "image_id": 1 + } + ], + "categories": [ + { + "id": 1, + "name": "bus", + "supercategory": "none" + }, + { + "id": 2, + "name": "car", + "supercategory": "none" + } + ], + "licenses": [], + "info": null +} diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index a6a23e575eb..83de7125b55 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -204,11 +204,18 @@ def test_dataset_evaluation(): tmp_dir.cleanup() +@patch('mmdet.datasets.CocoDataset.load_annotations', MagicMock) +@patch('mmdet.datasets.CustomDataset.load_annotations', MagicMock) +@patch('mmdet.datasets.XMLDataset.load_annotations', MagicMock) +@patch('mmdet.datasets.CityscapesDataset.load_annotations', MagicMock) +@patch('mmdet.datasets.CocoDataset._filter_imgs', MagicMock) +@patch('mmdet.datasets.CustomDataset._filter_imgs', MagicMock) +@patch('mmdet.datasets.XMLDataset._filter_imgs', MagicMock) +@patch('mmdet.datasets.CityscapesDataset._filter_imgs', MagicMock) @pytest.mark.parametrize('dataset', ['CocoDataset', 'VOCDataset', 'CityscapesDataset']) def test_custom_classes_override_default(dataset): dataset_class = DATASETS.get(dataset) - dataset_class.load_annotations = MagicMock() if dataset in ['CocoDataset', 'CityscapesDataset']: dataset_class.coco = MagicMock() dataset_class.cat_ids = MagicMock() @@ -225,7 +232,6 @@ def test_custom_classes_override_default(dataset): assert custom_dataset.CLASSES != original_classes assert custom_dataset.CLASSES == ('bus', 'car') - assert custom_dataset.custom_classes # Test setting classes as a list custom_dataset = dataset_class( @@ -237,7 +243,6 @@ def test_custom_classes_override_default(dataset): assert custom_dataset.CLASSES != original_classes assert custom_dataset.CLASSES == ['bus', 'car'] - assert custom_dataset.custom_classes # Test overriding not a subset custom_dataset = dataset_class( @@ -249,7 +254,6 @@ def test_custom_classes_override_default(dataset): assert custom_dataset.CLASSES != original_classes assert custom_dataset.CLASSES == ['foo'] - assert custom_dataset.custom_classes # Test default behavior custom_dataset = dataset_class( @@ -260,7 +264,6 @@ def test_custom_classes_override_default(dataset): img_prefix='VOC2007' if dataset == 'VOCDataset' else '') assert custom_dataset.CLASSES == original_classes - assert not custom_dataset.custom_classes # Test sending file path import tempfile @@ -277,7 +280,6 @@ def test_custom_classes_override_default(dataset): assert custom_dataset.CLASSES != original_classes assert custom_dataset.CLASSES == ['bus', 'car'] - assert custom_dataset.custom_classes def test_dataset_wrapper(): @@ -460,3 +462,32 @@ def val_step(self, x, optimizer, **kwargs): runner = EpochBasedRunner( model=model, work_dir=tmp_dir, logger=logging.getLogger()) return runner + + +@pytest.mark.parametrize('classes, expected_length', [(['bus'], 2), + (['car'], 1), + (['bus', 'car'], 2)]) +def test_allow_empty_images(classes, expected_length): + dataset_class = DATASETS.get('CocoDataset') + # Filter empty images + filtered_dataset = dataset_class( + ann_file='tests/data/coco_sample.json', + img_prefix='tests/data', + pipeline=[], + classes=classes, + filter_empty_gt=True) + + # Get all + full_dataset = dataset_class( + ann_file='tests/data/coco_sample.json', + img_prefix='tests/data', + pipeline=[], + classes=classes, + filter_empty_gt=False) + + assert len(filtered_dataset) == expected_length + assert len(filtered_dataset.img_ids) == expected_length + assert len(full_dataset) == 3 + assert len(full_dataset.img_ids) == 3 + assert filtered_dataset.CLASSES == classes + assert full_dataset.CLASSES == classes