Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Relax the requirement of PALETTE in dataset wrappers #7085

Merged
merged 3 commits into from
Jan 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mmdet/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ConcatDataset(_ConcatDataset):
def __init__(self, datasets, separate_eval=True):
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
self.PALETTE = datasets[0].PALETTE
self.PALETTE = getattr(datasets[0], 'PALETTE', None)
self.separate_eval = separate_eval
if not separate_eval:
if any([isinstance(ds, CocoDataset) for ds in datasets]):
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
self.PALETTE = dataset.PALETTE
self.PALETTE = getattr(dataset, 'PALETTE', None)
if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, times)

Expand Down Expand Up @@ -249,7 +249,7 @@ def __init__(self, dataset, oversample_thr, filter_empty_gt=True):
self.oversample_thr = oversample_thr
self.filter_empty_gt = filter_empty_gt
self.CLASSES = dataset.CLASSES
self.PALETTE = dataset.PALETTE
self.PALETTE = getattr(dataset, 'PALETTE', None)

repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
repeat_indices = []
Expand Down Expand Up @@ -384,7 +384,7 @@ def __init__(self,

self.dataset = dataset
self.CLASSES = dataset.CLASSES
self.PALETTE = dataset.PALETTE
self.PALETTE = getattr(dataset, 'PALETTE', None)
if hasattr(self.dataset, 'flag'):
self.flag = dataset.flag
self.num_samples = len(dataset)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_data/test_datasets/test_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def test_dataset_wrapper():
assert concat_dataset.get_ann_info(25) == ann_info_list_b[15]
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)

# Test if ConcatDataset allows dataset classes without the PALETTE
# attribute
palette_backup = CustomDataset.PALETTE
delattr(CustomDataset, 'PALETTE')
concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset.PALETTE is None
CustomDataset.PALETTE = palette_backup

repeat_dataset = RepeatDataset(dataset_a, 10)
assert repeat_dataset[5] == 5
assert repeat_dataset[15] == 5
Expand All @@ -88,6 +96,13 @@ def test_dataset_wrapper():
assert repeat_dataset.get_ann_info(27) == ann_info_list_a[7]
assert len(repeat_dataset) == 10 * len(dataset_a)

# Test if RepeatDataset allows dataset classes without the PALETTE
# attribute
delattr(CustomDataset, 'PALETTE')
repeat_dataset = RepeatDataset(dataset_a, 10)
assert repeat_dataset.PALETTE is None
CustomDataset.PALETTE = palette_backup

category_freq = defaultdict(int)
for cat_ids in cat_ids_list_a:
cat_ids = set(cat_ids)
Expand Down Expand Up @@ -117,6 +132,12 @@ def test_dataset_wrapper():
repeat_factors_cumsum, idx)
assert repeat_factor_dataset.get_ann_info(idx) == ann_info_list_a[
bisect.bisect_right(repeat_factors_cumsum, idx)]
# Test if ClassBalancedDataset allows dataset classes without the PALETTE
# attribute
delattr(CustomDataset, 'PALETTE')
repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr)
assert repeat_factor_dataset.PALETTE is None
CustomDataset.PALETTE = palette_backup

img_scale = (60, 60)
pipeline = [
Expand Down Expand Up @@ -179,3 +200,10 @@ def test_dataset_wrapper():
for idx in range(len_a):
results_ = multi_image_mix_dataset[idx]
assert results_['img'].shape == (img_scale[0], img_scale[1], 3)

# Test if MultiImageMixDataset allows dataset classes without the PALETTE
# attribute
delattr(CustomDataset, 'PALETTE')
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
assert multi_image_mix_dataset.PALETTE is None
CustomDataset.PALETTE = palette_backup