-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support eval concate dataset and add tool to show dataset #833
Changes from 2 commits
6fdd5ed
8e75af9
bc781f1
17cf91a
5bc4e58
e6501a2
bf48690
980e3bf
26570de
51ac8af
5784eb0
9352d60
28e3bd2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ def _concat_dataset(cfg, default_args=None): | |
img_dir = cfg['img_dir'] | ||
ann_dir = cfg.get('ann_dir', None) | ||
split = cfg.get('split', None) | ||
separate_eval = cfg.get('separate_eval', True) | ||
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 | ||
if ann_dir is not None: | ||
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 | ||
|
@@ -49,6 +50,9 @@ def _concat_dataset(cfg, default_args=None): | |
datasets = [] | ||
for i in range(num_dset): | ||
data_cfg = copy.deepcopy(cfg) | ||
# pop 'separate_eval' since it is not a valid key for common datasets. | ||
if 'separate_eval' in data_cfg: | ||
data_cfg.pop('separate_eval') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sure, I think it will be better There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated 28e3bd2 |
||
if isinstance(img_dir, (list, tuple)): | ||
data_cfg['img_dir'] = img_dir[i] | ||
if isinstance(ann_dir, (list, tuple)): | ||
|
@@ -57,7 +61,7 @@ def _concat_dataset(cfg, default_args=None): | |
data_cfg['split'] = split[i] | ||
datasets.append(build_dataset(data_cfg, default_args)) | ||
|
||
return ConcatDataset(datasets) | ||
return ConcatDataset(datasets, separate_eval) | ||
|
||
|
||
def build_dataset(cfg, default_args=None): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,6 +99,9 @@ def __init__(self, | |
self.label_map = None | ||
self.CLASSES, self.PALETTE = self.get_classes_and_palette( | ||
classes, palette) | ||
if test_mode: | ||
assert self.CLASSES is not None, \ | ||
'`cls.CLASSES` or `classes` should be specified when testing' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that this modify leads to failed github CI (checked). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I will find time to fix the issue above 😂 |
||
|
||
# join paths if data_root is specified | ||
if self.data_root is not None: | ||
|
@@ -339,7 +342,7 @@ def get_palette_for_custom_classes(self, class_names, palette=None): | |
|
||
return palette | ||
|
||
def evaluate(self, results, metric='mIoU', logger=None, **kwargs): | ||
def evaluate(self, results, metric='mIoU', logger=None, gt_seg_maps=None, **kwargs): | ||
"""Evaluate the dataset. | ||
|
||
Args: | ||
|
@@ -350,6 +353,8 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): | |
'mDice' and 'mFscore' are supported. | ||
logger (logging.Logger | None | str): Logger used for printing | ||
related information during evaluation. Default: None. | ||
gt_seg_maps (generator(ndarray)): Custom gt seg maps as input, | ||
used in ConcatDataset | ||
|
||
Returns: | ||
dict[str, float]: Default metrics. | ||
|
@@ -364,14 +369,9 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): | |
# test a list of files | ||
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( | ||
results, str): | ||
gt_seg_maps = self.get_gt_seg_maps() | ||
if self.CLASSES is None: | ||
num_classes = len( | ||
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) | ||
else: | ||
num_classes = len(self.CLASSES) | ||
# reset generator | ||
gt_seg_maps = self.get_gt_seg_maps() | ||
if gt_seg_maps is None: | ||
gt_seg_maps = self.get_gt_seg_maps() | ||
num_classes = len(self.CLASSES) | ||
ret_metrics = eval_metrics( | ||
results, | ||
gt_seg_maps, | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,7 +1,12 @@ | ||||||||||||||||||
# Copyright (c) OpenMMLab. All rights reserved. | ||||||||||||||||||
import bisect | ||||||||||||||||||
from itertools import chain | ||||||||||||||||||
|
||||||||||||||||||
from mmcv.utils import print_log | ||||||||||||||||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset | ||||||||||||||||||
|
||||||||||||||||||
from .builder import DATASETS | ||||||||||||||||||
from .cityscapes import CityscapesDataset | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
@DATASETS.register_module() | ||||||||||||||||||
|
@@ -15,10 +20,107 @@ class ConcatDataset(_ConcatDataset): | |||||||||||||||||
datasets (list[:obj:`Dataset`]): A list of datasets. | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
def __init__(self, datasets): | ||||||||||||||||||
def __init__(self, datasets, separate_eval=True): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi, I hava fix the issue and add unittest for it, does I need to submit a new PR or not? |
||||||||||||||||||
super(ConcatDataset, self).__init__(datasets) | ||||||||||||||||||
self.CLASSES = datasets[0].CLASSES | ||||||||||||||||||
self.PALETTE = datasets[0].PALETTE | ||||||||||||||||||
self.separate_eval = separate_eval | ||||||||||||||||||
if any([isinstance(ds, CityscapesDataset) for ds in datasets]): | ||||||||||||||||||
raise NotImplementedError( | ||||||||||||||||||
'Evaluating CityscapesDataset within ConcatDataset is not supported!') | ||||||||||||||||||
|
||||||||||||||||||
def evaluate(self, results, logger=None, **kwargs): | ||||||||||||||||||
"""Evaluate the results. | ||||||||||||||||||
|
||||||||||||||||||
Args: | ||||||||||||||||||
results (list[list | tuple]): Testing results of the dataset. | ||||||||||||||||||
logger (logging.Logger | str | None): Logger used for printing | ||||||||||||||||||
related information during evaluation. Default: None. | ||||||||||||||||||
|
||||||||||||||||||
Returns: | ||||||||||||||||||
dict[str: float]: evaluate results of the total dataset or each separate | ||||||||||||||||||
dataset if `self.separate_eval=True`. | ||||||||||||||||||
""" | ||||||||||||||||||
assert len(results) == self.cumulative_sizes[-1], \ | ||||||||||||||||||
('Dataset and results have different sizes: ' | ||||||||||||||||||
f'{self.cumulative_sizes[-1]} v.s. {len(results)}') | ||||||||||||||||||
|
||||||||||||||||||
# Check whether all the datasets support evaluation | ||||||||||||||||||
for dataset in self.datasets: | ||||||||||||||||||
assert hasattr(dataset, 'evaluate'), \ | ||||||||||||||||||
f'{type(dataset)} does not implement evaluate function' | ||||||||||||||||||
|
||||||||||||||||||
if self.separate_eval: | ||||||||||||||||||
dataset_idx = -1 | ||||||||||||||||||
total_eval_results = dict() | ||||||||||||||||||
for size, dataset in zip(self.cumulative_sizes, self.datasets): | ||||||||||||||||||
start_idx = 0 if dataset_idx == -1 else \ | ||||||||||||||||||
self.cumulative_sizes[dataset_idx] | ||||||||||||||||||
end_idx = self.cumulative_sizes[dataset_idx + 1] | ||||||||||||||||||
|
||||||||||||||||||
results_per_dataset = results[start_idx:end_idx] | ||||||||||||||||||
print_log( | ||||||||||||||||||
f'\nEvaluateing {dataset.img_dir} with ' | ||||||||||||||||||
f'{len(results_per_dataset)} images now', | ||||||||||||||||||
logger=logger) | ||||||||||||||||||
|
||||||||||||||||||
eval_results_per_dataset = dataset.evaluate( | ||||||||||||||||||
results_per_dataset, logger=logger, **kwargs) | ||||||||||||||||||
dataset_idx += 1 | ||||||||||||||||||
for k, v in eval_results_per_dataset.items(): | ||||||||||||||||||
total_eval_results.update({f'{dataset_idx}_{k}': v}) | ||||||||||||||||||
|
||||||||||||||||||
return total_eval_results | ||||||||||||||||||
|
||||||||||||||||||
if len(set([type(ds) for ds in self.datasets])) != 1: | ||||||||||||||||||
raise NotImplementedError( | ||||||||||||||||||
'All the datasets should have same types when self.separate_eval=False') | ||||||||||||||||||
else: | ||||||||||||||||||
gt_seg_maps = chain(*[dataset.get_gt_seg_maps() | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the results are There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes, but if
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean if the results are There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got it 😯 |
||||||||||||||||||
for dataset in self.datasets]) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Does this work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok,i will |
||||||||||||||||||
eval_results = self.datasets[0].evaluate( | ||||||||||||||||||
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs) | ||||||||||||||||||
return eval_results | ||||||||||||||||||
|
||||||||||||||||||
def get_dataset_idx_and_sample_idx(self, indice): | ||||||||||||||||||
"""Return dataset and sample index when given an indice of ConcatDataset | ||||||||||||||||||
|
||||||||||||||||||
Args: | ||||||||||||||||||
indice (int): indice of sample in ConcatDataset | ||||||||||||||||||
|
||||||||||||||||||
Returns: | ||||||||||||||||||
int: the index of sub dataset the sample belong to | ||||||||||||||||||
int: the index of sample in its corresponding subset | ||||||||||||||||||
""" | ||||||||||||||||||
if indice < 0: | ||||||||||||||||||
if -indice > len(self): | ||||||||||||||||||
raise ValueError("absolute value of index should not exceed dataset length") | ||||||||||||||||||
indice = len(self) + indice | ||||||||||||||||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice) | ||||||||||||||||||
if dataset_idx == 0: | ||||||||||||||||||
sample_idx = indice | ||||||||||||||||||
else: | ||||||||||||||||||
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1] | ||||||||||||||||||
return dataset_idx, sample_idx | ||||||||||||||||||
|
||||||||||||||||||
def format_results(self, results, imgfile_prefix, indices=None, **kwargs): | ||||||||||||||||||
"""format result for every sample of ConcatDataset """ | ||||||||||||||||||
ret_res = [] | ||||||||||||||||||
for i, indice in enumerate(indices): | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
you are right, I will fix it |
||||||||||||||||||
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(indice) | ||||||||||||||||||
res = self.datasets[dataset_idx].format_results( | ||||||||||||||||||
[results[i]], imgfile_prefix, indices=[sample_idx], **kwargs) | ||||||||||||||||||
ret_res.append(res) | ||||||||||||||||||
return sum(ret_res, []) | ||||||||||||||||||
|
||||||||||||||||||
def pre_eval(self, preds, indices): | ||||||||||||||||||
"""do pre eval for every sample of ConcatDataset""" | ||||||||||||||||||
ret_res = [] | ||||||||||||||||||
for i, indice in enumerate(indices): | ||||||||||||||||||
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(indice) | ||||||||||||||||||
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx) | ||||||||||||||||||
ret_res.append(res) | ||||||||||||||||||
return sum(ret_res, []) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
@DATASETS.register_module() | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import argparse | ||
import os | ||
import warnings | ||
from pathlib import Path | ||
|
||
import mmcv | ||
import numpy as np | ||
from mmcv import Config | ||
|
||
from mmseg.datasets.builder import build_dataset | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Browse a dataset') | ||
parser.add_argument('config', help='train config file path') | ||
parser.add_argument( | ||
'--show-origin', | ||
default=False, | ||
action='store_true', | ||
help='if True, omit all augmentation in pipeline,' | ||
' show origin image and seg map') | ||
parser.add_argument( | ||
'--skip-type', | ||
type=str, | ||
nargs='+', | ||
default=['DefaultFormatBundle', 'Normalize', 'Collect'], | ||
help='skip some useless pipeline,if `show-origin` is true, ' | ||
'all pipeline except `Load` will be skipped') | ||
parser.add_argument( | ||
'--output-dir', | ||
default='./output', | ||
type=str, | ||
help='If there is no display interface, you can save it') | ||
parser.add_argument('--show', default=False, action='store_true') | ||
parser.add_argument( | ||
'--show-interval', | ||
type=int, | ||
default=999, | ||
help='the interval of show (ms)') | ||
parser.add_argument( | ||
'--opacity', | ||
type=float, | ||
default=0.5, | ||
help='the opacity of semantic map') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def imshow_semantic(img, | ||
seg, | ||
class_names, | ||
palette=None, | ||
win_name='', | ||
show=False, | ||
wait_time=0, | ||
out_file=None, | ||
opacity=0.5): | ||
"""Draw `result` over `img`. | ||
|
||
Args: | ||
img (str or Tensor): The image to be displayed. | ||
seg (Tensor): The semantic segmentation results to draw over | ||
`img`. | ||
class_names (list[str]): Names of each classes. | ||
palette (list[list[int]]] | np.ndarray | None): The palette of | ||
segmentation map. If None is given, random palette will be | ||
generated. Default: None | ||
win_name (str): The window name. | ||
wait_time (int): Value of waitKey param. | ||
Default: 0. | ||
show (bool): Whether to show the image. | ||
Default: False. | ||
out_file (str or None): The filename to write the image. | ||
Default: None. | ||
opacity(float): Opacity of painted segmentation map. | ||
Default 0.5. | ||
Must be in (0, 1] range. | ||
Returns: | ||
img (Tensor): Only if not `show` or `out_file` | ||
""" | ||
img = mmcv.imread(img) | ||
img = img.copy() | ||
if palette is None: | ||
palette = np.random.randint(0, 255, size=(len(class_names), 3)) | ||
palette = np.array(palette) | ||
assert palette.shape[0] == len(class_names) | ||
assert palette.shape[1] == 3 | ||
assert len(palette.shape) == 2 | ||
assert 0 < opacity <= 1.0 | ||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) | ||
for label, color in enumerate(palette): | ||
color_seg[seg == label, :] = color | ||
# convert to BGR | ||
color_seg = color_seg[..., ::-1] | ||
|
||
img = img * (1 - opacity) + color_seg * opacity | ||
img = img.astype(np.uint8) | ||
# if out_file specified, do not show image in window | ||
if out_file is not None: | ||
show = False | ||
|
||
if show: | ||
mmcv.imshow(img, win_name, wait_time) | ||
if out_file is not None: | ||
mmcv.imwrite(img, out_file) | ||
|
||
if not (show or out_file): | ||
warnings.warn('show==False and out_file is not specified, only ' | ||
'result image will be returned') | ||
return img | ||
|
||
|
||
def _retrieve_data_cfg(_data_cfg, skip_type, show_origin): | ||
if show_origin is True: | ||
# only keep pipeline of Loading data and ann | ||
_data_cfg['pipeline'] = [ | ||
x for x in _data_cfg.pipeline if 'Load' in x['type'] | ||
] | ||
else: | ||
_data_cfg['pipeline'] = [ | ||
x for x in _data_cfg.pipeline if x['type'] not in skip_type | ||
] | ||
|
||
|
||
def retrieve_data_cfg(config_path, skip_type, show_origin=False): | ||
cfg = Config.fromfile(config_path) | ||
train_data_cfg = cfg.data.train | ||
if isinstance(train_data_cfg, list): | ||
for _data_cfg in train_data_cfg: | ||
if 'pipeline' in _data_cfg: | ||
_retrieve_data_cfg(_data_cfg, skip_type, show_origin) | ||
elif 'dataset' in _data_cfg: | ||
_retrieve_data_cfg(_data_cfg['dataset'], skip_type, | ||
show_origin) | ||
else: | ||
raise ValueError | ||
elif 'dataset' in train_data_cfg: | ||
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin) | ||
else: | ||
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin) | ||
return cfg | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin) | ||
dataset = build_dataset(cfg.data.train) | ||
progress_bar = mmcv.ProgressBar(len(dataset)) | ||
for item in dataset: | ||
filename = os.path.join(args.output_dir, | ||
Path(item['filename']).name | ||
) if args.output_dir is not None else None | ||
imshow_semantic( | ||
item['img'], | ||
item['gt_semantic_seg'], | ||
dataset.CLASSES, | ||
dataset.PALETTE, | ||
show=args.show, | ||
wait_time=args.show_interval, | ||
out_file=filename, | ||
opacity=args.opacity, | ||
) | ||
progress_bar.update() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may pop
separate_eval
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me check😢