Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support eval concate dataset and add tool to show dataset #833

Merged
merged 13 commits into from
Sep 9, 2021
Merged
6 changes: 5 additions & 1 deletion mmseg/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _concat_dataset(cfg, default_args=None):
img_dir = cfg['img_dir']
ann_dir = cfg.get('ann_dir', None)
split = cfg.get('split', None)
separate_eval = cfg.get('separate_eval', True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may pop separate_eval here?

Copy link
Contributor Author

@FreyWang FreyWang Sep 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me check😢

num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
if ann_dir is not None:
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
Expand All @@ -49,6 +50,9 @@ def _concat_dataset(cfg, default_args=None):
datasets = []
for i in range(num_dset):
data_cfg = copy.deepcopy(cfg)
# pop 'separate_eval' since it is not a valid key for common datasets.
if 'separate_eval' in data_cfg:
data_cfg.pop('separate_eval')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate_eval has been poped here for every subset @xvjiarui

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use separate_eval = cfg.pop('separate_eval', True) in L33?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use separate_eval = cfg.pop('separate_eval', True) in L33?

Sure, I think it will be better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated 28e3bd2

if isinstance(img_dir, (list, tuple)):
data_cfg['img_dir'] = img_dir[i]
if isinstance(ann_dir, (list, tuple)):
Expand All @@ -57,7 +61,7 @@ def _concat_dataset(cfg, default_args=None):
data_cfg['split'] = split[i]
datasets.append(build_dataset(data_cfg, default_args))

return ConcatDataset(datasets)
return ConcatDataset(datasets, separate_eval)


def build_dataset(cfg, default_args=None):
Expand Down
18 changes: 9 additions & 9 deletions mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def __init__(self,
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)
if test_mode:
assert self.CLASSES is not None, \
'`cls.CLASSES` or `classes` should be specified when testing'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this modify leads to failed github CI (checked).
Could you please add some unittests and fix the failed unitsests ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will find time to fix the issue above 😂


# join paths if data_root is specified
if self.data_root is not None:
Expand Down Expand Up @@ -339,7 +342,7 @@ def get_palette_for_custom_classes(self, class_names, palette=None):

return palette

def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
def evaluate(self, results, metric='mIoU', logger=None, gt_seg_maps=None, **kwargs):
"""Evaluate the dataset.

Args:
Expand All @@ -350,6 +353,8 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
gt_seg_maps (generator(ndarray)): Custom gt seg maps as input,
used in ConcatDataset

Returns:
dict[str, float]: Default metrics.
Expand All @@ -364,14 +369,9 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs):
# test a list of files
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
gt_seg_maps = self.get_gt_seg_maps()
if self.CLASSES is None:
num_classes = len(
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
else:
num_classes = len(self.CLASSES)
# reset generator
gt_seg_maps = self.get_gt_seg_maps()
if gt_seg_maps is None:
gt_seg_maps = self.get_gt_seg_maps()
num_classes = len(self.CLASSES)
ret_metrics = eval_metrics(
results,
gt_seg_maps,
Expand Down
104 changes: 103 additions & 1 deletion mmseg/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
from itertools import chain

from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset

from .builder import DATASETS
from .cityscapes import CityscapesDataset


@DATASETS.register_module()
Expand All @@ -15,10 +20,107 @@ class ConcatDataset(_ConcatDataset):
datasets (list[:obj:`Dataset`]): A list of datasets.
"""

def __init__(self, datasets):
def __init__(self, datasets, separate_eval=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring for separate_eval.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, I hava fix the issue and add unittest for it, does I need to submit a new PR or not?

super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
self.PALETTE = datasets[0].PALETTE
self.separate_eval = separate_eval
if any([isinstance(ds, CityscapesDataset) for ds in datasets]):
raise NotImplementedError(
'Evaluating CityscapesDataset within ConcatDataset is not supported!')

def evaluate(self, results, logger=None, **kwargs):
"""Evaluate the results.

Args:
results (list[list | tuple]): Testing results of the dataset.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.

Returns:
dict[str: float]: evaluate results of the total dataset or each separate
dataset if `self.separate_eval=True`.
"""
assert len(results) == self.cumulative_sizes[-1], \
('Dataset and results have different sizes: '
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')

# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f'{type(dataset)} does not implement evaluate function'

if self.separate_eval:
dataset_idx = -1
total_eval_results = dict()
for size, dataset in zip(self.cumulative_sizes, self.datasets):
start_idx = 0 if dataset_idx == -1 else \
self.cumulative_sizes[dataset_idx]
end_idx = self.cumulative_sizes[dataset_idx + 1]

results_per_dataset = results[start_idx:end_idx]
print_log(
f'\nEvaluateing {dataset.img_dir} with '
f'{len(results_per_dataset)} images now',
logger=logger)

eval_results_per_dataset = dataset.evaluate(
results_per_dataset, logger=logger, **kwargs)
dataset_idx += 1
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})

return total_eval_results

if len(set([type(ds) for ds in self.datasets])) != 1:
raise NotImplementedError(
'All the datasets should have same types when self.separate_eval=False')
else:
gt_seg_maps = chain(*[dataset.get_gt_seg_maps()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the results are pre_eval results, we do not need gt_seg_maps.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the results are pre_eval results, we do not need gt_seg_maps.

yes, but if pre_eval = False when training, it may case error

evaluation = dict(interval=2000, metric='mIoU', pre_eval=True)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean if the results are pre_eval results, we do not need gt_seg_maps and set gt_seg_maps=None.
We only need to collect gt_seg_maps when the results are eval results.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it 😯

for dataset in self.datasets])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
gt_seg_maps = chain(*[dataset.get_gt_seg_maps()
for dataset in self.datasets])
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
results, str):
gt_seg_maps = chain(*[dataset.get_gt_seg_maps()
for dataset in self.datasets])
else:
gt_seg_maps = None

Does this work?
Please have a check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,i will

eval_results = self.datasets[0].evaluate(
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
return eval_results

def get_dataset_idx_and_sample_idx(self, indice):
"""Return dataset and sample index when given an indice of ConcatDataset

Args:
indice (int): indice of sample in ConcatDataset

Returns:
int: the index of sub dataset the sample belong to
int: the index of sample in its corresponding subset
"""
if indice < 0:
if -indice > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
indice = len(self) + indice
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
if dataset_idx == 0:
sample_idx = indice
else:
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx

def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
"""format result for every sample of ConcatDataset """
ret_res = []
for i, indice in enumerate(indices):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about indices=None
Maybe we need handle this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about indices=None
Maybe we need handle this case.

you are right, I will fix it

dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(indice)
res = self.datasets[dataset_idx].format_results(
[results[i]], imgfile_prefix, indices=[sample_idx], **kwargs)
ret_res.append(res)
return sum(ret_res, [])

def pre_eval(self, preds, indices):
"""do pre eval for every sample of ConcatDataset"""
ret_res = []
for i, indice in enumerate(indices):
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(indice)
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
ret_res.append(res)
return sum(ret_res, [])


@DATASETS.register_module()
Expand Down
167 changes: 167 additions & 0 deletions tools/browse_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import argparse
import os
import warnings
from pathlib import Path

import mmcv
import numpy as np
from mmcv import Config

from mmseg.datasets.builder import build_dataset


def parse_args():
parser = argparse.ArgumentParser(description='Browse a dataset')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--show-origin',
default=False,
action='store_true',
help='if True, omit all augmentation in pipeline,'
' show origin image and seg map')
parser.add_argument(
'--skip-type',
type=str,
nargs='+',
default=['DefaultFormatBundle', 'Normalize', 'Collect'],
help='skip some useless pipeline,if `show-origin` is true, '
'all pipeline except `Load` will be skipped')
parser.add_argument(
'--output-dir',
default='./output',
type=str,
help='If there is no display interface, you can save it')
parser.add_argument('--show', default=False, action='store_true')
parser.add_argument(
'--show-interval',
type=int,
default=999,
help='the interval of show (ms)')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='the opacity of semantic map')
args = parser.parse_args()
return args


def imshow_semantic(img,
seg,
class_names,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None,
opacity=0.5):
"""Draw `result` over `img`.

Args:
img (str or Tensor): The image to be displayed.
seg (Tensor): The semantic segmentation results to draw over
`img`.
class_names (list[str]): Names of each classes.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
if palette is None:
palette = np.random.randint(0, 255, size=(len(class_names), 3))
palette = np.array(palette)
assert palette.shape[0] == len(class_names)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]

img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False

if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)

if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img


def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
if show_origin is True:
# only keep pipeline of Loading data and ann
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if 'Load' in x['type']
]
else:
_data_cfg['pipeline'] = [
x for x in _data_cfg.pipeline if x['type'] not in skip_type
]


def retrieve_data_cfg(config_path, skip_type, show_origin=False):
cfg = Config.fromfile(config_path)
train_data_cfg = cfg.data.train
if isinstance(train_data_cfg, list):
for _data_cfg in train_data_cfg:
if 'pipeline' in _data_cfg:
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
elif 'dataset' in _data_cfg:
_retrieve_data_cfg(_data_cfg['dataset'], skip_type,
show_origin)
else:
raise ValueError
elif 'dataset' in train_data_cfg:
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
else:
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
return cfg


def main():
args = parse_args()
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
dataset = build_dataset(cfg.data.train)
progress_bar = mmcv.ProgressBar(len(dataset))
for item in dataset:
filename = os.path.join(args.output_dir,
Path(item['filename']).name
) if args.output_dir is not None else None
imshow_semantic(
item['img'],
item['gt_semantic_seg'],
dataset.CLASSES,
dataset.PALETTE,
show=args.show,
wait_time=args.show_interval,
out_file=filename,
opacity=args.opacity,
)
progress_bar.update()


if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def main():
print(f'\nwriting results to {args.out}')
mmcv.dump(results, args.out)
if args.eval:
dataset.evaluate(results, args.eval, **eval_kwargs)
eval_kwargs.update(metric=args.eval)
dataset.evaluate(results, **eval_kwargs)
if tmpdir is not None and eval_on_format_results:
# remove tmp dir when cityscapes evaluation
shutil.rmtree(tmpdir)
Expand Down