From cef15ae7cefa7421f1f10e068b50b0cac0cf8b4e Mon Sep 17 00:00:00 2001 From: Soham Manoli Date: Tue, 2 Jul 2024 09:59:07 -0700 Subject: [PATCH] python/pytorch: Refactor datasets and utils Signed-off-by: Soham Manoli --- python/aistore/pytorch/__init__.py | 4 +- python/aistore/pytorch/base_dataset.py | 66 ------- python/aistore/pytorch/base_iter_dataset.py | 94 ++++++++++ python/aistore/pytorch/base_map_dataset.py | 88 ++++++++++ python/aistore/pytorch/iter_dataset.py | 40 ++--- .../pytorch/{dataset.py => map_dataset.py} | 26 ++- python/aistore/pytorch/multishard_dataset.py | 45 ++++- python/aistore/pytorch/shard_reader.py | 99 ++++++----- python/aistore/pytorch/utils.py | 165 +----------------- .../pyaisloader/pytorch_benchmark.py | 5 +- .../pytorch/test_pytorch_plugin.py | 22 +-- python/tests/unit/pytorch/test_datasets.py | 91 +++++----- 12 files changed, 370 insertions(+), 375 deletions(-) delete mode 100644 python/aistore/pytorch/base_dataset.py create mode 100644 python/aistore/pytorch/base_iter_dataset.py create mode 100644 python/aistore/pytorch/base_map_dataset.py rename python/aistore/pytorch/{dataset.py => map_dataset.py} (56%) diff --git a/python/aistore/pytorch/__init__.py b/python/aistore/pytorch/__init__.py index dbc18a3d99..8c4b80634a 100644 --- a/python/aistore/pytorch/__init__.py +++ b/python/aistore/pytorch/__init__.py @@ -4,7 +4,9 @@ AISSourceLister, ) -from aistore.pytorch.dataset import AISDataset +from aistore.pytorch.map_dataset import AISMapDataset from aistore.pytorch.multishard_dataset import AISMultiShardStream from aistore.pytorch.iter_dataset import AISIterDataset from aistore.pytorch.shard_reader import AISShardReader +from aistore.pytorch.base_map_dataset import AISBaseMapDataset +from aistore.pytorch.base_iter_dataset import AISBaseIterDataset diff --git a/python/aistore/pytorch/base_dataset.py b/python/aistore/pytorch/base_dataset.py deleted file mode 100644 index 5821119f17..0000000000 --- a/python/aistore/pytorch/base_dataset.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Base classes for AIS Datasets and Iterable Datasets - -Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -""" - -from typing import List, Union -from aistore.sdk.ais_source import AISSource -from aistore.pytorch.utils import list_objects, list_objects_iterator -from aistore.sdk import Client - - -class AISBaseClass: - """ - A base class for creating AIS Datasets for PyTorch. - - Args: - client_url (str): AIS endpoint URL - urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data - ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data - """ - - def __init__( - self, - client_url: str, - urls_list: Union[str, List[str]], - ais_source_list: Union[AISSource, List[AISSource]], - ) -> None: - self.client = Client(client_url) - if isinstance(urls_list, str): - urls_list = [urls_list] - if isinstance(ais_source_list, AISSource): - ais_source_list = [ais_source_list] - self._objects = list_objects(self.client, urls_list, ais_source_list) - - -class AISBaseClassIter: - """ - A base class for creating AIS Iterable Datasets for PyTorch. - - Args: - client_url (str): AIS endpoint URL - urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data - ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data - """ - - def __init__( - self, - client_url: str, - urls_list: Union[str, List[str]], - ais_source_list: Union[AISSource, List[AISSource]], - ) -> None: - self.client = Client(client_url) - if isinstance(urls_list, str): - urls_list = [urls_list] - if isinstance(ais_source_list, AISSource): - ais_source_list = [ais_source_list] - self.urls_list = urls_list - self.ais_source_list = ais_source_list - self._reset_iterator() - - def _reset_iterator(self): - """Reset the object iterator to start from the beginning""" - self._object_iter = list_objects_iterator( - self.client, self.urls_list, self.ais_source_list - ) diff --git a/python/aistore/pytorch/base_iter_dataset.py b/python/aistore/pytorch/base_iter_dataset.py new file mode 100644 index 0000000000..4317ad7a90 --- /dev/null +++ b/python/aistore/pytorch/base_iter_dataset.py @@ -0,0 +1,94 @@ +""" +Base class for AIS Iterable Style Datasets + +Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +""" + +from typing import List, Union, Iterable, Dict, Iterator +from aistore.sdk.ais_source import AISSource +from aistore.sdk import Client +from torch.utils.data import IterableDataset +from abc import ABC, abstractmethod + + +class AISBaseIterDataset(ABC, IterableDataset): + """ + A base class for creating AIS Iterable Datasets. Should not be instantiated directly. Subclasses + should implement :meth:`__iter__` which returns the samples from the dataset and can optionally + override other methods from torch IterableDataset such as :meth:`__len__`. Additionally, + to modify the behavior of loading samples from a source, override :meth:`_get_sample_iter_from_source`. + + Args: + client_url (str): AIS endpoint URL + ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data + prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows + objects with the specified prefixes to be used from each source + """ + + def __init__( + self, + client_url: str, + ais_source_list: Union[AISSource, List[AISSource]], + prefix_map: Dict[AISSource, Union[str, List[str]]] = {}, + ) -> None: + if not ais_source_list: + raise ValueError("ais_source_list must be provided") + self._client = Client(client_url) + self._ais_source_list = ( + [ais_source_list] + if isinstance(ais_source_list, AISSource) + else ais_source_list + ) + self._prefix_map = prefix_map + self._iterator = None + self._reset_iterator() + + def _get_sample_iter_from_source(self, source: AISSource, prefix: str) -> Iterable: + """ + Creates an iterable of samples from the AISSource and the objects stored within. Must be able to handle prefixes + as well. The default implementation returns an iterable of Objects. This method can be overridden + to provides other functionality (such as reading the data and creating usable samples for different + file types). + + Args: + source (AISSource): AISSource (:class:`aistore.sdk.ais_source.AISSource`) provides an interface for accessing a list of + AIS objects or their URLs + prefix (str): Prefix to dictate what objects should be included + + Returns: + Iterable: Iterable over the content of the dataset + """ + yield from source.list_all_objects_iter(prefix=prefix) + + def _create_samples_iter(self) -> Iterable: + """ + Create an iterable given the AIS sources and associated prefixes. + + Returns: + Iterable: Iterable over the samples of the dataset + """ + for source in self._ais_source_list: + if source not in self._prefix_map or self._prefix_map[source] is None: + yield from self._get_sample_iter_from_source(source, "") + else: + prefixes = ( + [self._prefix_map[source]] + if isinstance(self._prefix_map[source], str) + else self._prefix_map[source] + ) + for prefix in prefixes: + yield from self._get_sample_iter_from_source(source, prefix) + + @abstractmethod + def __iter__(self) -> Iterator: + """ + Return iterator with samples in this dataset. + + Returns: + Iterator: Iterator of samples + """ + pass + + def _reset_iterator(self): + """Reset the iterator to start from the beginning""" + self._iterator = self._create_samples_iter() diff --git a/python/aistore/pytorch/base_map_dataset.py b/python/aistore/pytorch/base_map_dataset.py new file mode 100644 index 0000000000..1a77d358a8 --- /dev/null +++ b/python/aistore/pytorch/base_map_dataset.py @@ -0,0 +1,88 @@ +""" +Base class for AIS Map Style Datasets + +Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +""" + +from typing import List, Union, Dict +from aistore.sdk.ais_source import AISSource +from aistore.sdk import Client +from aistore.sdk.object import Object +from torch.utils.data import Dataset +from abc import ABC, abstractmethod + + +class AISBaseMapDataset(ABC, Dataset): + """ + A base class for creating map-style AIS Datasets. Should not be instantiated directly. Subclasses + should implement :meth:`__getitem__` which fetches a samples given a key from the dataset and can optionally + override other methods from torch IterableDataset such as :meth:`__len__`. Additionally, + to modify the behavior of loading samples from a source, override :meth:`_get_sample_list_from_source`. + + Args: + client_url (str): AIS endpoint URL + ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data + prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows + objects with the specified prefixes to be used from each source + """ + + def __init__( + self, + client_url: str, + ais_source_list: Union[AISSource, List[AISSource]], + prefix_map: Dict[AISSource, Union[str, List[str]]] = {}, + ) -> None: + if not ais_source_list: + raise ValueError("ais_source_list must be provided") + self._client = Client(client_url) + self._ais_source_list = ( + [ais_source_list] + if isinstance(ais_source_list, AISSource) + else ais_source_list + ) + self._prefix_map = prefix_map + self._samples = self._create_samples_list() + + def _get_sample_list_from_source(self, source: AISSource, prefix: str) -> List: + """ + Creates an list of samples from the AISSource and the objects stored within. Must be able to handle prefixes + as well. The default implementation returns an list of objects. This method can be overridden + to provides other functionality (such as reading the data and creating usable samples for different + file types). + + Args: + source (AISSource): AISSource (:class:`aistore.sdk.ais_source.AISSource`) provides an interface for accessing a list of + AIS objects or their URLs + prefix (str): Prefix to dictate what objects should be included + + Returns: + List: List over the content of the dataset + """ + return [obj for obj in source.list_all_objects_iter(prefix=prefix)] + + def _create_samples_list(self) -> List[Object]: + """ + Create a list of all the objects in the given URLs and AIS sources. + + Returns: + List[Object]: List of all the objects in the given URLs and AIS sources + """ + samples = [] + + for source in self._ais_source_list: + if source not in self._prefix_map or self._prefix_map[source] is None: + samples.extend(self._get_sample_list_from_source(source, "")) + else: + prefixes = ( + [self._prefix_map[source]] + if isinstance(self._prefix_map[source], str) + else self._prefix_map[source] + ) + for prefix in prefixes: + samples.extend(self._get_sample_list_from_source(source, prefix)) + + return samples + + @abstractmethod + def __getitem__(self, index): + pass diff --git a/python/aistore/pytorch/iter_dataset.py b/python/aistore/pytorch/iter_dataset.py index c49447d9a4..8867693aff 100644 --- a/python/aistore/pytorch/iter_dataset.py +++ b/python/aistore/pytorch/iter_dataset.py @@ -4,21 +4,21 @@ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """ -from aistore.pytorch.base_dataset import AISBaseClassIter -from torch.utils.data import IterableDataset -from typing import List, Union +from aistore.pytorch.base_iter_dataset import AISBaseIterDataset +from typing import List, Union, Dict from aistore.sdk.ais_source import AISSource -class AISIterDataset(AISBaseClassIter, IterableDataset): +class AISIterDataset(AISBaseIterDataset): """ An iterable-style dataset that iterates over objects in AIS. If `etl_name` is provided, that ETL must already exist on the AIStore cluster. Args: client_url (str): AIS endpoint URL - urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data + prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of AISSource objects to list of prefixes that only allows + objects with the specified prefixes to be used from each source etl_name (str, optional): Optional ETL on the AIS cluster to apply to each object Note: @@ -28,30 +28,22 @@ class AISIterDataset(AISBaseClassIter, IterableDataset): def __init__( self, client_url: str, - urls_list: Union[str, List[str]] = [], - ais_source_list: Union[AISSource, List[AISSource]] = [], + ais_source_list: Union[AISSource, List[AISSource]], + prefix_map: Dict[AISSource, Union[str, List[str]]] = {}, etl_name: str = None, ): - if not urls_list and not ais_source_list: - raise ValueError( - "At least one of urls_list or ais_source_list must be provided." - ) - super().__init__(client_url, urls_list, ais_source_list) - self.etl_name = etl_name - self.length = None + super().__init__(client_url, ais_source_list, prefix_map) + self._etl_name = etl_name + self._length = None def __iter__(self): self._reset_iterator() - self.length = 0 - for obj in self._object_iter: - self.length += 1 - yield obj.name, obj.get(etl_name=self.etl_name).read_all() + self._length = 0 + for obj in self._iterator: + yield obj.name, obj.get(etl_name=self._etl_name).read_all() def __len__(self): - if self.length is None: + if self._length is None: self._reset_iterator() - self.length = self._calculate_len() - return self.length - - def _calculate_len(self): - return sum(1 for _ in self._object_iter) + self._length = sum(1 for _ in self._iterator) + return self._length diff --git a/python/aistore/pytorch/dataset.py b/python/aistore/pytorch/map_dataset.py similarity index 56% rename from python/aistore/pytorch/dataset.py rename to python/aistore/pytorch/map_dataset.py index 843c3454fb..169bca4aad 100644 --- a/python/aistore/pytorch/dataset.py +++ b/python/aistore/pytorch/map_dataset.py @@ -4,21 +4,21 @@ Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. """ -from typing import List, Union -from torch.utils.data import Dataset +from typing import List, Union, Dict from aistore.sdk.ais_source import AISSource -from aistore.pytorch.base_dataset import AISBaseClass +from aistore.pytorch.base_map_dataset import AISBaseMapDataset -class AISDataset(AISBaseClass, Dataset): +class AISMapDataset(AISBaseMapDataset): """ A map-style dataset for objects in AIS. If `etl_name` is provided, that ETL must already exist on the AIStore cluster. Args: client_url (str): AIS endpoint URL - urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data + prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows + objects with the specified prefixes to be used from each source etl_name (str, optional): Optional ETL on the AIS cluster to apply to each object Note: @@ -28,21 +28,17 @@ class AISDataset(AISBaseClass, Dataset): def __init__( self, client_url: str, - urls_list: Union[str, List[str]] = [], ais_source_list: Union[AISSource, List[AISSource]] = [], + prefix_map: Dict[AISSource, Union[str, List[str]]] = {}, etl_name: str = None, ): - if not urls_list and not ais_source_list: - raise ValueError( - "At least one of urls_list or ais_source_list must be provided" - ) - super().__init__(client_url, urls_list, ais_source_list) - self.etl_name = etl_name + super().__init__(client_url, ais_source_list, prefix_map) + self._etl_name = etl_name def __len__(self): - return len(self._objects) + return len(self._samples) def __getitem__(self, index: int): - obj = self._objects[index] - content = obj.get(etl_name=self.etl_name).read_all() + obj = self._samples[index] + content = obj.get(etl_name=self._etl_name).read_all() return obj.name, content diff --git a/python/aistore/pytorch/multishard_dataset.py b/python/aistore/pytorch/multishard_dataset.py index 759c38c355..742fe6e728 100644 --- a/python/aistore/pytorch/multishard_dataset.py +++ b/python/aistore/pytorch/multishard_dataset.py @@ -4,10 +4,12 @@ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """ -from torch.utils.data import IterableDataset from aistore.sdk.dataset.data_shard import DataShard -from typing import Iterator, List -from aistore.pytorch.utils import list_shard_objects_iterator +from aistore.sdk import Bucket +from typing import Iterator, List, Iterable +from aistore.sdk.list_object_flag import ListObjectFlag +from aistore.sdk.types import ArchiveSettings +from torch.utils.data import IterableDataset class AISMultiShardStream(IterableDataset): @@ -23,11 +25,42 @@ class AISMultiShardStream(IterableDataset): """ def __init__(self, data_sources: List[DataShard]): - self.data_sources = data_sources + self._data_sources = data_sources def __iter__(self) -> Iterator: data_iterators = ( - list_shard_objects_iterator(ds.bucket, ds.prefix, ds.etl_name) - for ds in self.data_sources + self._get_shard_objects_iterator(ds.bucket, ds.prefix, ds.etl_name) + for ds in self._data_sources ) return zip(*data_iterators) + + def _get_shard_objects_iterator( + self, bucket: Bucket, prefix: str = "", etl_name: str = "" + ) -> Iterable[bytes]: + """ + Create an iterable over all the objects in the given shards. + + Args: + bucket (Bucket): Bucket containing the shards + prefix (str): Prefix of the object names + etl_name (str): ETL name to apply on each object + + Returns: + Iterable[Object]: Iterable over all the objects in the given shards, + with each iteration returning a combined sample + """ + shards_iter = bucket.list_objects_iter(prefix=prefix, props="name") + + for shard in shards_iter: + path = shard.name + objects_iter = bucket.list_objects_iter( + prefix=path, props="name", flags=[ListObjectFlag.ARCH_DIR] + ) + + for obj in objects_iter: + if obj.name != path: + obj_name = obj.name.replace(f"{path}/", "", 1) + yield bucket.object(path).get( + etl_name=etl_name, + archive_settings=ArchiveSettings(archpath=obj_name), + ).read_all() diff --git a/python/aistore/pytorch/shard_reader.py b/python/aistore/pytorch/shard_reader.py index 39d98d5c4d..8980c9f3dc 100644 --- a/python/aistore/pytorch/shard_reader.py +++ b/python/aistore/pytorch/shard_reader.py @@ -7,20 +7,22 @@ """ from aistore.sdk.bucket import Bucket -from torch.utils.data import IterableDataset -from typing import Iterator, List, Union -from aistore.pytorch.utils import list_wds_samples_iter -from aistore.sdk import Client +from typing import Dict, Iterator, List, Union, Iterable +from aistore.sdk.list_object_flag import ListObjectFlag +from aistore.pytorch.utils import get_basename +from aistore.sdk.types import ArchiveSettings +from aistore.pytorch.base_iter_dataset import AISBaseIterDataset -class AISShardReader(IterableDataset): +class AISShardReader(AISBaseIterDataset): """ An iterable-style dataset that iterates over objects stored as Webdataset shards. Args: client_url (str): AIS endpoint URL - urls_list (Union[str, List[str]]): Single or list of URLs, can be URLS for buckets and/or objects bucket_list (Union[Bucket, List[Bucket]]): Single or list of Bucket objects to load data + prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of Bucket objects to list of prefixes that only allows + objects with the specified prefixes to be used from each source etl_name (str, optional): Optional ETL on the AIS cluster to apply to each object Yields: @@ -31,47 +33,64 @@ class AISShardReader(IterableDataset): def __init__( self, client_url: str, - urls_list: Union[str, List[str]] = [], - bucket_list: Union[Bucket, List[Bucket]] = [], + bucket_list: Union[Bucket, List[Bucket]], + prefix_map: Dict[Bucket, Union[str, List[str]]] = {}, etl_name: str = None, ): - if not urls_list and not bucket_list: - raise ValueError( - "At least one of urls_list or bucket_list must be provided" + super().__init__(client_url, bucket_list, prefix_map) + self._etl_name = etl_name + self._length = None + + def _get_sample_iter_from_source(self, source: Bucket, prefix: str) -> Iterable: + """ + Creates an iterable for all samples and contents over each shard from a bucket. + + Args: + name (str): Name of shard object + source (Bucket): Bucket where the shard object is stored + prefix (str): Prefix of objects in bucket which are shards + + Returns: + Iterable[Tuple[str, dict(str, bytes)]]: Iterator over all the WDS basenames and content (file extension, data) + in shards from the given shard + """ + for entry in source.list_objects_iter(prefix=prefix): + # get iterator of all objects in the shard + objects_iter = source.list_objects_iter( + prefix=entry.name, props="name", flags=[ListObjectFlag.ARCH_DIR] ) - self.client = Client(client_url) - self.urls_list = [urls_list] if isinstance(urls_list, str) else urls_list - self.bucket_list = ( - [bucket_list] if isinstance(bucket_list, Bucket) else bucket_list - ) - self.etl_name = etl_name - self.length = None - self._reset_iterator() + # pool all files with the same basename into dictionary (basename, [file names]) + samples_dict = {} + for obj in objects_iter: + basename = get_basename(obj.name) + + # Original tar is included in basenames so only yield actual files + if basename != entry.name.split(".")[0]: + if basename not in samples_dict: + samples_dict[basename] = [] + samples_dict[basename].append(obj.name) + + # for each basename, get the byte data for each file and yield in dictionary + shard = source.object(entry.name) + for basename, files in samples_dict.items(): + content_dict = {} + for file_name in files: + file_prefix = file_name.split(".")[-1] + content_dict[file_prefix] = shard.get( + etl_name=self._etl_name, + archive_settings=ArchiveSettings(archpath=file_name), + ).read_all() + self._length += 1 + yield basename, content_dict def __iter__(self) -> Iterator: self._reset_iterator() - self.length = 0 - for basename, content_dict in self._samples_iter: - self.length += 1 - yield basename, content_dict - - def _reset_iterator(self): - """ - Reset the iterator to start from the beginning - """ - self._samples_iter = list_wds_samples_iter( - client=self.client, - urls_list=self.urls_list, - bucket_list=self.bucket_list, - etl_name=self.etl_name, - ) + self._length = 0 + yield from self._iterator def __len__(self): - if self.length is None: + if self._length is None: + self._length = 0 self._reset_iterator() - self.length = self._calculate_len() - return self.length - - def _calculate_len(self): - return sum(1 for _ in self._samples_iter) + return self._length diff --git a/python/aistore/pytorch/utils.py b/python/aistore/pytorch/utils.py index db97fd074b..fde07215f9 100644 --- a/python/aistore/pytorch/utils.py +++ b/python/aistore/pytorch/utils.py @@ -4,44 +4,11 @@ Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. """ -from typing import List, Iterable, Tuple from urllib.parse import urlunparse -from aistore.sdk import Client -from aistore.sdk.ais_source import AISSource -from aistore.sdk.list_object_flag import ListObjectFlag -from aistore.sdk.object import Object -from aistore.sdk.bucket import Bucket -from aistore.sdk.types import ArchiveSettings +from typing import Tuple from aistore.sdk.utils import parse_url as sdk_parse_url -def list_objects( - client: Client, urls_list: List[str], ais_source_list: List[AISSource] -) -> List[Object]: - """ - Create a list of all the objects in the given URLs and AIS sources. - - Args: - client (Client): AIStore client object - urls_list (List[str]): List of URLs - ais_source_list (List[AISSource]): List of AISSource objects to load data - - Returns: - List[Object]: List of all the objects in the given URLs and AIS sources - """ - samples = [] - - for url in urls_list: - provider, bck_name, path = parse_url(url) - bucket = client.bucket(bck_name=bck_name, provider=provider) - samples.extend([obj for obj in bucket.list_all_objects_iter(prefix=path)]) - - for source in ais_source_list: - samples.extend([obj.name for obj in source.list_all_objects_iter()]) - - return samples - - def unparse_url(provider: str, bck_name: str, obj_name: str) -> str: """ Generate URL based on provider, bucket name, and object name. @@ -57,61 +24,6 @@ def unparse_url(provider: str, bck_name: str, obj_name: str) -> str: return urlunparse([provider, bck_name, obj_name, "", "", ""]) -def list_objects_iterator( - client: Client, urls_list: List[str] = [], ais_source_list: List[AISSource] = [] -) -> Iterable[Object]: - """ - Create an iterable over all the objects in the given URLs and AIS sources. - - Args: - client (Client): AIStore client object - urls_list (List[str]): List of URLs - ais_source_list (List[AISSource]): List of AISSource objects to load data - - Returns: - Iterable[Object]: Iterable over all the objects in the given URLs and AIS sources - """ - for url in urls_list: - provider, bck_name, path = parse_url(url) - bucket = client.bucket(bck_name=bck_name, provider=provider) - yield from bucket.list_all_objects_iter(prefix=path) - - for source in ais_source_list: - yield from source.list_all_objects_iter() - - -def list_shard_objects_iterator( - bucket: Bucket, prefix: str = "", etl_name: str = "" -) -> Iterable[bytes]: - """ - Create an iterable over all the objects in the given shards. - - Args: - bucket (Bucket): Bucket containing the shards - prefix (str): Prefix of the object names - etl_name (str): ETL name to apply on each object - - Returns: - Iterable[Object]: Iterable over all the objects in the given shards, - with each iteration returning a combined sample - """ - shards_iter = bucket.list_objects_iter(prefix=prefix, props="name") - - for shard in shards_iter: - path = shard.name - objects_iter = bucket.list_objects_iter( - prefix=path, props="name", flags=[ListObjectFlag.ARCH_DIR] - ) - - for obj in objects_iter: - if obj.name != path: - obj_name = obj.name.replace(f"{path}/", "", 1) - yield bucket.object(path).get( - etl_name=etl_name, - archive_settings=ArchiveSettings(archpath=obj_name), - ).read_all() - - def get_basename(name: str) -> str: """ Get the basename of the object name by stripping any directory information and suffix. @@ -126,81 +38,6 @@ def get_basename(name: str) -> str: return name.split("/")[-1].split(".")[0] -def __samples_from_bck_iter(shard_name: str, bucket: Bucket, etl_name: str): - """ - Helper function to create an iterator for all samples and contents over the given shard name. - - Args: - name (str): Name of shard object - bucket (Bucket): Bucket where the shard object is stored - etl_name (str): Name of ETL (Extract, Transform, Load) to apply to each object - - Returns: - Iterable[Tuple[str, dict(str, bytes)]]: Iterator over all the WDS basenames and content (file extension, data) - in shards from the given shard - """ - # get iterator of all objects in the shard - objects_iter = bucket.list_objects_iter( - prefix=shard_name, props="name", flags=[ListObjectFlag.ARCH_DIR] - ) - - # pool all files with the same basename into dictionary (basename, [file names]) - samples_dict = {} - for obj in objects_iter: - basename = get_basename(obj.name) - - # Original tar is included in basenames so only yield actual files - if basename != shard_name.split(".")[0]: - if basename not in samples_dict: - samples_dict[basename] = [] - samples_dict[basename].append(obj.name) - - # for each basename, get the byte data for each file and yield in dictionary - shard = bucket.object(shard_name) - for basename, files in samples_dict.items(): - content_dict = {} - for file_name in files: - file_prefix = file_name.split(".")[-1] - content_dict[file_prefix] = shard.get( - etl_name=etl_name, archive_settings=ArchiveSettings(archpath=file_name) - ).read_all() - yield basename, content_dict - - -def list_wds_samples_iter( - client: Client, - urls_list: List[str], - bucket_list: List[Bucket], - etl_name: str, -) -> Iterable[Tuple[str, bytes]]: - """ - Create an iterator over all of the shard sample basenames and sample contents. - - Args: - client (Client): AIStore Client for accessing buckets and objects - urls_list (List[str]): List of URLs, can be URLS for buckets and/or objects - bucket_list (List[Bucket]): List of Bucket objects containing the shards to load data - etl_name (str): Name of ETL (Extract, Transform, Load) to apply to each object - - Returns: - Iterable[Tuple[str, dict(str, bytes)]]: Iterator over all the WDS basenames and content (file extension, data) - in shards from the given URLs and buckets - """ - - for item in urls_list: - provider, bck_name, path = parse_url(item) - bucket = client.bucket(bck_name=bck_name, provider=provider) - if path == None or path == "": - for shard in bucket.list_objects_iter(): - yield from __samples_from_bck_iter(shard.name, bucket, etl_name) - else: - yield from __samples_from_bck_iter(path, bucket, etl_name) - - for bucket in bucket_list: - for shard in bucket.list_objects_iter(): - yield from __samples_from_bck_iter(shard.name, bucket, etl_name) - - def parse_url(url: str) -> Tuple[str, str, str]: """ Wrapper of sdk/utils.py parse_url. Parse AIS URLs for bucket and object names. diff --git a/python/pyaisloader/pyaisloader/pytorch_benchmark.py b/python/pyaisloader/pyaisloader/pytorch_benchmark.py index f70515f33c..4c8b3c27af 100644 --- a/python/pyaisloader/pyaisloader/pytorch_benchmark.py +++ b/python/pyaisloader/pyaisloader/pytorch_benchmark.py @@ -11,8 +11,7 @@ from pyaisloader.client_config import ENDPOINT from pyaisloader.benchmark import PutGetMixedBenchmark, BenchmarkStats -from aistore.pytorch.dataset import AISDataset -from aistore.pytorch.iter_dataset import AISIterDataset +from aistore.pytorch import AISMapDataset, AISIterDataset class AISDatasetBenchmark(PutGetMixedBenchmark): @@ -32,7 +31,7 @@ def run(self): print_results(result, title=self.__class__.__name__) def get_benchmark(self, duration): - dataset = AISDataset( + dataset = AISMapDataset( client_url=ENDPOINT, urls_list=f"{self.bucket.provider}://{self.bucket.name}", ) diff --git a/python/tests/integration/pytorch/test_pytorch_plugin.py b/python/tests/integration/pytorch/test_pytorch_plugin.py index b8c5e3dd4c..2ba03a4909 100644 --- a/python/tests/integration/pytorch/test_pytorch_plugin.py +++ b/python/tests/integration/pytorch/test_pytorch_plugin.py @@ -12,11 +12,11 @@ from aistore.pytorch import ( AISFileLister, AISFileLoader, - AISDataset, + AISMapDataset, AISIterDataset, AISMultiShardStream, + AISShardReader, ) -from aistore.pytorch.shard_reader import AISShardReader from tests.integration import CLUSTER_ENDPOINT from tests.utils import ( create_and_put_object, @@ -36,7 +36,8 @@ class TestPytorchPlugin(unittest.TestCase): def setUp(self) -> None: self.bck_name = random_string() self.client = Client(CLUSTER_ENDPOINT) - self.client.bucket(self.bck_name).create() + self.bck = self.client.bucket(self.bck_name) + self.bck.create() self.local_test_files = ( Path().absolute().joinpath("pytorch-plugin-test-" + random_string(8)) ) @@ -113,8 +114,8 @@ def test_ais_dataset(self): ) content_dict[i] = content - ais_dataset = AISDataset( - client_url=CLUSTER_ENDPOINT, urls_list=["ais://" + self.bck_name] + ais_dataset = AISMapDataset( + client_url=CLUSTER_ENDPOINT, ais_source_list=[self.bck] ) self.assertEqual(len(ais_dataset), num_objs) for i in range(num_objs): @@ -132,7 +133,7 @@ def test_ais_iter_dataset(self): content_dict[i] = content ais_iter_dataset = AISIterDataset( - client_url=CLUSTER_ENDPOINT, urls_list=["ais://" + self.bck_name] + client_url=CLUSTER_ENDPOINT, ais_source_list=self.bck ) self.assertEqual(len(ais_iter_dataset), num_objs) for i, (obj_name, content) in enumerate(ais_iter_dataset): @@ -243,11 +244,12 @@ def test_shard_reader(self): sample_basenames = ["sample_1", "sample_2", "sample_3", "sample_4"] - # Test shard_reader with url params - url_one = f"{bucket.provider}://{bucket.name}/{shard_one_obj.name}" - url_two = f"{bucket.provider}://{bucket.name}/{shard_two_obj.name}" + # Test shard_reader with prefixes + url_shard_reader = AISShardReader( - client_url=CLUSTER_ENDPOINT, urls_list=[url_one, url_two] + client_url=CLUSTER_ENDPOINT, + bucket_list=[bucket], + prefix_map={bucket: "shard_1.tar"}, ) for i, (basename, content_dict) in enumerate(url_shard_reader): diff --git a/python/tests/unit/pytorch/test_datasets.py b/python/tests/unit/pytorch/test_datasets.py index 4eeacedfaa..9e9c56db41 100644 --- a/python/tests/unit/pytorch/test_datasets.py +++ b/python/tests/unit/pytorch/test_datasets.py @@ -5,12 +5,12 @@ import unittest from unittest.mock import patch, Mock, MagicMock -from aistore.pytorch.dataset import AISDataset +import unittest.mock +from aistore.pytorch.map_dataset import AISMapDataset from aistore.pytorch.iter_dataset import AISIterDataset from aistore.pytorch.multishard_dataset import AISMultiShardStream from aistore.pytorch.shard_reader import AISShardReader -from aistore.pytorch.utils import list_wds_samples_iter -from aistore.sdk.list_object_flag import ListObjectFlag +from aistore.sdk import Bucket class TestAISDataset(unittest.TestCase): @@ -22,38 +22,50 @@ def setUp(self) -> None: mock_obj, mock_obj, ] + self.mock_bck = Mock(Bucket) - self.patcher_list_objects_iterator = patch( - "aistore.pytorch.base_dataset.list_objects_iterator", + self.patcher_get_objects_iterator = patch( + "aistore.pytorch.base_iter_dataset.AISBaseIterDataset._create_samples_iter", return_value=iter(self.mock_objects), ) - self.patcher_list_objects = patch( - "aistore.pytorch.base_dataset.list_objects", return_value=self.mock_objects + self.patcher_get_objects = patch( + "aistore.pytorch.base_map_dataset.AISBaseMapDataset._create_samples_list", + return_value=self.mock_objects, ) - self.patcher_client = patch( - "aistore.pytorch.base_dataset.Client", return_value=self.mock_client + self.patcher_client_map = patch( + "aistore.pytorch.base_map_dataset.Client", return_value=self.mock_client ) - self.patcher_list_objects_iterator.start() - self.patcher_list_objects.start() - self.patcher_client.start() + self.patcher_client_iter = patch( + "aistore.pytorch.base_iter_dataset.Client", return_value=self.mock_client + ) + self.patcher_get_objects_iterator.start() + self.patcher_get_objects.start() + self.patcher_client_map.start() + self.patcher_client_iter.start() def tearDown(self) -> None: - self.patcher_list_objects_iterator.stop() - self.patcher_list_objects.stop() - self.patcher_client.stop() + self.patcher_get_objects_iterator.stop() + self.patcher_get_objects.stop() + self.patcher_client_map.stop() + self.patcher_client_iter.stop() def test_map_dataset(self): - ais_dataset = AISDataset(client_url="mock_client_url", urls_list="ais://test") - self.assertIsNone(ais_dataset.etl_name) + self.mock_bck.list_all_objects_iter.return_value = iter(self.mock_objects) + + ais_dataset = AISMapDataset( + client_url="mock_client_url", ais_source_list=self.mock_bck + ) + + self.assertIsNone(ais_dataset._etl_name) self.assertEqual(len(ais_dataset), 2) self.assertEqual(ais_dataset[0][1], b"mock data") def test_iter_dataset(self): ais_iter_dataset = AISIterDataset( - client_url="mock_client_url", urls_list="ais://test" + client_url="mock_client_url", ais_source_list=self.mock_bck ) - self.assertIsNone(ais_iter_dataset.etl_name) + self.assertIsNone(ais_iter_dataset._etl_name) self.assertEqual(len(ais_iter_dataset), 2) @@ -61,15 +73,15 @@ def test_iter_dataset(self): self.assertEqual(obj, b"mock data") def test_multi_shard_stream(self): - self.patcher = unittest.mock.patch( - "aistore.pytorch.multishard_dataset.list_shard_objects_iterator" + self.patcher = patch( + "aistore.pytorch.AISMultiShardStream._get_shard_objects_iterator" ) - self.mock_list_shard_objects_iterator = self.patcher.start() + self.mock_get_shard_objects_iterator = self.patcher.start() self.data1 = iter([b"data1_1", b"data1_2", b"data1_3"]) self.data2 = iter([b"data2_1", b"data2_2", b"data2_3"]) self.data3 = iter([b"data3_1", b"data3_2", b"data3_3"]) - self.mock_list_shard_objects_iterator.side_effect = [ + self.mock_get_shard_objects_iterator.side_effect = [ self.data1, self.data2, self.data3, @@ -89,29 +101,14 @@ def test_multi_shard_stream(self): self.assertEqual(results, expected_results) - @patch("aistore.sdk.bucket") - def test_list_wds_sample_iter(self, mock_bucket): - # Mock the list_objects_iter method - mock_shard = Mock() - mock_shard.name = "shard.tar" - mock_shard.list_objects_iter.return_value = iter([]) - mock_bucket.list_objects_iter.return_value = [mock_shard] + self.patcher.stop() - # Call the function under test - list(list_wds_samples_iter(None, [], [mock_bucket], None)) + def test_shard_reader(self): + # Mock get_wds_samples_iter + self.patcher = patch("aistore.pytorch.AISShardReader._create_samples_iter") + mock_create_samples_iter = self.patcher.start() - # Assert that list_objects_iter was called exactly once for the shard - mock_bucket.list_objects_iter.assert_any_call() - - # Assert that list_objects_iter was called exactly once with arch params - mock_bucket.list_objects_iter.assert_called_with( - prefix=mock_shard.name, props="name", flags=[ListObjectFlag.ARCH_DIR] - ) - - @patch("aistore.pytorch.shard_reader.list_wds_samples_iter") - def test_shard_reader(self, mock_list_wds_samples_iter): - # Mock list_wds_samples_iter - mock_list_wds_samples_iter.return_value = [ + mock_create_samples_iter.return_value = [ ("sample_1", {"cls": b"Content of class"}), ("sample_2", {"png": b"Content of class"}), ("sample_3", {"jpg": b"Content of class"}), @@ -119,7 +116,7 @@ def test_shard_reader(self, mock_list_wds_samples_iter): # Create shard reader and get results and compare shard_reader = AISShardReader( - client_url="http://example.com", urls_list="http://example.com/data" + client_url="http://example.com", bucket_list=self.mock_bck ) result = list(shard_reader) @@ -133,4 +130,6 @@ def test_shard_reader(self, mock_list_wds_samples_iter): self.assertEqual(result, expected_result) # Ensure the iter is called correctly - mock_list_wds_samples_iter.assert_called() + mock_create_samples_iter.assert_called() + + self.patcher.stop()