Skip to content

Commit

Permalink
python/pytorch: Improve error handling for datasets and remove unused…
Browse files Browse the repository at this point in the history
… Client

Signed-off-by: Soham Manoli <[email protected]>
  • Loading branch information
msoham123 committed Jul 2, 2024
1 parent cef15ae commit 8572b8e
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 51 deletions.
8 changes: 3 additions & 5 deletions python/aistore/pytorch/base_iter_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from typing import List, Union, Iterable, Dict, Iterator
from aistore.sdk.ais_source import AISSource
from aistore.sdk import Client
from torch.utils.data import IterableDataset
from abc import ABC, abstractmethod

Expand All @@ -19,21 +18,20 @@ class AISBaseIterDataset(ABC, IterableDataset):
to modify the behavior of loading samples from a source, override :meth:`_get_sample_iter_from_source`.
Args:
client_url (str): AIS endpoint URL
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows
objects with the specified prefixes to be used from each source
"""

def __init__(
self,
client_url: str,
ais_source_list: Union[AISSource, List[AISSource]],
prefix_map: Dict[AISSource, Union[str, List[str]]] = {},
) -> None:
if not ais_source_list:
raise ValueError("ais_source_list must be provided")
self._client = Client(client_url)
raise ValueError(
f"<{self.__class__.__name__}> ais_source_list must be provided"
)
self._ais_source_list = (
[ais_source_list]
if isinstance(ais_source_list, AISSource)
Expand Down
8 changes: 3 additions & 5 deletions python/aistore/pytorch/base_map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from typing import List, Union, Dict
from aistore.sdk.ais_source import AISSource
from aistore.sdk import Client
from aistore.sdk.object import Object
from torch.utils.data import Dataset
from abc import ABC, abstractmethod
Expand All @@ -20,21 +19,20 @@ class AISBaseMapDataset(ABC, Dataset):
to modify the behavior of loading samples from a source, override :meth:`_get_sample_list_from_source`.
Args:
client_url (str): AIS endpoint URL
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows
objects with the specified prefixes to be used from each source
"""

def __init__(
self,
client_url: str,
ais_source_list: Union[AISSource, List[AISSource]],
prefix_map: Dict[AISSource, Union[str, List[str]]] = {},
) -> None:
if not ais_source_list:
raise ValueError("ais_source_list must be provided")
self._client = Client(client_url)
raise ValueError(
f"<{self.__class__.__name__}> ais_source_list must be provided"
)
self._ais_source_list = (
[ais_source_list]
if isinstance(ais_source_list, AISSource)
Expand Down
4 changes: 1 addition & 3 deletions python/aistore/pytorch/iter_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class AISIterDataset(AISBaseIterDataset):
If `etl_name` is provided, that ETL must already exist on the AIStore cluster.
Args:
client_url (str): AIS endpoint URL
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of AISSource objects to list of prefixes that only allows
objects with the specified prefixes to be used from each source
Expand All @@ -27,12 +26,11 @@ class AISIterDataset(AISBaseIterDataset):

def __init__(
self,
client_url: str,
ais_source_list: Union[AISSource, List[AISSource]],
prefix_map: Dict[AISSource, Union[str, List[str]]] = {},
etl_name: str = None,
):
super().__init__(client_url, ais_source_list, prefix_map)
super().__init__(ais_source_list, prefix_map)
self._etl_name = etl_name
self._length = None

Expand Down
15 changes: 9 additions & 6 deletions python/aistore/pytorch/map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class AISMapDataset(AISBaseMapDataset):
If `etl_name` is provided, that ETL must already exist on the AIStore cluster.
Args:
client_url (str): AIS endpoint URL
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows
objects with the specified prefixes to be used from each source
Expand All @@ -27,18 +26,22 @@ class AISMapDataset(AISBaseMapDataset):

def __init__(
self,
client_url: str,
ais_source_list: Union[AISSource, List[AISSource]] = [],
prefix_map: Dict[AISSource, Union[str, List[str]]] = {},
etl_name: str = None,
):
super().__init__(client_url, ais_source_list, prefix_map)
super().__init__(ais_source_list, prefix_map)
self._etl_name = etl_name

def __len__(self):
return len(self._samples)

def __getitem__(self, index: int):
obj = self._samples[index]
content = obj.get(etl_name=self._etl_name).read_all()
return obj.name, content
try:
obj = self._samples[index]
content = obj.get(etl_name=self._etl_name).read_all()
return obj.name, content
except IndexError:
raise IndexError(
f"<{self.__class__.__name__}> index must be in bounds of dataset"
)
4 changes: 1 addition & 3 deletions python/aistore/pytorch/shard_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class AISShardReader(AISBaseIterDataset):
An iterable-style dataset that iterates over objects stored as Webdataset shards.
Args:
client_url (str): AIS endpoint URL
bucket_list (Union[Bucket, List[Bucket]]): Single or list of Bucket objects to load data
prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of Bucket objects to list of prefixes that only allows
objects with the specified prefixes to be used from each source
Expand All @@ -32,12 +31,11 @@ class AISShardReader(AISBaseIterDataset):

def __init__(
self,
client_url: str,
bucket_list: Union[Bucket, List[Bucket]],
prefix_map: Dict[Bucket, Union[str, List[str]]] = {},
etl_name: str = None,
):
super().__init__(client_url, bucket_list, prefix_map)
super().__init__(bucket_list, prefix_map)
self._etl_name = etl_name
self._length = None

Expand Down
13 changes: 3 additions & 10 deletions python/tests/integration/pytorch/test_pytorch_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ def test_ais_dataset(self):
)
content_dict[i] = content

ais_dataset = AISMapDataset(
client_url=CLUSTER_ENDPOINT, ais_source_list=[self.bck]
)
ais_dataset = AISMapDataset(ais_source_list=[self.bck])
self.assertEqual(len(ais_dataset), num_objs)
for i in range(num_objs):
obj_name, content = ais_dataset[i]
Expand All @@ -132,9 +130,7 @@ def test_ais_iter_dataset(self):
)
content_dict[i] = content

ais_iter_dataset = AISIterDataset(
client_url=CLUSTER_ENDPOINT, ais_source_list=self.bck
)
ais_iter_dataset = AISIterDataset(ais_source_list=self.bck)
self.assertEqual(len(ais_iter_dataset), num_objs)
for i, (obj_name, content) in enumerate(ais_iter_dataset):
self.assertEqual(obj_name, f"temp/obj{ i }")
Expand Down Expand Up @@ -247,7 +243,6 @@ def test_shard_reader(self):
# Test shard_reader with prefixes

url_shard_reader = AISShardReader(
client_url=CLUSTER_ENDPOINT,
bucket_list=[bucket],
prefix_map={bucket: "shard_1.tar"},
)
Expand All @@ -257,9 +252,7 @@ def test_shard_reader(self):
self.assertEqual(content_dict, expected_sample_dicts[i])

# Test shard_reader with bucket_params
bck_shard_reader = AISShardReader(
client_url=CLUSTER_ENDPOINT, bucket_list=[bucket]
)
bck_shard_reader = AISShardReader(bucket_list=[bucket])

for i, (basename, content_dict) in enumerate(bck_shard_reader):
self.assertEqual(basename, sample_basenames[i])
Expand Down
22 changes: 3 additions & 19 deletions python/tests/unit/pytorch/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,25 @@ def setUp(self) -> None:
"aistore.pytorch.base_map_dataset.AISBaseMapDataset._create_samples_list",
return_value=self.mock_objects,
)
self.patcher_client_map = patch(
"aistore.pytorch.base_map_dataset.Client", return_value=self.mock_client
)
self.patcher_client_iter = patch(
"aistore.pytorch.base_iter_dataset.Client", return_value=self.mock_client
)
self.patcher_get_objects_iterator.start()
self.patcher_get_objects.start()
self.patcher_client_map.start()
self.patcher_client_iter.start()

def tearDown(self) -> None:
self.patcher_get_objects_iterator.stop()
self.patcher_get_objects.stop()
self.patcher_client_map.stop()
self.patcher_client_iter.stop()

def test_map_dataset(self):
self.mock_bck.list_all_objects_iter.return_value = iter(self.mock_objects)

ais_dataset = AISMapDataset(
client_url="mock_client_url", ais_source_list=self.mock_bck
)
ais_dataset = AISMapDataset(ais_source_list=self.mock_bck)

self.assertIsNone(ais_dataset._etl_name)

self.assertEqual(len(ais_dataset), 2)
self.assertEqual(ais_dataset[0][1], b"mock data")

def test_iter_dataset(self):
ais_iter_dataset = AISIterDataset(
client_url="mock_client_url", ais_source_list=self.mock_bck
)
ais_iter_dataset = AISIterDataset(ais_source_list=self.mock_bck)
self.assertIsNone(ais_iter_dataset._etl_name)

self.assertEqual(len(ais_iter_dataset), 2)
Expand Down Expand Up @@ -115,9 +101,7 @@ def test_shard_reader(self):
]

# Create shard reader and get results and compare
shard_reader = AISShardReader(
client_url="http://example.com", bucket_list=self.mock_bck
)
shard_reader = AISShardReader(bucket_list=self.mock_bck)

result = list(shard_reader)

Expand Down

0 comments on commit 8572b8e

Please sign in to comment.