Skip to content

Commit

Permalink
python/pytorch: Decode shards on client side in ShardReader and suppo…
Browse files Browse the repository at this point in the history
…rt non-uniform samples

Signed-off-by: Soham Manoli <[email protected]>
  • Loading branch information
msoham123 committed Jul 30, 2024
1 parent 9b35132 commit 2120562
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 92 deletions.
25 changes: 24 additions & 1 deletion python/aistore/pytorch/base_iter_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from torch.utils.data import IterableDataset
from abc import ABC, abstractmethod
from aistore.pytorch.worker_request_client import WorkerRequestClient
import torch.utils.data as torch_utils
from itertools import islice


class AISBaseIterDataset(ABC, IterableDataset):
Expand Down Expand Up @@ -88,6 +90,26 @@ def _create_samples_iter(self) -> Iterable:

self._length = length

def _get_worker_iter_info(self) -> tuple[Iterator, str]:
"""
Depending on how many Torch workers are present or if they are even present at all,
return an iterator for the current worker to access and a worker name.
Returns:
tuple[Iterator, str]: Iterator of objects and name of worker
"""
worker_info = torch_utils.get_worker_info()

if worker_info is None or worker_info.num_workers == 1:
return self._iterator, ""

worker_iter = islice(
self._iterator, worker_info.id, None, worker_info.num_workers
)
worker_name = f" (Worker {worker_info.id})"

return worker_iter, worker_name

@abstractmethod
def __iter__(self) -> Iterator:
"""
Expand All @@ -100,10 +122,11 @@ def __iter__(self) -> Iterator:

def _reset_iterator(self):
"""Reset the iterator to start from the beginning."""
self._length = 0
self._iterator = self._create_samples_iter()

def __len__(self):
if self._length is None:
self._length = sum(1 for _ in self._iterator)
self._reset_iterator()
self._length = sum(1 for _ in self._iterator)
return self._length
32 changes: 12 additions & 20 deletions python/aistore/pytorch/iter_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from aistore.pytorch.base_iter_dataset import AISBaseIterDataset
from typing import List, Union, Dict
from aistore.sdk.ais_source import AISSource
from torch.utils.data import get_worker_info
from itertools import islice
from alive_progress import alive_it


Expand Down Expand Up @@ -38,23 +36,17 @@ def __init__(
super().__init__(ais_source_list, prefix_map)
self._etl_name = etl_name
self._show_progress = show_progress
self._reset_iterator()

def __iter__(self):
worker_info = get_worker_info()

if worker_info is None:
# If not using multiple workers, load directly
for obj in alive_it(
self._iterator, title="AISIterDataset", disable=not self._show_progress
):
yield obj.name, obj.get(etl_name=self._etl_name).read_all()
else:
# Slice iterator based on worker id as starting index (0, 1, 2, ..) and steps of total workers
for obj in alive_it(
islice(self._iterator, worker_info.id, None, worker_info.num_workers),
title=f"AISIterDataset (Worker {worker_info.id})",
disable=not self._show_progress,
force_tty=False,
):
yield obj.name, obj.get(etl_name=self._etl_name).read_all()
self._reset_iterator()
# Get iterator for current worker and name (if no workers, just entire iter)
worker_iter, worker_name = self._get_worker_iter_info()

# For object, yield name and content
for obj in alive_it(
worker_iter,
title="AISIterDataset" + worker_name,
disable=not self._show_progress,
force_tty=worker_name == "",
):
yield obj.name, obj.get(etl_name=self._etl_name).read_all()
135 changes: 75 additions & 60 deletions python/aistore/pytorch/shard_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
"""

from aistore.sdk.bucket import Bucket
from typing import Dict, Iterator, List, Union, Iterable
from aistore.sdk.list_object_flag import ListObjectFlag
from aistore.pytorch.utils import get_basename
from aistore.sdk.types import ArchiveSettings
from typing import Dict, Iterator, List, Union
from aistore.pytorch.utils import get_basename, get_extension
from aistore.pytorch.base_iter_dataset import AISBaseIterDataset
from alive_progress import alive_it
from torch.utils.data import get_worker_info
from itertools import islice
from io import BytesIO
from tarfile import open, TarError


class AISShardReader(AISBaseIterDataset):
Expand Down Expand Up @@ -43,66 +41,83 @@ def __init__(
super().__init__(bucket_list, prefix_map)
self._etl_name = etl_name
self._show_progress = show_progress
self._observed_keys = set()

def _get_sample_iter_from_source(self, source: Bucket, prefix: str) -> Iterable:
class ZeroDict(dict):
"""
When `collate_fn` is called while using ShardReader with a dataloader,
the content dictionaries for each sample are merged into a single dictionary
with file extensions as keys and lists of contents as values. This means,
however, that each sample must have a value for that file extension in the batch
at iteration time or else collation will fail. To avoid forcing the user to
pass in a custom collation function, we workaround the default implementation
of collation.
As such, we define a dictionary that has a default value of `b""` (zero bytes)
for every key that we have seen so far. We cannot use None as collation
does not accept None. Initially, when we open a shard tar, we collect every file type
(pre-processing pass) from its members and cache those. Then, we read the shard files.
Lastly, before yielding the sample, we wrap its content dictionary with this custom dictionary
to insert any keys that it does not contain, hence ensuring consistent keys across
samples.
NOTE: For our use case, `defaultdict` does not work due to needing
a `lambda` which cannot be pickled in multithreaded contexts.
"""
Creates an iterable for all samples and contents over each shard from a bucket.

Args:
name (str): Name of shard object
source (Bucket): Bucket where the shard object is stored
prefix (str): Prefix of objects in bucket which are shards
def __init__(self, dict, keys):
super().__init__(dict)
for key in keys:
if key not in self:
self[key] = b""

Returns:
Iterable[Tuple[str, dict(str, bytes)]]: Iterator over all the WDS basenames and content (file extension, data)
in shards from the given shard
"""
for entry in source.list_all_objects_iter(prefix=prefix):

# get iterator of all objects in the shard
objects_iter = source.list_objects_iter(
prefix=entry.name, props="name", flags=[ListObjectFlag.ARCH_DIR]
)

# pool all files with the same basename into dictionary (basename, [file names])
samples_dict = {}
for obj in objects_iter:
basename = get_basename(obj.name)

# Original tar is included in basenames so only yield actual files
if basename != entry.name.split(".")[0]:
if basename not in samples_dict:
samples_dict[basename] = []
samples_dict[basename].append(obj.name)

# if workers are present, then slice iterator
worker_info = get_worker_info()
if worker_info is None or worker_info.num_workers == 1:
worker_items = samples_dict.items()
worker_name = ""
else:
worker_items = islice(
samples_dict.items(), worker_info.id, None, worker_info.num_workers
def _read_samples_from_shards(self, shard_content) -> Dict:
sample_dict = {}

file = BytesIO(shard_content)

try:
# Open the shard as a tarfile as read samples into dict
with open(fileobj=file, mode="r:") as tar:

# Preprocess every key in the archive to ensure consistency in batch collation
self._observed_keys.update(
[get_extension(name) for name in tar.getnames()]
)
worker_name = f" (Worker {worker_info.id})"

# for each basename, get the byte data for each file and yield in dictionary
shard = source.object(entry.name)
for basename, files in alive_it(
worker_items,
title=entry.name + worker_name,
disable=not self._show_progress,
force_tty=worker_info is None or worker_info.num_workers == 1,
):
content_dict = {}
for file_name in files:
file_prefix = file_name.split(".")[-1]
content_dict[file_prefix] = shard.get(
etl_name=self._etl_name,
archive_settings=ArchiveSettings(archpath=file_name),
).read_all()
yield basename, content_dict
for member in tar.getmembers():
if member.isfile():
file_basename = get_basename(member.name)
file_extension = get_extension(member.name)
if file_basename not in sample_dict:
sample_dict[file_basename] = {}
sample_dict[file_basename][file_extension] = tar.extractfile(
member
).read()
except TarError as e:
raise TarError(f"<{self.__class__.__name__}> Error opening tar file: {e}")

return sample_dict

def __iter__(self) -> Iterator:
self._reset_iterator()
yield from self._iterator

# Get iterator for current worker and name (if no workers, just entire iter)
worker_iter, worker_name = self._get_worker_iter_info()

# Read shard, get samples, and yield
for shard in worker_iter:
shard_content = shard.get(
etl_name=self._etl_name,
).read_all()

sample_dict = self._read_samples_from_shards(shard_content)

for basename, content_dict in alive_it(
sample_dict.items(),
title=shard.name + worker_name,
disable=not self._show_progress,
force_tty=worker_name == "",
):
self._length += 1
yield basename, self.ZeroDict(content_dict, self._observed_keys)
14 changes: 14 additions & 0 deletions python/aistore/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def get_basename(name: str) -> str:
return name.split("/")[-1].split(".")[0]


def get_extension(name: str) -> str:
"""
Get the file extension of the object by stripping any basename or prefix.
Args:
name (str): Complete object name
Returns:
str: File extension of the object
"""

return name.split(".")[1]


def parse_url(url: str) -> Tuple[str, str, str]:
"""
Wrapper of sdk/utils.py parse_url. Parse AIS URLs for bucket and object names.
Expand Down
7 changes: 4 additions & 3 deletions python/examples/aisio-pytorch/shard_reader_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"source": [
"# PyTorch: Using `ShardReader` to read WebDataset formatted Shards\n",
"\n",
"The `ShardReader` class can be used to read WebDataset formatted shards from buckets and objects through URLs or by passing in buckets directly. The `ShardReader` class will yield an iterator contain a tuple with the sample basename and a sample content dictionary. This dictionary is keyed by file extension (e.g \"png\") and has values containing the contents of the associated file in bytes. So, given a shard with a sample in it containing a \"cls\" and \"png\" file, you can read the shard using `ShardReader` and then access the sample and it's contents directly by iterating through the `ShardReader` instance. And all of this is done through the AIStore cluster, the `ShardReader` class must only send requests to fetch the data! So there is no tar decoding happening on the client side!"
"The `ShardReader` class can be used to read WebDataset formatted shards from buckets and objects through URLs or by passing in buckets directly. The `ShardReader` class will yield an iterator contain a tuple with the sample basename and a sample content dictionary. This dictionary is keyed by file extension (e.g \"png\") and has values containing the contents of the associated file in bytes. So, given a shard with a sample in it containing a \"cls\" and \"png\" file, you can read the shard using `ShardReader` and then access the sample and it's contents directly by iterating through the `ShardReader` instance."
]
},
{
Expand Down Expand Up @@ -106,9 +106,10 @@
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"loader = DataLoader(shard_reader, batch_size=60)\n",
"loader = DataLoader(shard_reader, batch_size=60, num_workers=4)\n",
"\n",
"for basename, content_dict in loader:\n",
"# basenames, content_dicts have size batch_size each\n",
"for basenames, content_dicts in loader:\n",
" print(basename, list(content_dict.keys()))"
]
}
Expand Down
46 changes: 38 additions & 8 deletions python/tests/unit/pytorch/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from aistore.pytorch.multishard_dataset import AISMultiShardStream
from aistore.pytorch.shard_reader import AISShardReader
from aistore.sdk import Bucket
from tarfile import open, TarInfo
from io import BytesIO


class TestAISDataset(unittest.TestCase):
Expand Down Expand Up @@ -93,21 +95,49 @@ def test_shard_reader(self):
self.patcher = patch("aistore.pytorch.AISShardReader._create_samples_iter")
mock_create_samples_iter = self.patcher.start()

mock_create_samples_iter.return_value = [
("sample_1", {"cls": b"Content of class"}),
("sample_2", {"png": b"Content of class"}),
("sample_3", {"jpg": b"Content of class"}),
]
tar_buffer = BytesIO()
# Open the tar file in write mode
with open(fileobj=tar_buffer, mode="w") as tar:
# Create some dummy content
content = b"Content of class"

# Create a TarInfo object to create samples
tarinfo = TarInfo(name="sample_1.cls")
tarinfo.size = len(content)
tar.addfile(tarinfo, BytesIO(content))
tarinfo = TarInfo(name="sample_1.png")
tarinfo.size = len(content)
tar.addfile(tarinfo, BytesIO(content))
tarinfo = TarInfo(name="sample_1.jpg")
tarinfo.size = len(content)
tar.addfile(tarinfo, BytesIO(content))

tar_buffer.seek(0)

mock_shard = Mock()
mock_shard.name = "test_shard.tar"

mock_get = Mock()
mock_shard.get.return_value = mock_get

mock_get.read_all.return_value = tar_buffer.getvalue()

mock_create_samples_iter.return_value = [mock_shard]

# Create shard reader and get results and compare
shard_reader = AISShardReader(bucket_list=self.mock_bck)

result = list(shard_reader)

expected_result = [
("sample_1", {"cls": b"Content of class"}),
("sample_2", {"png": b"Content of class"}),
("sample_3", {"jpg": b"Content of class"}),
(
"sample_1",
{
"cls": b"Content of class",
"png": b"Content of class",
"jpg": b"Content of class",
},
),
]

self.assertEqual(result, expected_result)
Expand Down

0 comments on commit 2120562

Please sign in to comment.