Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion mmdet/datasets/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,29 @@ class CityscapesDataset(CocoDataset):
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
# obtain images that contain annotation
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.coco.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann

valid_img_ids = []
for i, img_info in enumerate(self.data_infos):
img_id = img_info['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids)
all_iscrowd = all([_['iscrowd'] for _ in ann_info])
if self.filter_empty_gt and (self.img_ids[i] not in ids_with_ann
if self.filter_empty_gt and (self.img_ids[i] not in ids_in_cat
or all_iscrowd):
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
valid_img_ids.append(img_id)
self.img_ids = valid_img_ids
return valid_inds

def _parse_ann_info(self, img_info, ann_info):
Expand Down
40 changes: 14 additions & 26 deletions mmdet/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,39 +96,27 @@ def get_cat_ids(self, idx):
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
# obtain images that contain annotation
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.coco.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann

valid_img_ids = []
for i, img_info in enumerate(self.data_infos):
if self.filter_empty_gt and self.img_ids[i] not in ids_with_ann:
img_id = self.img_ids[i]
if self.filter_empty_gt and img_id not in ids_in_cat:
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
valid_img_ids.append(img_id)
self.img_ids = valid_img_ids
return valid_inds

def get_subset_by_classes(self):
"""Get img ids that contain any category in class_ids.

Different from the coco.getImgIds(), this function returns the id if
the img contains one of the categories rather than all.

Args:
class_ids (list[int]): list of category ids

Return:
ids (list[int]): integer list of img ids
"""

ids = set()
for i, class_id in enumerate(self.cat_ids):
ids |= set(self.coco.cat_img_map[class_id])
self.img_ids = list(ids)

data_infos = []
for i in self.img_ids:
info = self.coco.load_imgs([i])[0]
info['filename'] = info['file_name']
data_infos.append(info)
return data_infos

def _parse_ann_info(self, img_info, ann_info):
"""Parse bbox and mask annotation.

Expand Down
26 changes: 14 additions & 12 deletions mmdet/datasets/custom.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os.path as osp
import warnings

import mmcv
import numpy as np
Expand Down Expand Up @@ -42,7 +43,9 @@ class CustomDataset(Dataset):
``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
test_mode (bool, optional): If set True, annotation will not be loaded.
filter_empty_gt (bool, optional): If set true, images without bounding
boxes will be filtered out.
boxes of the dataset's classes will be filtered out. This option
only works when `test_mode=False`, i.e., we never filter images
during tests.
"""

CLASSES = None
Expand Down Expand Up @@ -80,23 +83,21 @@ def __init__(self,
self.proposal_file)
# load annotations (and proposals)
self.data_infos = self.load_annotations(self.ann_file)
# filter data infos if classes are customized
if self.custom_classes:
self.data_infos = self.get_subset_by_classes()

if self.proposal_file is not None:
self.proposals = self.load_proposals(self.proposal_file)
else:
self.proposals = None
# filter images too small

# filter images too small and containing no annotations
if not test_mode:
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
if self.proposals is not None:
self.proposals = [self.proposals[i] for i in valid_inds]
# set group flag for the sampler
if not self.test_mode:
# set group flag for the sampler
self._set_group_flag()

# processing pipeline
self.pipeline = Compose(pipeline)

Expand Down Expand Up @@ -147,6 +148,9 @@ def pre_pipeline(self, results):

def _filter_imgs(self, min_size=32):
"""Filter images too small."""
if self.filter_empty_gt:
warnings.warn(
'CustomDataset does not support filtering empty gt images.')
valid_inds = []
for i, img_info in enumerate(self.data_infos):
if min(img_info['width'], img_info['height']) >= min_size:
Expand Down Expand Up @@ -237,12 +241,13 @@ def get_classes(cls, classes=None):
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.

Returns:
tuple[str] or list[str]: Names of categories of the dataset.
"""
if classes is None:
cls.custom_classes = False
return cls.CLASSES

cls.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
Expand All @@ -253,9 +258,6 @@ def get_classes(cls, classes=None):

return class_names

def get_subset_by_classes(self):
return self.data_infos

def format_results(self, results, **kwargs):
"""Place holder to format result to dataset specific output."""
pass
Expand Down
36 changes: 20 additions & 16 deletions mmdet/datasets/xml_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,26 @@ def load_annotations(self, ann_file):

return data_infos

def get_subset_by_classes(self):
"""Filter imgs by user-defined categories."""
subset_data_infos = []
for data_info in self.data_infos:
img_id = data_info['id']
xml_path = osp.join(self.img_prefix, 'Annotations',
f'{img_id}.xml')
tree = ET.parse(xml_path)
root = tree.getroot()
for obj in root.findall('object'):
name = obj.find('name').text
if name in self.CLASSES:
subset_data_infos.append(data_info)
break

return subset_data_infos
def _filter_imgs(self, min_size=32):
"""Filter images too small or without annotation."""
valid_inds = []
for i, img_info in enumerate(self.data_infos):
if min(img_info['width'], img_info['height']) < min_size:
continue
if self.filter_empty_gt:
img_id = img_info['id']
xml_path = osp.join(self.img_prefix, 'Annotations',
f'{img_id}.xml')
tree = ET.parse(xml_path)
root = tree.getroot()
for obj in root.findall('object'):
name = obj.find('name').text
if name in self.CLASSES:
valid_inds.append(i)
break
else:
valid_inds.append(i)
return valid_inds

def get_ann_info(self, idx):
"""Get annotation from XML file by index.
Expand Down
77 changes: 77 additions & 0 deletions tests/data/coco_sample.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{
"images": [
{
"file_name": "fake1.jpg",
"height": 800,
"width": 800,
"id": 0
},
{
"file_name": "fake2.jpg",
"height": 800,
"width": 800,
"id": 1
},
{
"file_name": "fake3.jpg",
"height": 800,
"width": 800,
"id": 2
}
],
"annotations": [
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 1,
"id": 1,
"image_id": 0
},
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 2,
"id": 2,
"image_id": 0
},
{
"bbox": [
0,
0,
20,
20
],
"area": 400.00,
"score": 1.0,
"category_id": 1,
"id": 3,
"image_id": 1
}
],
"categories": [
{
"id": 1,
"name": "bus",
"supercategory": "none"
},
{
"id": 2,
"name": "car",
"supercategory": "none"
}
],
"licenses": [],
"info": null
}
43 changes: 37 additions & 6 deletions tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,18 @@ def test_dataset_evaluation():
tmp_dir.cleanup()


@patch('mmdet.datasets.CocoDataset.load_annotations', MagicMock)
@patch('mmdet.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmdet.datasets.XMLDataset.load_annotations', MagicMock)
@patch('mmdet.datasets.CityscapesDataset.load_annotations', MagicMock)
@patch('mmdet.datasets.CocoDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.CustomDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.XMLDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.CityscapesDataset._filter_imgs', MagicMock)
@pytest.mark.parametrize('dataset',
['CocoDataset', 'VOCDataset', 'CityscapesDataset'])
def test_custom_classes_override_default(dataset):
dataset_class = DATASETS.get(dataset)
dataset_class.load_annotations = MagicMock()
if dataset in ['CocoDataset', 'CityscapesDataset']:
dataset_class.coco = MagicMock()
dataset_class.cat_ids = MagicMock()
Expand All @@ -225,7 +232,6 @@ def test_custom_classes_override_default(dataset):

assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ('bus', 'car')
assert custom_dataset.custom_classes

# Test setting classes as a list
custom_dataset = dataset_class(
Expand All @@ -237,7 +243,6 @@ def test_custom_classes_override_default(dataset):

assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['bus', 'car']
assert custom_dataset.custom_classes

# Test overriding not a subset
custom_dataset = dataset_class(
Expand All @@ -249,7 +254,6 @@ def test_custom_classes_override_default(dataset):

assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['foo']
assert custom_dataset.custom_classes

# Test default behavior
custom_dataset = dataset_class(
Expand All @@ -260,7 +264,6 @@ def test_custom_classes_override_default(dataset):
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')

assert custom_dataset.CLASSES == original_classes
assert not custom_dataset.custom_classes

# Test sending file path
import tempfile
Expand All @@ -277,7 +280,6 @@ def test_custom_classes_override_default(dataset):

assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['bus', 'car']
assert custom_dataset.custom_classes


def test_dataset_wrapper():
Expand Down Expand Up @@ -460,3 +462,32 @@ def val_step(self, x, optimizer, **kwargs):
runner = EpochBasedRunner(
model=model, work_dir=tmp_dir, logger=logging.getLogger())
return runner


@pytest.mark.parametrize('classes, expected_length', [(['bus'], 2),
(['car'], 1),
(['bus', 'car'], 2)])
def test_allow_empty_images(classes, expected_length):
dataset_class = DATASETS.get('CocoDataset')
# Filter empty images
filtered_dataset = dataset_class(
ann_file='tests/data/coco_sample.json',
img_prefix='tests/data',
pipeline=[],
classes=classes,
filter_empty_gt=True)

# Get all
full_dataset = dataset_class(
ann_file='tests/data/coco_sample.json',
img_prefix='tests/data',
pipeline=[],
classes=classes,
filter_empty_gt=False)

assert len(filtered_dataset) == expected_length
assert len(filtered_dataset.img_ids) == expected_length
assert len(full_dataset) == 3
assert len(full_dataset.img_ids) == 3
assert filtered_dataset.CLASSES == classes
assert full_dataset.CLASSES == classes