From 107054e87623b1e577894c921447c90733ad2c06 Mon Sep 17 00:00:00 2001 From: gaotongxiao Date: Thu, 27 Jan 2022 19:29:35 +0800 Subject: [PATCH 1/3] [Fix] Free the requirement of PALETTE in ConcatDataset for downstream compatability --- mmdet/datasets/dataset_wrappers.py | 3 ++- tests/test_data/test_datasets/test_dataset_wrapper.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mmdet/datasets/dataset_wrappers.py b/mmdet/datasets/dataset_wrappers.py index b41ec369943..78c8a922942 100644 --- a/mmdet/datasets/dataset_wrappers.py +++ b/mmdet/datasets/dataset_wrappers.py @@ -30,7 +30,8 @@ class ConcatDataset(_ConcatDataset): def __init__(self, datasets, separate_eval=True): super(ConcatDataset, self).__init__(datasets) self.CLASSES = datasets[0].CLASSES - self.PALETTE = datasets[0].PALETTE + self.PALETTE = datasets[0].PALETTE if hasattr(datasets[0], + 'PALETTE') else None self.separate_eval = separate_eval if not separate_eval: if any([isinstance(ds, CocoDataset) for ds in datasets]): diff --git a/tests/test_data/test_datasets/test_dataset_wrapper.py b/tests/test_data/test_datasets/test_dataset_wrapper.py index 0e43c31ebc0..a1103ff083a 100644 --- a/tests/test_data/test_datasets/test_dataset_wrapper.py +++ b/tests/test_data/test_datasets/test_dataset_wrapper.py @@ -76,6 +76,14 @@ def test_dataset_wrapper(): assert concat_dataset.get_ann_info(25) == ann_info_list_b[15] assert len(concat_dataset) == len(dataset_a) + len(dataset_b) + # Test if ConcatDataset allows dataset classes without the PALETTE + # attribute + palette_backup = CustomDataset.PALETTE + delattr(CustomDataset, 'PALETTE') + concat_dataset = ConcatDataset([dataset_a, dataset_b]) + assert concat_dataset.PALETTE is None + CustomDataset.PALETTE = palette_backup + repeat_dataset = RepeatDataset(dataset_a, 10) assert repeat_dataset[5] == 5 assert repeat_dataset[15] == 5 From 9538425d36c4c9e6b50a77cf80e6e1113a5f7d2e Mon Sep 17 00:00:00 2001 From: gaotongxiao Date: Thu, 27 Jan 2022 19:42:39 +0800 Subject: [PATCH 2/3] Extend to all dataset wrappers --- mmdet/datasets/dataset_wrappers.py | 6 +++--- .../test_datasets/test_dataset_wrapper.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mmdet/datasets/dataset_wrappers.py b/mmdet/datasets/dataset_wrappers.py index 78c8a922942..251327473ea 100644 --- a/mmdet/datasets/dataset_wrappers.py +++ b/mmdet/datasets/dataset_wrappers.py @@ -169,7 +169,7 @@ def __init__(self, dataset, times): self.dataset = dataset self.times = times self.CLASSES = dataset.CLASSES - self.PALETTE = dataset.PALETTE + self.PALETTE = dataset.PALETTE if hasattr(dataset, 'PALETTE') else None if hasattr(self.dataset, 'flag'): self.flag = np.tile(self.dataset.flag, times) @@ -250,7 +250,7 @@ def __init__(self, dataset, oversample_thr, filter_empty_gt=True): self.oversample_thr = oversample_thr self.filter_empty_gt = filter_empty_gt self.CLASSES = dataset.CLASSES - self.PALETTE = dataset.PALETTE + self.PALETTE = dataset.PALETTE if hasattr(dataset, 'PALETTE') else None repeat_factors = self._get_repeat_factors(dataset, oversample_thr) repeat_indices = [] @@ -385,7 +385,7 @@ def __init__(self, self.dataset = dataset self.CLASSES = dataset.CLASSES - self.PALETTE = dataset.PALETTE + self.PALETTE = dataset.PALETTE if hasattr(dataset, 'PALETTE') else None if hasattr(self.dataset, 'flag'): self.flag = dataset.flag self.num_samples = len(dataset) diff --git a/tests/test_data/test_datasets/test_dataset_wrapper.py b/tests/test_data/test_datasets/test_dataset_wrapper.py index a1103ff083a..ad29678590e 100644 --- a/tests/test_data/test_datasets/test_dataset_wrapper.py +++ b/tests/test_data/test_datasets/test_dataset_wrapper.py @@ -96,6 +96,13 @@ def test_dataset_wrapper(): assert repeat_dataset.get_ann_info(27) == ann_info_list_a[7] assert len(repeat_dataset) == 10 * len(dataset_a) + # Test if RepeatDataset allows dataset classes without the PALETTE + # attribute + delattr(CustomDataset, 'PALETTE') + repeat_dataset = RepeatDataset(dataset_a, 10) + assert repeat_dataset.PALETTE is None + CustomDataset.PALETTE = palette_backup + category_freq = defaultdict(int) for cat_ids in cat_ids_list_a: cat_ids = set(cat_ids) @@ -125,6 +132,12 @@ def test_dataset_wrapper(): repeat_factors_cumsum, idx) assert repeat_factor_dataset.get_ann_info(idx) == ann_info_list_a[ bisect.bisect_right(repeat_factors_cumsum, idx)] + # Test if ClassBalancedDataset allows dataset classes without the PALETTE + # attribute + delattr(CustomDataset, 'PALETTE') + repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr) + assert repeat_factor_dataset.PALETTE is None + CustomDataset.PALETTE = palette_backup img_scale = (60, 60) pipeline = [ @@ -187,3 +200,10 @@ def test_dataset_wrapper(): for idx in range(len_a): results_ = multi_image_mix_dataset[idx] assert results_['img'].shape == (img_scale[0], img_scale[1], 3) + + # Test if MultiImageMixDataset allows dataset classes without the PALETTE + # attribute + delattr(CustomDataset, 'PALETTE') + multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline) + assert multi_image_mix_dataset.PALETTE is None + CustomDataset.PALETTE = palette_backup From b00e9fec99de504f268b21500ba47697ed4e264d Mon Sep 17 00:00:00 2001 From: gaotongxiao Date: Fri, 28 Jan 2022 09:39:04 +0800 Subject: [PATCH 3/3] use getattr --- mmdet/datasets/dataset_wrappers.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mmdet/datasets/dataset_wrappers.py b/mmdet/datasets/dataset_wrappers.py index 251327473ea..1274f58e667 100644 --- a/mmdet/datasets/dataset_wrappers.py +++ b/mmdet/datasets/dataset_wrappers.py @@ -30,8 +30,7 @@ class ConcatDataset(_ConcatDataset): def __init__(self, datasets, separate_eval=True): super(ConcatDataset, self).__init__(datasets) self.CLASSES = datasets[0].CLASSES - self.PALETTE = datasets[0].PALETTE if hasattr(datasets[0], - 'PALETTE') else None + self.PALETTE = getattr(datasets[0], 'PALETTE', None) self.separate_eval = separate_eval if not separate_eval: if any([isinstance(ds, CocoDataset) for ds in datasets]): @@ -169,7 +168,7 @@ def __init__(self, dataset, times): self.dataset = dataset self.times = times self.CLASSES = dataset.CLASSES - self.PALETTE = dataset.PALETTE if hasattr(dataset, 'PALETTE') else None + self.PALETTE = getattr(dataset, 'PALETTE', None) if hasattr(self.dataset, 'flag'): self.flag = np.tile(self.dataset.flag, times) @@ -250,7 +249,7 @@ def __init__(self, dataset, oversample_thr, filter_empty_gt=True): self.oversample_thr = oversample_thr self.filter_empty_gt = filter_empty_gt self.CLASSES = dataset.CLASSES - self.PALETTE = dataset.PALETTE if hasattr(dataset, 'PALETTE') else None + self.PALETTE = getattr(dataset, 'PALETTE', None) repeat_factors = self._get_repeat_factors(dataset, oversample_thr) repeat_indices = [] @@ -385,7 +384,7 @@ def __init__(self, self.dataset = dataset self.CLASSES = dataset.CLASSES - self.PALETTE = dataset.PALETTE if hasattr(dataset, 'PALETTE') else None + self.PALETTE = getattr(dataset, 'PALETTE', None) if hasattr(self.dataset, 'flag'): self.flag = dataset.flag self.num_samples = len(dataset)