Skip to content

Commit

Permalink
python/pytorch: Refactor datasets and utils
Browse files Browse the repository at this point in the history
Signed-off-by: Soham Manoli <[email protected]>
  • Loading branch information
msoham123 committed Jul 2, 2024
1 parent a53d3b0 commit cef15ae
Show file tree
Hide file tree
Showing 12 changed files with 370 additions and 375 deletions.
4 changes: 3 additions & 1 deletion python/aistore/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
AISSourceLister,
)

from aistore.pytorch.dataset import AISDataset
from aistore.pytorch.map_dataset import AISMapDataset
from aistore.pytorch.multishard_dataset import AISMultiShardStream
from aistore.pytorch.iter_dataset import AISIterDataset
from aistore.pytorch.shard_reader import AISShardReader
from aistore.pytorch.base_map_dataset import AISBaseMapDataset
from aistore.pytorch.base_iter_dataset import AISBaseIterDataset
66 changes: 0 additions & 66 deletions python/aistore/pytorch/base_dataset.py

This file was deleted.

94 changes: 94 additions & 0 deletions python/aistore/pytorch/base_iter_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Base class for AIS Iterable Style Datasets
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from typing import List, Union, Iterable, Dict, Iterator
from aistore.sdk.ais_source import AISSource
from aistore.sdk import Client
from torch.utils.data import IterableDataset
from abc import ABC, abstractmethod


class AISBaseIterDataset(ABC, IterableDataset):
"""
A base class for creating AIS Iterable Datasets. Should not be instantiated directly. Subclasses
should implement :meth:`__iter__` which returns the samples from the dataset and can optionally
override other methods from torch IterableDataset such as :meth:`__len__`. Additionally,
to modify the behavior of loading samples from a source, override :meth:`_get_sample_iter_from_source`.
Args:
client_url (str): AIS endpoint URL
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows
objects with the specified prefixes to be used from each source
"""

def __init__(
self,
client_url: str,
ais_source_list: Union[AISSource, List[AISSource]],
prefix_map: Dict[AISSource, Union[str, List[str]]] = {},
) -> None:
if not ais_source_list:
raise ValueError("ais_source_list must be provided")
self._client = Client(client_url)
self._ais_source_list = (
[ais_source_list]
if isinstance(ais_source_list, AISSource)
else ais_source_list
)
self._prefix_map = prefix_map
self._iterator = None
self._reset_iterator()

def _get_sample_iter_from_source(self, source: AISSource, prefix: str) -> Iterable:
"""
Creates an iterable of samples from the AISSource and the objects stored within. Must be able to handle prefixes
as well. The default implementation returns an iterable of Objects. This method can be overridden
to provides other functionality (such as reading the data and creating usable samples for different
file types).
Args:
source (AISSource): AISSource (:class:`aistore.sdk.ais_source.AISSource`) provides an interface for accessing a list of
AIS objects or their URLs
prefix (str): Prefix to dictate what objects should be included
Returns:
Iterable: Iterable over the content of the dataset
"""
yield from source.list_all_objects_iter(prefix=prefix)

def _create_samples_iter(self) -> Iterable:
"""
Create an iterable given the AIS sources and associated prefixes.
Returns:
Iterable: Iterable over the samples of the dataset
"""
for source in self._ais_source_list:
if source not in self._prefix_map or self._prefix_map[source] is None:
yield from self._get_sample_iter_from_source(source, "")
else:
prefixes = (
[self._prefix_map[source]]
if isinstance(self._prefix_map[source], str)
else self._prefix_map[source]
)
for prefix in prefixes:
yield from self._get_sample_iter_from_source(source, prefix)

@abstractmethod
def __iter__(self) -> Iterator:
"""
Return iterator with samples in this dataset.
Returns:
Iterator: Iterator of samples
"""
pass

def _reset_iterator(self):
"""Reset the iterator to start from the beginning"""
self._iterator = self._create_samples_iter()
88 changes: 88 additions & 0 deletions python/aistore/pytorch/base_map_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Base class for AIS Map Style Datasets
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from typing import List, Union, Dict
from aistore.sdk.ais_source import AISSource
from aistore.sdk import Client
from aistore.sdk.object import Object
from torch.utils.data import Dataset
from abc import ABC, abstractmethod


class AISBaseMapDataset(ABC, Dataset):
"""
A base class for creating map-style AIS Datasets. Should not be instantiated directly. Subclasses
should implement :meth:`__getitem__` which fetches a samples given a key from the dataset and can optionally
override other methods from torch IterableDataset such as :meth:`__len__`. Additionally,
to modify the behavior of loading samples from a source, override :meth:`_get_sample_list_from_source`.
Args:
client_url (str): AIS endpoint URL
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows
objects with the specified prefixes to be used from each source
"""

def __init__(
self,
client_url: str,
ais_source_list: Union[AISSource, List[AISSource]],
prefix_map: Dict[AISSource, Union[str, List[str]]] = {},
) -> None:
if not ais_source_list:
raise ValueError("ais_source_list must be provided")
self._client = Client(client_url)
self._ais_source_list = (
[ais_source_list]
if isinstance(ais_source_list, AISSource)
else ais_source_list
)
self._prefix_map = prefix_map
self._samples = self._create_samples_list()

def _get_sample_list_from_source(self, source: AISSource, prefix: str) -> List:
"""
Creates an list of samples from the AISSource and the objects stored within. Must be able to handle prefixes
as well. The default implementation returns an list of objects. This method can be overridden
to provides other functionality (such as reading the data and creating usable samples for different
file types).
Args:
source (AISSource): AISSource (:class:`aistore.sdk.ais_source.AISSource`) provides an interface for accessing a list of
AIS objects or their URLs
prefix (str): Prefix to dictate what objects should be included
Returns:
List: List over the content of the dataset
"""
return [obj for obj in source.list_all_objects_iter(prefix=prefix)]

def _create_samples_list(self) -> List[Object]:
"""
Create a list of all the objects in the given URLs and AIS sources.
Returns:
List[Object]: List of all the objects in the given URLs and AIS sources
"""
samples = []

for source in self._ais_source_list:
if source not in self._prefix_map or self._prefix_map[source] is None:
samples.extend(self._get_sample_list_from_source(source, ""))
else:
prefixes = (
[self._prefix_map[source]]
if isinstance(self._prefix_map[source], str)
else self._prefix_map[source]
)
for prefix in prefixes:
samples.extend(self._get_sample_list_from_source(source, prefix))

return samples

@abstractmethod
def __getitem__(self, index):
pass
40 changes: 16 additions & 24 deletions python/aistore/pytorch/iter_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from aistore.pytorch.base_dataset import AISBaseClassIter
from torch.utils.data import IterableDataset
from typing import List, Union
from aistore.pytorch.base_iter_dataset import AISBaseIterDataset
from typing import List, Union, Dict
from aistore.sdk.ais_source import AISSource


class AISIterDataset(AISBaseClassIter, IterableDataset):
class AISIterDataset(AISBaseIterDataset):
"""
An iterable-style dataset that iterates over objects in AIS.
If `etl_name` is provided, that ETL must already exist on the AIStore cluster.
Args:
client_url (str): AIS endpoint URL
urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of AISSource objects to list of prefixes that only allows
objects with the specified prefixes to be used from each source
etl_name (str, optional): Optional ETL on the AIS cluster to apply to each object
Note:
Expand All @@ -28,30 +28,22 @@ class AISIterDataset(AISBaseClassIter, IterableDataset):
def __init__(
self,
client_url: str,
urls_list: Union[str, List[str]] = [],
ais_source_list: Union[AISSource, List[AISSource]] = [],
ais_source_list: Union[AISSource, List[AISSource]],
prefix_map: Dict[AISSource, Union[str, List[str]]] = {},
etl_name: str = None,
):
if not urls_list and not ais_source_list:
raise ValueError(
"At least one of urls_list or ais_source_list must be provided."
)
super().__init__(client_url, urls_list, ais_source_list)
self.etl_name = etl_name
self.length = None
super().__init__(client_url, ais_source_list, prefix_map)
self._etl_name = etl_name
self._length = None

def __iter__(self):
self._reset_iterator()
self.length = 0
for obj in self._object_iter:
self.length += 1
yield obj.name, obj.get(etl_name=self.etl_name).read_all()
self._length = 0
for obj in self._iterator:
yield obj.name, obj.get(etl_name=self._etl_name).read_all()

def __len__(self):
if self.length is None:
if self._length is None:
self._reset_iterator()
self.length = self._calculate_len()
return self.length

def _calculate_len(self):
return sum(1 for _ in self._object_iter)
self._length = sum(1 for _ in self._iterator)
return self._length
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
"""

from typing import List, Union
from torch.utils.data import Dataset
from typing import List, Union, Dict
from aistore.sdk.ais_source import AISSource
from aistore.pytorch.base_dataset import AISBaseClass
from aistore.pytorch.base_map_dataset import AISBaseMapDataset


class AISDataset(AISBaseClass, Dataset):
class AISMapDataset(AISBaseMapDataset):
"""
A map-style dataset for objects in AIS.
If `etl_name` is provided, that ETL must already exist on the AIStore cluster.
Args:
client_url (str): AIS endpoint URL
urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows
objects with the specified prefixes to be used from each source
etl_name (str, optional): Optional ETL on the AIS cluster to apply to each object
Note:
Expand All @@ -28,21 +28,17 @@ class AISDataset(AISBaseClass, Dataset):
def __init__(
self,
client_url: str,
urls_list: Union[str, List[str]] = [],
ais_source_list: Union[AISSource, List[AISSource]] = [],
prefix_map: Dict[AISSource, Union[str, List[str]]] = {},
etl_name: str = None,
):
if not urls_list and not ais_source_list:
raise ValueError(
"At least one of urls_list or ais_source_list must be provided"
)
super().__init__(client_url, urls_list, ais_source_list)
self.etl_name = etl_name
super().__init__(client_url, ais_source_list, prefix_map)
self._etl_name = etl_name

def __len__(self):
return len(self._objects)
return len(self._samples)

def __getitem__(self, index: int):
obj = self._objects[index]
content = obj.get(etl_name=self.etl_name).read_all()
obj = self._samples[index]
content = obj.get(etl_name=self._etl_name).read_all()
return obj.name, content
Loading

0 comments on commit cef15ae

Please sign in to comment.