diff --git a/python/aistore/pytorch/base_iter_dataset.py b/python/aistore/pytorch/base_iter_dataset.py index 045d903122..6b7c5e9046 100644 --- a/python/aistore/pytorch/base_iter_dataset.py +++ b/python/aistore/pytorch/base_iter_dataset.py @@ -42,7 +42,6 @@ def __init__( ) self._prefix_map = prefix_map self._iterator = None - self._length = None def _get_sample_iter_from_source(self, source: AISSource, prefix: str) -> Iterable: """ @@ -68,14 +67,11 @@ def _create_samples_iter(self) -> Iterable: Returns: Iterable: Iterable over the samples of the dataset """ - length = 0 - for source in self._ais_source_list: # Add pytorch worker support to the internal request client source.client = WorkerRequestClient(source.client) if source not in self._prefix_map or self._prefix_map[source] is None: for sample in self._get_sample_iter_from_source(source, ""): - length += 1 yield sample else: prefixes = ( @@ -85,11 +81,8 @@ def _create_samples_iter(self) -> Iterable: ) for prefix in prefixes: for sample in self._get_sample_iter_from_source(source, prefix): - length += 1 yield sample - self._length = length - def _get_worker_iter_info(self) -> tuple[Iterator, str]: """ Depending on how many Torch workers are present or if they are even present at all, @@ -122,11 +115,20 @@ def __iter__(self) -> Iterator: def _reset_iterator(self): """Reset the iterator to start from the beginning.""" - self._length = 0 self._iterator = self._create_samples_iter() def __len__(self): - if self._length is None: - self._reset_iterator() - self._length = sum(1 for _ in self._iterator) - return self._length + """ + Returns the length of the dataset. Note that calling this + will iterate through the dataset, taking O(N) time. + + NOTE: If you want the length of the dataset after iterating through + it, use `for i, data in enumerate(dataset)` instead. + """ + self._reset_iterator() + sum = 0 + + for _ in self._iterator: + sum += 1 + + return sum diff --git a/python/aistore/pytorch/shard_reader.py b/python/aistore/pytorch/shard_reader.py index c5658854a1..9cc36f6de6 100644 --- a/python/aistore/pytorch/shard_reader.py +++ b/python/aistore/pytorch/shard_reader.py @@ -13,6 +13,7 @@ from alive_progress import alive_it from io import BytesIO from tarfile import open, TarError +from aistore.sdk.list_object_flag import ListObjectFlag class AISShardReader(AISBaseIterDataset): @@ -43,6 +44,28 @@ def __init__( self._show_progress = show_progress self._observed_keys = set() + def __len__(self): + """ + Returns the length of the dataset. Note that calling this + will iterate through the dataset, taking O(N) time. + + NOTE: If you want the length of the dataset after iterating through + it, use `for i, data in enumerate(dataset)` instead. + """ + self._reset_iterator() + length = 0 + + for shard in self._iterator: + + for _ in shard.bucket.list_objects_iter( + prefix=shard.name, props="name", flags=[ListObjectFlag.ARCH_DIR] + ): + length += 1 + + length -= 1 # Exclude the bucket (overcounted earlier) + + return length + class ZeroDict(dict): """ When `collate_fn` is called while using ShardReader with a dataloader, @@ -119,5 +142,4 @@ def __iter__(self) -> Iterator: disable=not self._show_progress, force_tty=worker_name == "", ): - self._length += 1 yield basename, self.ZeroDict(content_dict, self._observed_keys)