Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset caching fixes #351

Merged
merged 3 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions datumaro/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def is_cache_initialized(self) -> bool:

@property
def _is_unchanged_wrapper(self) -> bool:
return self._source is not None and self._storage.is_empty()
return self._source is not None and self._storage.is_empty() and \
not self._transforms

def init_cache(self):
if not self.is_cache_initialized():
Expand Down Expand Up @@ -513,16 +514,17 @@ def get_subset(self, name):
return self._merged().get_subset(name)

def subsets(self):
subsets = {}
if not self.is_cache_initialized():
subsets.update(self._source.subsets())
subsets.update(self._storage.subsets())
return subsets
# TODO: check if this can be optimized in case of transforms
# and other cases
return self._merged().subsets()

def transform(self, method: Transform, *args, **kwargs):
# Flush accumulated changes
source = self._merged()
self._storage = DatasetItemStorage()
if not self._storage.is_empty():
source = self._merged()
self._storage = DatasetItemStorage()
else:
source = self._source

if not self._transforms:
# The stack of transforms only needs a single source
Expand Down
3 changes: 3 additions & 0 deletions datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,9 @@ def get_subset(self, name):
if self._subsets is None:
self._init_cache()
if name in self._subsets:
if len(self._subsets) == 1:
return self

return self.select(lambda item: item.subset == name)
else:
raise Exception("Unknown subset '%s', available subsets: %s" % \
Expand Down
8 changes: 6 additions & 2 deletions datumaro/plugins/coco_format/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ def __call__(self, path, **extra_params):

@classmethod
def find_sources(cls, path):
if path.endswith('.json') and osp.isfile(path):
subset_paths = [path]
if osp.isfile(path):
if len(cls._TASKS) == 1:
return {'': { next(iter(cls._TASKS)): path }}

if path.endswith('.json'):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '**', '*_*.json'),
recursive=True)
Expand Down
114 changes: 114 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,120 @@ def transform_item(self, item):

self.assertEqual(iter_called, 1)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_get_len_after_local_transforms(self):
iter_called = 0
class TestExtractor(Extractor):
def __iter__(self):
nonlocal iter_called
iter_called += 1
yield from [
DatasetItem(1),
DatasetItem(2),
DatasetItem(3),
DatasetItem(4),
]
dataset = Dataset.from_extractors(TestExtractor())

class TestTransform(ItemTransform):
def transform_item(self, item):
return self.wrap_item(item, id=int(item.id) + 1)

dataset.transform(TestTransform)
dataset.transform(TestTransform)

self.assertEqual(iter_called, 0)

self.assertEqual(4, len(dataset))

self.assertEqual(iter_called, 1)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_get_len_after_nonlocal_transforms(self):
iter_called = 0
class TestExtractor(Extractor):
def __iter__(self):
nonlocal iter_called
iter_called += 1
yield from [
DatasetItem(1),
DatasetItem(2),
DatasetItem(3),
DatasetItem(4),
]
dataset = Dataset.from_extractors(TestExtractor())

class TestTransform(Transform):
def __iter__(self):
for item in self._extractor:
yield self.wrap_item(item, id=int(item.id) + 1)

dataset.transform(TestTransform)
dataset.transform(TestTransform)

self.assertEqual(iter_called, 0)

self.assertEqual(4, len(dataset))

self.assertEqual(iter_called, 2)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_get_subsets_after_local_transforms(self):
iter_called = 0
class TestExtractor(Extractor):
def __iter__(self):
nonlocal iter_called
iter_called += 1
yield from [
DatasetItem(1),
DatasetItem(2),
DatasetItem(3),
DatasetItem(4),
]
dataset = Dataset.from_extractors(TestExtractor())

class TestTransform(ItemTransform):
def transform_item(self, item):
return self.wrap_item(item, id=int(item.id) + 1, subset='a')

dataset.transform(TestTransform)
dataset.transform(TestTransform)

self.assertEqual(iter_called, 0)

self.assertEqual({'a'}, set(dataset.subsets()))

self.assertEqual(iter_called, 1)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_get_subsets_after_nonlocal_transforms(self):
iter_called = 0
class TestExtractor(Extractor):
def __iter__(self):
nonlocal iter_called
iter_called += 1
yield from [
DatasetItem(1),
DatasetItem(2),
DatasetItem(3),
DatasetItem(4),
]
dataset = Dataset.from_extractors(TestExtractor())

class TestTransform(Transform):
def __iter__(self):
for item in self._extractor:
yield self.wrap_item(item, id=int(item.id) + 1, subset='a')

dataset.transform(TestTransform)
dataset.transform(TestTransform)

self.assertEqual(iter_called, 0)

self.assertEqual({'a'}, set(dataset.subsets()))

self.assertEqual(iter_called, 2)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_raises_when_repeated_items_in_source(self):
dataset = Dataset.from_iterable([DatasetItem(0), DatasetItem(0)])
Expand Down