diff --git a/datumaro/datumaro/components/project.py b/datumaro/datumaro/components/project.py index 262710dc7497..ea184083a46b 100644 --- a/datumaro/datumaro/components/project.py +++ b/datumaro/datumaro/components/project.py @@ -634,6 +634,8 @@ def sources(self): return self._sources def _save_branch_project(self, extractor, save_dir=None): + extractor = Dataset.from_extractors(extractor) # apply lazy transforms + # NOTE: probably this function should be in the ViewModel layer save_dir = osp.abspath(save_dir) if save_dir: diff --git a/datumaro/datumaro/plugins/transforms.py b/datumaro/datumaro/plugins/transforms.py index 47cbfcf32095..78d9ecf36aa6 100644 --- a/datumaro/datumaro/plugins/transforms.py +++ b/datumaro/datumaro/plugins/transforms.py @@ -3,6 +3,7 @@ # # SPDX-License-Identifier: MIT +from enum import Enum import logging as log import os.path as osp import random @@ -10,7 +11,9 @@ import pycocotools.mask as mask_utils from datumaro.components.extractor import (Transform, AnnotationType, - RleMask, Polygon, Bbox) + RleMask, Polygon, Bbox, + LabelCategories, MaskCategories, PointsCategories +) from datumaro.components.cli_plugin import CliPlugin import datumaro.util.mask_tools as mask_tools from datumaro.util.annotation_tools import find_group_leader, find_instances @@ -46,7 +49,7 @@ def crop_segments(cls, segment_anns, img_width, img_height): segments.append(s.points) elif s.type == AnnotationType.mask: if isinstance(s, RleMask): - rle = s._rle + rle = s.rle else: rle = mask_tools.mask_to_rle(s.image) segments.append(rle) @@ -365,3 +368,116 @@ def transform_item(self, item): if item.has_image and item.image.filename: name = osp.splitext(item.image.filename)[0] return self.wrap_item(item, id=name) + +class RemapLabels(Transform, CliPlugin): + DefaultAction = Enum('DefaultAction', ['keep', 'delete']) + + @staticmethod + def _split_arg(s): + parts = s.split(':') + if len(parts) != 2: + import argparse + raise argparse.ArgumentTypeError() + return (parts[0], parts[1]) + + @classmethod + def build_cmdline_parser(cls, **kwargs): + parser = super().build_cmdline_parser(**kwargs) + parser.add_argument('-l', '--label', action='append', + type=cls._split_arg, dest='mapping', + help="Label in the form of: ':' (repeatable)") + parser.add_argument('--default', + choices=[a.name for a in cls.DefaultAction], + default=cls.DefaultAction.keep.name, + help="Action for unspecified labels") + return parser + + def __init__(self, extractor, mapping, default=None): + super().__init__(extractor) + + assert isinstance(default, (str, self.DefaultAction)) + if isinstance(default, str): + default = self.DefaultAction[default] + + assert isinstance(mapping, (dict, list)) + if isinstance(mapping, list): + mapping = dict(mapping) + + self._categories = {} + + src_label_cat = self._extractor.categories().get(AnnotationType.label) + if src_label_cat is not None: + self._make_label_id_map(src_label_cat, mapping, default) + + src_mask_cat = self._extractor.categories().get(AnnotationType.mask) + if src_mask_cat is not None: + assert src_label_cat is not None + dst_mask_cat = MaskCategories(attributes=src_mask_cat.attributes) + dst_mask_cat.colormap = { + id: src_mask_cat.colormap[id] + for id, _ in enumerate(src_label_cat.items) + if self._map_id(id) or id == 0 + } + self._categories[AnnotationType.mask] = dst_mask_cat + + src_points_cat = self._extractor.categories().get(AnnotationType.points) + if src_points_cat is not None: + assert src_label_cat is not None + dst_points_cat = PointsCategories(attributes=src_points_cat.attributes) + dst_points_cat.items = { + id: src_points_cat.items[id] + for id, item in enumerate(src_label_cat.items) + if self._map_id(id) or id == 0 + } + self._categories[AnnotationType.points] = dst_points_cat + + def _make_label_id_map(self, src_label_cat, label_mapping, default_action): + dst_label_cat = LabelCategories(attributes=src_label_cat.attributes) + id_mapping = {} + for src_index, src_label in enumerate(src_label_cat.items): + dst_label = label_mapping.get(src_label.name) + if not dst_label and default_action == self.DefaultAction.keep: + dst_label = src_label.name # keep unspecified as is + if not dst_label: + continue + + dst_index = dst_label_cat.find(dst_label)[0] + if dst_index is None: + dst_label_cat.add(dst_label, + src_label.parent, src_label.attributes) + dst_index = dst_label_cat.find(dst_label)[0] + id_mapping[src_index] = dst_index + + if log.getLogger().isEnabledFor(log.DEBUG): + log.debug("Label mapping:") + for src_id, src_label in enumerate(src_label_cat.items): + if id_mapping.get(src_id): + log.debug("#%s '%s' -> #%s '%s'", + src_id, src_label.name, id_mapping[src_id], + dst_label_cat.items[id_mapping[src_id]].name + ) + else: + log.debug("#%s '%s' -> ", src_id, src_label.name) + + self._map_id = lambda src_id: id_mapping.get(src_id, None) + self._categories[AnnotationType.label] = dst_label_cat + + def categories(self): + return self._categories + + def transform_item(self, item): + # TODO: provide non-inplace version + annotations = [] + for ann in item.annotations: + if ann.type in { AnnotationType.label, AnnotationType.mask, + AnnotationType.points, AnnotationType.polygon, + AnnotationType.polyline, AnnotationType.bbox + } and ann.label is not None: + conv_label = self._map_id(ann.label) + if conv_label is not None: + ann._label = conv_label + annotations.append(ann) + else: + annotations.append(ann) + item._annotations = annotations + return item \ No newline at end of file diff --git a/datumaro/datumaro/plugins/voc_format/converter.py b/datumaro/datumaro/plugins/voc_format/converter.py index 4b82b4f547a1..5467e52ffb85 100644 --- a/datumaro/datumaro/plugins/voc_format/converter.py +++ b/datumaro/datumaro/plugins/voc_format/converter.py @@ -53,14 +53,13 @@ def _write_xml_bbox(bbox, parent_elem): class _Converter: def __init__(self, extractor, save_dir, tasks=None, apply_colormap=True, save_images=False, label_map=None): - assert tasks is None or isinstance(tasks, (VocTask, list)) + assert tasks is None or isinstance(tasks, (VocTask, list, set)) if tasks is None: - tasks = list(VocTask) + tasks = set(VocTask) elif isinstance(tasks, VocTask): - tasks = [tasks] + tasks = {tasks} else: - tasks = [t if t in VocTask else VocTask[t] for t in tasks] - + tasks = set(t if t in VocTask else VocTask[t] for t in tasks) self._tasks = tasks self._extractor = extractor @@ -259,10 +258,10 @@ def save_subsets(self): if len(actions_elem) != 0: obj_elem.append(actions_elem) - if set(self._tasks) & set([None, + if self._tasks & {None, VocTask.detection, VocTask.person_layout, - VocTask.action_classification]): + VocTask.action_classification}: with open(osp.join(self._ann_dir, item.id + '.xml'), 'w') as f: f.write(ET.tostring(root_elem, encoding='unicode', pretty_print=True)) @@ -302,19 +301,19 @@ def save_subsets(self): action_list[item.id] = None segm_list[item.id] = None - if set(self._tasks) & set([None, + if self._tasks & {None, VocTask.classification, VocTask.detection, VocTask.action_classification, - VocTask.person_layout]): + VocTask.person_layout}: self.save_clsdet_lists(subset_name, clsdet_list) - if set(self._tasks) & set([None, VocTask.classification]): + if self._tasks & {None, VocTask.classification}: self.save_class_lists(subset_name, class_lists) - if set(self._tasks) & set([None, VocTask.action_classification]): + if self._tasks & {None, VocTask.action_classification}: self.save_action_lists(subset_name, action_list) - if set(self._tasks) & set([None, VocTask.person_layout]): + if self._tasks & {None, VocTask.person_layout}: self.save_layout_lists(subset_name, layout_list) - if set(self._tasks) & set([None, VocTask.segmentation]): + if self._tasks & {None, VocTask.segmentation}: self.save_segm_lists(subset_name, segm_list) def save_action_lists(self, subset_name, action_list): diff --git a/datumaro/tests/test_transforms.py b/datumaro/tests/test_transforms.py index b90581e8f975..58c677a275dd 100644 --- a/datumaro/tests/test_transforms.py +++ b/datumaro/tests/test_transforms.py @@ -3,10 +3,12 @@ from unittest import TestCase from datumaro.components.extractor import (Extractor, DatasetItem, - Mask, Polygon, PolyLine, Points, Bbox + Mask, Polygon, PolyLine, Points, Bbox, Label, + LabelCategories, MaskCategories, AnnotationType ) -from datumaro.util.test_utils import compare_datasets +import datumaro.util.mask_tools as mask_tools import datumaro.plugins.transforms as transforms +from datumaro.util.test_utils import compare_datasets class TransformsTest(TestCase): @@ -361,3 +363,95 @@ def __iter__(self): ('train', -0.5), ('test', 1.5), ]) + + def test_remap_labels(self): + class SrcExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, annotations=[ + # Should be remapped + Label(1), + Bbox(1, 2, 3, 4, label=2), + Mask(image=np.array([1]), label=3), + + # Should be kept + Polygon([1, 1, 2, 2, 3, 4], label=4), + PolyLine([1, 3, 4, 2, 5, 6], label=None) + ]), + ]) + + def categories(self): + label_cat = LabelCategories() + label_cat.add('label0') + label_cat.add('label1') + label_cat.add('label2') + label_cat.add('label3') + label_cat.add('label4') + + mask_cat = MaskCategories( + colormap=mask_tools.generate_colormap(5)) + + return { + AnnotationType.label: label_cat, + AnnotationType.mask: mask_cat, + } + + class DstExtractor(Extractor): + def __iter__(self): + return iter([ + DatasetItem(id=1, annotations=[ + Label(1), + Bbox(1, 2, 3, 4, label=0), + Mask(image=np.array([1]), label=1), + + Polygon([1, 1, 2, 2, 3, 4], label=2), + PolyLine([1, 3, 4, 2, 5, 6], label=None) + ]), + ]) + + def categories(self): + label_cat = LabelCategories() + label_cat.add('label0') + label_cat.add('label9') + label_cat.add('label4') + + mask_cat = MaskCategories(colormap={ + k: v for k, v in mask_tools.generate_colormap(5).items() + if k in { 0, 1, 3, 4 } + }) + + return { + AnnotationType.label: label_cat, + AnnotationType.mask: mask_cat, + } + + actual = transforms.RemapLabels(SrcExtractor(), mapping={ + 'label1': 'label9', + 'label2': 'label0', + 'label3': 'label9', + }, default='keep') + + compare_datasets(self, DstExtractor(), actual) + + def test_remap_labels_delete_unspecified(self): + class SrcExtractor(Extractor): + def __iter__(self): + return iter([ DatasetItem(id=1, annotations=[ Label(0) ]) ]) + + def categories(self): + label_cat = LabelCategories() + label_cat.add('label0') + + return { AnnotationType.label: label_cat } + + class DstExtractor(Extractor): + def __iter__(self): + return iter([ DatasetItem(id=1, annotations=[]) ]) + + def categories(self): + return { AnnotationType.label: LabelCategories() } + + actual = transforms.RemapLabels(SrcExtractor(), + mapping={}, default='delete') + + compare_datasets(self, DstExtractor(), actual)