Skip to content

Commit

Permalink
python/pytorch: Fix length for iterable datasets
Browse files Browse the repository at this point in the history
Signed-off-by: Soham Manoli <[email protected]>
  • Loading branch information
msoham123 committed Jul 30, 2024
1 parent 9155d28 commit c06c0a5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
26 changes: 14 additions & 12 deletions python/aistore/pytorch/base_iter_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(
)
self._prefix_map = prefix_map
self._iterator = None
self._length = None

def _get_sample_iter_from_source(self, source: AISSource, prefix: str) -> Iterable:
"""
Expand All @@ -68,14 +67,11 @@ def _create_samples_iter(self) -> Iterable:
Returns:
Iterable: Iterable over the samples of the dataset
"""
length = 0

for source in self._ais_source_list:
# Add pytorch worker support to the internal request client
source.client = WorkerRequestClient(source.client)
if source not in self._prefix_map or self._prefix_map[source] is None:
for sample in self._get_sample_iter_from_source(source, ""):
length += 1
yield sample
else:
prefixes = (
Expand All @@ -85,11 +81,8 @@ def _create_samples_iter(self) -> Iterable:
)
for prefix in prefixes:
for sample in self._get_sample_iter_from_source(source, prefix):
length += 1
yield sample

self._length = length

def _get_worker_iter_info(self) -> tuple[Iterator, str]:
"""
Depending on how many Torch workers are present or if they are even present at all,
Expand Down Expand Up @@ -122,11 +115,20 @@ def __iter__(self) -> Iterator:

def _reset_iterator(self):
"""Reset the iterator to start from the beginning."""
self._length = 0
self._iterator = self._create_samples_iter()

def __len__(self):
if self._length is None:
self._reset_iterator()
self._length = sum(1 for _ in self._iterator)
return self._length
"""
Returns the length of the dataset. Note that calling this
will iterate through the dataset, taking O(N) time.
NOTE: If you want the length of the dataset after iterating through
it, use `for i, data in enumerate(dataset)` instead.
"""
self._reset_iterator()
sum = 0

for _ in self._iterator:
sum += 1

return sum
24 changes: 23 additions & 1 deletion python/aistore/pytorch/shard_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from alive_progress import alive_it
from io import BytesIO
from tarfile import open, TarError
from aistore.sdk.list_object_flag import ListObjectFlag


class AISShardReader(AISBaseIterDataset):
Expand Down Expand Up @@ -43,6 +44,28 @@ def __init__(
self._show_progress = show_progress
self._observed_keys = set()

def __len__(self):
"""
Returns the length of the dataset. Note that calling this
will iterate through the dataset, taking O(N) time.
NOTE: If you want the length of the dataset after iterating through
it, use `for i, data in enumerate(dataset)` instead.
"""
self._reset_iterator()
length = 0

for shard in self._iterator:

for _ in shard.bucket.list_objects_iter(
prefix=shard.name, props="name", flags=[ListObjectFlag.ARCH_DIR]
):
length += 1

length -= 1 # Exclude the bucket (overcounted earlier)

return length

class ZeroDict(dict):
"""
When `collate_fn` is called while using ShardReader with a dataloader,
Expand Down Expand Up @@ -119,5 +142,4 @@ def __iter__(self) -> Iterator:
disable=not self._show_progress,
force_tty=worker_name == "",
):
self._length += 1
yield basename, self.ZeroDict(content_dict, self._observed_keys)

0 comments on commit c06c0a5

Please sign in to comment.