Skip to content

Commit

Permalink
python/pytorch: Implement WebDataset shard reader and tests
Browse files Browse the repository at this point in the history
Signed-off-by: Soham Manoli <[email protected]>
Signed-off-by: Soumyendra Shrivastava <[email protected]>

Co-authored-by: Soham Manoli <[email protected]>
Co-authored-by: Soumyendra Shrivastava <[email protected]>
  • Loading branch information
msoham123 and soumyendra98 committed Jun 14, 2024
1 parent 7896e26 commit 1be385b
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 13 deletions.
1 change: 1 addition & 0 deletions python/aistore/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from aistore.pytorch.dataset import AISDataset
from aistore.pytorch.multishard_dataset import AISMultiShardStream
from aistore.pytorch.iter_dataset import AISIterDataset
from aistore.pytorch.shard_reader import AISShardReader
77 changes: 77 additions & 0 deletions python/aistore/pytorch/shard_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
AIS Shard Reader for PyTorch
PyTorch Dataset and DataLoader for AIS.
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from aistore.sdk.bucket import Bucket
from torch.utils.data import IterableDataset
from typing import Iterator, List, Union
from aistore.pytorch.utils import list_wds_samples_iter
from aistore.sdk import Client


class AISShardReader(IterableDataset):
"""
An iterable-style dataset that iterates over objects stored as Webdataset shards.
Args:
client_url (str): AIS endpoint URL
urls_list (Union[str, List[str]]): Single or list of URLs, can be URLS for buckets and/or objects
bucket_list (Union[Bucket, List[Bucket]]): Single or list of Bucket objects to load data
etl_name (str, optional): Optional ETL on the AIS cluster to apply to each object
Yields:
Tuple[str, List[bytes]]: Each item is a tuple where the first element is the basename of the shard
and the second element is a list of bytes representing the files in the shard.
"""

def __init__(
self,
client_url: str,
urls_list: Union[str, List[str]] = [],
bucket_list: Union[Bucket, List[Bucket]] = [],
etl_name: str = None,
):
if not urls_list and not bucket_list:
raise ValueError(
"At least one of urls_list or bucket_list must be provided"
)

self.client = Client(client_url)
self.urls_list = [urls_list] if isinstance(urls_list, str) else urls_list
self.bucket_list = (
[bucket_list] if isinstance(bucket_list, Bucket) else bucket_list
)
self.etl_name = etl_name
self.length = None
self._reset_iterator()

def __iter__(self) -> Iterator:
self._reset_iterator()
self.length = 0
for basename, content_dict in self._samples_iter:
self.length += 1
yield basename, content_dict

def _reset_iterator(self):
"""
Reset the iterator to start from the beginning
"""
self._samples_iter = list_wds_samples_iter(
client=self.client,
urls_list=self.urls_list,
bucket_list=self.bucket_list,
etl_name=self.etl_name,
)

def __len__(self):
if self.length is None:
self._reset_iterator()
self.length = self._calculate_len()
return self.length

def _calculate_len(self):
return sum(1 for _ in self._samples_iter)
104 changes: 96 additions & 8 deletions python/aistore/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
"""

from typing import List, Iterable
from typing import List, Iterable, Tuple
from urllib.parse import urlunparse
from aistore.sdk import Client
from aistore.sdk.ais_source import AISSource
Expand Down Expand Up @@ -82,7 +82,7 @@ def list_objects_iterator(

def list_shard_objects_iterator(
bucket: Bucket, prefix: str = "", etl_name: str = ""
) -> Iterable[Object]:
) -> Iterable[bytes]:
"""
Create an iterable over all the objects in the given shards.
Expand All @@ -104,10 +104,98 @@ def list_shard_objects_iterator(
)

for obj in objects_iter:
if obj.name == path:
continue
obj_name = obj.name.replace(f"{path}/", "", 1)
yield bucket.object(path).get(
etl_name=etl_name,
archive_settings=ArchiveSettings(archpath=obj_name),
if obj.name != path:
obj_name = obj.name.replace(f"{path}/", "", 1)
yield bucket.object(path).get(
etl_name=etl_name,
archive_settings=ArchiveSettings(archpath=obj_name),
).read_all()


def get_basename(name: str) -> str:
"""
Get the basename of the object name by stripping any directory information and suffix.
Args:
name (str): Complete object name
Returns:
str: Basename of the object
"""

return name.split("/")[-1].split(".")[0]


def __samples_from_bck_iter(shard_name: str, bucket: Bucket, etl_name: str):
"""
Helper function to create an iterator for all samples and contents over the given shard name.
Args:
name (str): Name of shard object
bucket (Bucket): Bucket where the shard object is stored
etl_name (str): Name of ETL (Extract, Transform, Load) to apply to each object
Returns:
Iterable[Tuple[str, dict(str, bytes)]]: Iterator over all the WDS basenames and content (file extension, data)
in shards from the given shard
"""
# get iterator of all objects in the shard
objects_iter = bucket.list_objects_iter(
prefix=shard_name, props="name", flags=[ListObjectFlag.ARCH_DIR]
)

# pool all files with the same basename into dictionary (basename, [file names])
samples_dict = {}
for obj in objects_iter:
basename = get_basename(obj.name)

# Original tar is included in basenames so only yield actual files
if basename != shard_name.split(".")[0]:
if basename not in samples_dict:
samples_dict[basename] = []
samples_dict[basename].append(obj.name)

# for each basename, get the byte data for each file and yield in dictionary
shard = bucket.object(shard_name)
for basename, files in samples_dict.items():
content_dict = {}
for file_name in files:
file_prefix = file_name.split(".")[-1]
content_dict[file_prefix] = shard.get(
etl_name=etl_name, archive_settings=ArchiveSettings(archpath=file_name)
).read_all()
yield basename, content_dict


def list_wds_samples_iter(
client: Client,
urls_list: List[str],
bucket_list: List[Bucket],
etl_name: str,
) -> Iterable[Tuple[str, bytes]]:
"""
Create an iterator over all of the shard sample basenames and sample contents.
Args:
client (Client): AIStore Client for accessing buckets and objects
urls_list (List[str]): List of URLs, can be URLS for buckets and/or objects
bucket_list (List[Bucket]): List of Bucket objects containing the shards to load data
etl_name (str): Name of ETL (Extract, Transform, Load) to apply to each object
Returns:
Iterable[Tuple[str, dict(str, bytes)]]: Iterator over all the WDS basenames and content (file extension, data)
in shards from the given URLs and buckets
"""

for item in urls_list:
provider, bck_name, path = parse_url(item)
bucket = client.bucket(bck_name=bck_name, provider=provider)
if path == None or path == "":
for shard in bucket.list_objects_iter():
yield from __samples_from_bck_iter(shard.name, bucket, etl_name)
else:
yield from __samples_from_bck_iter(path, bucket, etl_name)

for bucket in bucket_list:
for shard in bucket.list_objects_iter():
yield from __samples_from_bck_iter(shard.name, bucket, etl_name)
86 changes: 83 additions & 3 deletions python/tests/integration/pytorch/test_pytorch_plugin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""
Test class for AIStore PyTorch Plugin
Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
"""

import unittest
from pathlib import Path
import torchdata.datapipes.iter as torch_pipes

from aistore.sdk import Client
from aistore.sdk import Client, Bucket
from aistore.sdk.errors import AISError, ErrBckNotFound
from aistore.sdk.dataset.data_shard import DataShard
from aistore.pytorch import (
Expand All @@ -17,6 +16,7 @@
AISIterDataset,
AISMultiShardStream,
)
from aistore.pytorch.shard_reader import AISShardReader
from tests.integration import CLUSTER_ENDPOINT
from tests.utils import (
create_and_put_object,
Expand Down Expand Up @@ -183,6 +183,86 @@ def test_multishard_stream(self):
for i, content in enumerate(dataset):
self.assertEqual(content, combined_content[i])

def test_shard_reader(self):

self.local_test_files.mkdir()

bucket: Bucket = self.client.bucket(self.bck_name)

shard_one_dict = {
"sample_1.cls": b"Class content of sample one",
"sample_1.jpg": b"Jpg content of sample one",
"sample_1.png": b"Png content of sample one",
"sample_2.cls": b"Class content of sample two",
"sample_2.jpg": b"Jpg content of sample two",
"sample_2.png": b"Png content of sample two",
}
shard_one_archive_name = "shard_1.tar"
shard_one_archive_path = self.local_test_files.joinpath(shard_one_archive_name)
create_archive(shard_one_archive_path, shard_one_dict)
shard_one_obj = bucket.object(obj_name=shard_one_archive_name)
shard_one_obj.put_file(shard_one_archive_path)

shard_two_dict = {
"sample_3.cls": b"Class content of sample three",
"sample_3.jpg": b"Jpg content of sample three",
"sample_3.png": b"Png content of sample three",
"sample_4.cls": b"Class content of sample four",
"sample_4.jpg": b"Jpg content of sample four",
"sample_4.png": b"Png content of sample four",
}
shard_two_archive_name = "shard_2.tar"
shard_two_archive_path = self.local_test_files.joinpath(shard_two_archive_name)
create_archive(shard_two_archive_path, shard_two_dict)
shard_two_obj = bucket.object(obj_name=shard_two_archive_name)
shard_two_obj.put_file(shard_two_archive_path)

# Expected output from the reader
expected_sample_dicts = [
{
"cls": b"Class content of sample one",
"jpg": b"Jpg content of sample one",
"png": b"Png content of sample one",
},
{
"cls": b"Class content of sample two",
"jpg": b"Jpg content of sample two",
"png": b"Png content of sample two",
},
{
"cls": b"Class content of sample three",
"jpg": b"Jpg content of sample three",
"png": b"Png content of sample three",
},
{
"cls": b"Class content of sample four",
"jpg": b"Jpg content of sample four",
"png": b"Png content of sample four",
},
]

sample_basenames = ["sample_1", "sample_2", "sample_3", "sample_4"]

# Test shard_reader with url params
url_one = f"{bucket.provider}://{bucket.name}/{shard_one_obj.name}"
url_two = f"{bucket.provider}://{bucket.name}/{shard_two_obj.name}"
url_shard_reader = AISShardReader(
client_url=CLUSTER_ENDPOINT, urls_list=[url_one, url_two]
)

for i, (basename, content_dict) in enumerate(url_shard_reader):
self.assertEqual(basename, sample_basenames[i])
self.assertEqual(content_dict, expected_sample_dicts[i])

# Test shard_reader with bucket_params
bck_shard_reader = AISShardReader(
client_url=CLUSTER_ENDPOINT, bucket_list=[bucket]
)

for i, (basename, content_dict) in enumerate(bck_shard_reader):
self.assertEqual(basename, sample_basenames[i])
self.assertEqual(content_dict, expected_sample_dicts[i])


if __name__ == "__main__":
unittest.main()
57 changes: 55 additions & 2 deletions python/tests/unit/pytorch/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""
Test class for AIStore PyTorch Plugin
Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
"""

import unittest
from unittest.mock import patch, Mock, MagicMock
from aistore.pytorch.dataset import AISDataset
from aistore.pytorch.multishard_dataset import AISMultiShardStream
from aistore.pytorch.iter_dataset import AISIterDataset
from aistore.pytorch.multishard_dataset import AISMultiShardStream
from aistore.pytorch.shard_reader import AISShardReader
from aistore.pytorch.utils import list_wds_samples_iter
from aistore.sdk.list_object_flag import ListObjectFlag


class TestAISDataset(unittest.TestCase):
def setUp(self) -> None:
self.mock_client = Mock()
self.mock_bucket = Mock()
mock_obj = Mock()
mock_obj.get.return_value.read_all.return_value = b"mock data"
self.mock_objects = [
Expand Down Expand Up @@ -81,3 +88,49 @@ def test_multi_shard_stream(self):
results = list(iter(stream))

self.assertEqual(results, expected_results)

@patch("aistore.sdk.bucket")
def test_list_wds_sample_iter(self, mock_bucket):
# Mock the list_objects_iter method
mock_shard = Mock()
mock_shard.name = "shard.tar"
mock_shard.list_objects_iter.return_value = iter([])
mock_bucket.list_objects_iter.return_value = [mock_shard]

# Call the function under test
list(list_wds_samples_iter(None, [], [mock_bucket], None))

# Assert that list_objects_iter was called exactly once for the shard
mock_bucket.list_objects_iter.assert_any_call()

# Assert that list_objects_iter was called exactly once with arch params
mock_bucket.list_objects_iter.assert_called_with(
prefix=mock_shard.name, props="name", flags=[ListObjectFlag.ARCH_DIR]
)

@patch("aistore.pytorch.shard_reader.list_wds_samples_iter")
def test_shard_reader(self, mock_list_wds_samples_iter):
# Mock list_wds_samples_iter
mock_list_wds_samples_iter.return_value = [
("sample_1", {"cls": b"Content of class"}),
("sample_2", {"png": b"Content of class"}),
("sample_3", {"jpg": b"Content of class"}),
]

# Create shard reader and get results and compare
shard_reader = AISShardReader(
client_url="http://example.com", urls_list="http://example.com/data"
)

result = list(shard_reader)

expected_result = [
("sample_1", {"cls": b"Content of class"}),
("sample_2", {"png": b"Content of class"}),
("sample_3", {"jpg": b"Content of class"}),
]

self.assertEqual(result, expected_result)

# Ensure the iter is called correctly
mock_list_wds_samples_iter.assert_called()

0 comments on commit 1be385b

Please sign in to comment.