From 9ab438051fd486710be299c8f0371f0587d79ed7 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 13 Jul 2021 14:56:02 +0300 Subject: [PATCH 1/3] Fix importing arbitrary file names in COCO subformats --- datumaro/plugins/coco_format/importer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/datumaro/plugins/coco_format/importer.py b/datumaro/plugins/coco_format/importer.py index ca76dc55db..44a4c73621 100644 --- a/datumaro/plugins/coco_format/importer.py +++ b/datumaro/plugins/coco_format/importer.py @@ -73,8 +73,12 @@ def __call__(self, path, **extra_params): @classmethod def find_sources(cls, path): - if path.endswith('.json') and osp.isfile(path): - subset_paths = [path] + if osp.isfile(path): + if len(cls._TASKS) == 1: + return {'': { next(iter(cls._TASKS)): path }} + + if path.endswith('.json'): + subset_paths = [path] else: subset_paths = glob(osp.join(path, '**', '*_*.json'), recursive=True) From 1d4176558cd5cdf887b8d4a88f3cb4e72cfefe93 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 13 Jul 2021 15:35:23 +0300 Subject: [PATCH 2/3] Optimize subset iteration in a simple scenario --- datumaro/components/extractor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datumaro/components/extractor.py b/datumaro/components/extractor.py index 26506b3002..17220d04af 100644 --- a/datumaro/components/extractor.py +++ b/datumaro/components/extractor.py @@ -686,6 +686,9 @@ def get_subset(self, name): if self._subsets is None: self._init_cache() if name in self._subsets: + if len(self._subsets) == 1: + return self + return self.select(lambda item: item.subset == name) else: raise Exception("Unknown subset '%s', available subsets: %s" % \ From 00a5e757bd61a1128480447ea16b8d9d0077e856 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 13 Jul 2021 15:36:09 +0300 Subject: [PATCH 3/3] Fix subset iteration in dataset with transforms --- datumaro/components/dataset.py | 18 +++--- tests/test_dataset.py | 114 +++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 8 deletions(-) diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index 3be34a4a43..7072150f2d 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -266,7 +266,8 @@ def is_cache_initialized(self) -> bool: @property def _is_unchanged_wrapper(self) -> bool: - return self._source is not None and self._storage.is_empty() + return self._source is not None and self._storage.is_empty() and \ + not self._transforms def init_cache(self): if not self.is_cache_initialized(): @@ -513,16 +514,17 @@ def get_subset(self, name): return self._merged().get_subset(name) def subsets(self): - subsets = {} - if not self.is_cache_initialized(): - subsets.update(self._source.subsets()) - subsets.update(self._storage.subsets()) - return subsets + # TODO: check if this can be optimized in case of transforms + # and other cases + return self._merged().subsets() def transform(self, method: Transform, *args, **kwargs): # Flush accumulated changes - source = self._merged() - self._storage = DatasetItemStorage() + if not self._storage.is_empty(): + source = self._merged() + self._storage = DatasetItemStorage() + else: + source = self._source if not self._transforms: # The stack of transforms only needs a single source diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d141d00d88..e594a87f26 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -867,6 +867,120 @@ def transform_item(self, item): self.assertEqual(iter_called, 1) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_get_len_after_local_transforms(self): + iter_called = 0 + class TestExtractor(Extractor): + def __iter__(self): + nonlocal iter_called + iter_called += 1 + yield from [ + DatasetItem(1), + DatasetItem(2), + DatasetItem(3), + DatasetItem(4), + ] + dataset = Dataset.from_extractors(TestExtractor()) + + class TestTransform(ItemTransform): + def transform_item(self, item): + return self.wrap_item(item, id=int(item.id) + 1) + + dataset.transform(TestTransform) + dataset.transform(TestTransform) + + self.assertEqual(iter_called, 0) + + self.assertEqual(4, len(dataset)) + + self.assertEqual(iter_called, 1) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_get_len_after_nonlocal_transforms(self): + iter_called = 0 + class TestExtractor(Extractor): + def __iter__(self): + nonlocal iter_called + iter_called += 1 + yield from [ + DatasetItem(1), + DatasetItem(2), + DatasetItem(3), + DatasetItem(4), + ] + dataset = Dataset.from_extractors(TestExtractor()) + + class TestTransform(Transform): + def __iter__(self): + for item in self._extractor: + yield self.wrap_item(item, id=int(item.id) + 1) + + dataset.transform(TestTransform) + dataset.transform(TestTransform) + + self.assertEqual(iter_called, 0) + + self.assertEqual(4, len(dataset)) + + self.assertEqual(iter_called, 2) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_get_subsets_after_local_transforms(self): + iter_called = 0 + class TestExtractor(Extractor): + def __iter__(self): + nonlocal iter_called + iter_called += 1 + yield from [ + DatasetItem(1), + DatasetItem(2), + DatasetItem(3), + DatasetItem(4), + ] + dataset = Dataset.from_extractors(TestExtractor()) + + class TestTransform(ItemTransform): + def transform_item(self, item): + return self.wrap_item(item, id=int(item.id) + 1, subset='a') + + dataset.transform(TestTransform) + dataset.transform(TestTransform) + + self.assertEqual(iter_called, 0) + + self.assertEqual({'a'}, set(dataset.subsets())) + + self.assertEqual(iter_called, 1) + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_get_subsets_after_nonlocal_transforms(self): + iter_called = 0 + class TestExtractor(Extractor): + def __iter__(self): + nonlocal iter_called + iter_called += 1 + yield from [ + DatasetItem(1), + DatasetItem(2), + DatasetItem(3), + DatasetItem(4), + ] + dataset = Dataset.from_extractors(TestExtractor()) + + class TestTransform(Transform): + def __iter__(self): + for item in self._extractor: + yield self.wrap_item(item, id=int(item.id) + 1, subset='a') + + dataset.transform(TestTransform) + dataset.transform(TestTransform) + + self.assertEqual(iter_called, 0) + + self.assertEqual({'a'}, set(dataset.subsets())) + + self.assertEqual(iter_called, 2) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_raises_when_repeated_items_in_source(self): dataset = Dataset.from_iterable([DatasetItem(0), DatasetItem(0)])