From 6c3e63e48b8f94a64ac5b1d88cab6fca005ae269 Mon Sep 17 00:00:00 2001 From: Kyungmin Lee <30465912+lkm2835@users.noreply.github.com> Date: Tue, 11 Jan 2022 17:18:24 +0900 Subject: [PATCH] [Feature] Add MultiImageMixDataset (#1105) * Fix typo in usage example * original MultiImageMixDataset code in mmdet * Add MultiImageMixDataset unittests in test_dataset_wrapper * fix lint error * fix value name ann_file to ann_dir * modify retrieve_data_cfg (#1) * remove dynamic_scale & add palette * modify retrieve_data_cfg method * modify retrieve_data_cfg func * fix error * improve the unittests coverage * fix unittests error * Dataset (#2) * add cfg-options * Add unittest in test_build_dataset * add blank line * add blank line * add a blank line Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: Younghoon-Lee <72462227+Younghoon-Lee@users.noreply.github.com> Co-authored-by: MeowZheng Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> --- mmseg/datasets/__init__.py | 5 +- mmseg/datasets/builder.py | 8 ++- mmseg/datasets/dataset_wrappers.py | 91 ++++++++++++++++++++++++- tests/test_data/test_dataset.py | 63 ++++++++++++++++- tests/test_data/test_dataset_builder.py | 9 ++- tools/browse_dataset.py | 30 +++++--- 6 files changed, 190 insertions(+), 16 deletions(-) diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index c115ab796f..a9f80a9204 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -6,7 +6,8 @@ from .coco_stuff import COCOStuffDataset from .custom import CustomDataset from .dark_zurich import DarkZurichDataset -from .dataset_wrappers import ConcatDataset, RepeatDataset +from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, + RepeatDataset) from .drive import DRIVEDataset from .hrf import HRFDataset from .loveda import LoveDADataset @@ -21,5 +22,5 @@ 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', - 'COCOStuffDataset', 'LoveDADataset' + 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset' ] diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py index 7ab645958d..4a12ec8adb 100644 --- a/mmseg/datasets/builder.py +++ b/mmseg/datasets/builder.py @@ -64,12 +64,18 @@ def _concat_dataset(cfg, default_args=None): def build_dataset(cfg, default_args=None): """Build datasets.""" - from .dataset_wrappers import ConcatDataset, RepeatDataset + from .dataset_wrappers import (ConcatDataset, RepeatDataset, + MultiImageMixDataset) if isinstance(cfg, (list, tuple)): dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) elif cfg['type'] == 'RepeatDataset': dataset = RepeatDataset( build_dataset(cfg['dataset'], default_args), cfg['times']) + elif cfg['type'] == 'MultiImageMixDataset': + cp_cfg = copy.deepcopy(cfg) + cp_cfg['dataset'] = build_dataset(cp_cfg['dataset']) + cp_cfg.pop('type') + dataset = MultiImageMixDataset(**cp_cfg) elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( cfg.get('split', None), (list, tuple)): dataset = _concat_dataset(cfg, default_args) diff --git a/mmseg/datasets/dataset_wrappers.py b/mmseg/datasets/dataset_wrappers.py index 0349332eeb..1fb089f9f2 100644 --- a/mmseg/datasets/dataset_wrappers.py +++ b/mmseg/datasets/dataset_wrappers.py @@ -1,13 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import bisect +import collections +import copy from itertools import chain import mmcv import numpy as np -from mmcv.utils import print_log +from mmcv.utils import build_from_cfg, print_log from torch.utils.data.dataset import ConcatDataset as _ConcatDataset -from .builder import DATASETS +from .builder import DATASETS, PIPELINES from .cityscapes import CityscapesDataset @@ -188,3 +190,88 @@ def __getitem__(self, idx): def __len__(self): """The length is multiplied by ``times``""" return self.times * self._ori_len + + +@DATASETS.register_module() +class MultiImageMixDataset: + """A wrapper of multiple images mixed dataset. + + Suitable for training on multiple images mixed data augmentation like + mosaic and mixup. For the augmentation pipeline of mixed image data, + the `get_indexes` method needs to be provided to obtain the image + indexes, and you can set `skip_flags` to change the pipeline running + process. + + + Args: + dataset (:obj:`CustomDataset`): The dataset to be mixed. + pipeline (Sequence[dict]): Sequence of transform object or + config dict to be composed. + skip_type_keys (list[str], optional): Sequence of type string to + be skip pipeline. Default to None. + """ + + def __init__(self, dataset, pipeline, skip_type_keys=None): + assert isinstance(pipeline, collections.abc.Sequence) + if skip_type_keys is not None: + assert all([ + isinstance(skip_type_key, str) + for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys + + self.pipeline = [] + self.pipeline_types = [] + for transform in pipeline: + if isinstance(transform, dict): + self.pipeline_types.append(transform['type']) + transform = build_from_cfg(transform, PIPELINES) + self.pipeline.append(transform) + else: + raise TypeError('pipeline must be a dict') + + self.dataset = dataset + self.CLASSES = dataset.CLASSES + self.PALETTE = dataset.PALETTE + self.num_samples = len(dataset) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + results = copy.deepcopy(self.dataset[idx]) + for (transform, transform_type) in zip(self.pipeline, + self.pipeline_types): + if self._skip_type_keys is not None and \ + transform_type in self._skip_type_keys: + continue + + if hasattr(transform, 'get_indexes'): + indexes = transform.get_indexes(self.dataset) + if not isinstance(indexes, collections.abc.Sequence): + indexes = [indexes] + mix_results = [ + copy.deepcopy(self.dataset[index]) for index in indexes + ] + results['mix_results'] = mix_results + + results = transform(results) + + if 'mix_results' in results: + results.pop('mix_results') + + return results + + def update_skip_type_keys(self, skip_type_keys): + """Update skip_type_keys. + + It is called by an external hook. + + Args: + skip_type_keys (list[str], optional): Sequence of type + string to be skip pipeline. + """ + assert all([ + isinstance(skip_type_key, str) for skip_type_key in skip_type_keys + ]) + self._skip_type_keys = skip_type_keys diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index 455e82dbdf..58c7275ab5 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -14,7 +14,8 @@ from mmseg.core.evaluation import get_classes, get_palette from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, ConcatDataset, CustomDataset, LoveDADataset, - PascalVOCDataset, RepeatDataset, build_dataset) + MultiImageMixDataset, PascalVOCDataset, + RepeatDataset, build_dataset) def test_classes(): @@ -95,6 +96,66 @@ def test_dataset_wrapper(): assert repeat_dataset[27] == 7 assert len(repeat_dataset) == 10 * len(dataset_a) + img_scale = (60, 60) + pipeline = [ + # dict(type='Mosaic', img_scale=img_scale, pad_val=255), + # need to merge mosaic + dict(type='RandomFlip', prob=0.5), + dict(type='Resize', img_scale=img_scale, keep_ratio=False), + ] + + CustomDataset.load_annotations = MagicMock() + results = [] + for _ in range(2): + height = np.random.randint(10, 30) + weight = np.random.randint(10, 30) + img = np.ones((height, weight, 3)) + gt_semantic_seg = np.random.randint(5, size=(height, weight)) + results.append(dict(gt_semantic_seg=gt_semantic_seg, img=img)) + + classes = ['0', '1', '2', '3', '4'] + palette = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)] + CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx]) + dataset_a = CustomDataset( + img_dir=MagicMock(), + pipeline=[], + test_mode=True, + classes=classes, + palette=palette) + len_a = 2 + cat_ids_list_a = [ + np.random.randint(0, 80, num).tolist() + for num in np.random.randint(1, 20, len_a) + ] + dataset_a.data_infos = MagicMock() + dataset_a.data_infos.__len__.return_value = len_a + dataset_a.get_cat_ids = MagicMock( + side_effect=lambda idx: cat_ids_list_a[idx]) + + multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline) + assert len(multi_image_mix_dataset) == len(dataset_a) + + for idx in range(len_a): + results_ = multi_image_mix_dataset[idx] + + # test skip_type_keys + multi_image_mix_dataset = MultiImageMixDataset( + dataset_a, pipeline, skip_type_keys=('RandomFlip')) + for idx in range(len_a): + results_ = multi_image_mix_dataset[idx] + assert results_['img'].shape == (img_scale[0], img_scale[1], 3) + + skip_type_keys = ('RandomFlip', 'Resize') + multi_image_mix_dataset.update_skip_type_keys(skip_type_keys) + for idx in range(len_a): + results_ = multi_image_mix_dataset[idx] + assert results_['img'].shape[:2] != img_scale + + # test pipeline + with pytest.raises(TypeError): + pipeline = [['Resize']] + multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline) + def test_custom_dataset(): img_norm_cfg = dict( diff --git a/tests/test_data/test_dataset_builder.py b/tests/test_data/test_dataset_builder.py index edb82efb93..30910b09bd 100644 --- a/tests/test_data/test_dataset_builder.py +++ b/tests/test_data/test_dataset_builder.py @@ -6,8 +6,8 @@ from torch.utils.data import (DistributedSampler, RandomSampler, SequentialSampler) -from mmseg.datasets import (DATASETS, ConcatDataset, build_dataloader, - build_dataset) +from mmseg.datasets import (DATASETS, ConcatDataset, MultiImageMixDataset, + build_dataloader, build_dataset) @DATASETS.register_module() @@ -48,6 +48,11 @@ def test_build_dataset(): assert isinstance(dataset, ConcatDataset) assert len(dataset) == 10 + cfg = dict(type='MultiImageMixDataset', dataset=cfg, pipeline=[]) + dataset = build_dataset(cfg) + assert isinstance(dataset, MultiImageMixDataset) + assert len(dataset) == 10 + # with ann_dir, split cfg = dict( type='CustomDataset', diff --git a/tools/browse_dataset.py b/tools/browse_dataset.py index 2ec414280a..d46487bf22 100644 --- a/tools/browse_dataset.py +++ b/tools/browse_dataset.py @@ -5,7 +5,7 @@ import mmcv import numpy as np -from mmcv import Config +from mmcv import Config, DictAction from mmseg.datasets.builder import build_dataset @@ -42,6 +42,16 @@ def parse_args(): type=float, default=0.5, help='the opacity of semantic map') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') args = parser.parse_args() return args @@ -122,28 +132,32 @@ def _retrieve_data_cfg(_data_cfg, skip_type, show_origin): ] -def retrieve_data_cfg(config_path, skip_type, show_origin=False): +def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False): cfg = Config.fromfile(config_path) + if cfg_options is not None: + cfg.merge_from_dict(cfg_options) train_data_cfg = cfg.data.train if isinstance(train_data_cfg, list): for _data_cfg in train_data_cfg: + while 'dataset' in _data_cfg and _data_cfg[ + 'type'] != 'MultiImageMixDataset': + _data_cfg = _data_cfg['dataset'] if 'pipeline' in _data_cfg: _retrieve_data_cfg(_data_cfg, skip_type, show_origin) - elif 'dataset' in _data_cfg: - _retrieve_data_cfg(_data_cfg['dataset'], skip_type, - show_origin) else: raise ValueError - elif 'dataset' in train_data_cfg: - _retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin) else: + while 'dataset' in train_data_cfg and train_data_cfg[ + 'type'] != 'MultiImageMixDataset': + train_data_cfg = train_data_cfg['dataset'] _retrieve_data_cfg(train_data_cfg, skip_type, show_origin) return cfg def main(): args = parse_args() - cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin) + cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options, + args.show_origin) dataset = build_dataset(cfg.data.train) progress_bar = mmcv.ProgressBar(len(dataset)) for item in dataset: