diff --git a/python/aistore/pytorch/__init__.py b/python/aistore/pytorch/__init__.py index 8c4b80634a..aaa17d4c60 100644 --- a/python/aistore/pytorch/__init__.py +++ b/python/aistore/pytorch/__init__.py @@ -10,3 +10,4 @@ from aistore.pytorch.shard_reader import AISShardReader from aistore.pytorch.base_map_dataset import AISBaseMapDataset from aistore.pytorch.base_iter_dataset import AISBaseIterDataset +from aistore.pytorch.dynamic_sampler import DynamicBatchSampler diff --git a/python/aistore/pytorch/base_map_dataset.py b/python/aistore/pytorch/base_map_dataset.py index 08cebe3c92..2105ea1ee7 100644 --- a/python/aistore/pytorch/base_map_dataset.py +++ b/python/aistore/pytorch/base_map_dataset.py @@ -16,7 +16,7 @@ class AISBaseMapDataset(ABC, Dataset): """ A base class for creating map-style AIS Datasets. Should not be instantiated directly. Subclasses should implement :meth:`__getitem__` which fetches a samples given a key from the dataset and can optionally - override other methods from torch IterableDataset such as :meth:`__len__`. Additionally, + override other methods from torch Dataset such as :meth:`__len__` and :meth:`__getitems__`. Additionally, to modify the behavior of loading samples from a source, override :meth:`_get_sample_list_from_source`. Args: @@ -87,3 +87,6 @@ def _create_samples_list(self) -> List[Object]: @abstractmethod def __getitem__(self, index): pass + + def __getitems__(self, indices: List[int]): + return [self.__getitem__(index) for index in indices] diff --git a/python/aistore/pytorch/dynamic_sampler.py b/python/aistore/pytorch/dynamic_sampler.py new file mode 100644 index 0000000000..a9f7b1663c --- /dev/null +++ b/python/aistore/pytorch/dynamic_sampler.py @@ -0,0 +1,102 @@ +""" +Dynamic Batch Sampler for Dynamic Batch Sizing + +In scenarios where memory is a constraint, the DynamicBatchSampler +can be used to generate mini-batches that fit within a memory constraint +so that there is a guarantee that each batch fits within memory +while attempting to fit the maximum number of samples in each batch. + +Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +""" + +from torch.utils.data import Sampler +from typing import Iterator, List +from aistore.pytorch.base_map_dataset import AISBaseMapDataset +from logging import getLogger + +# Saturation of a batch needed to not be dropped with drop_last=True +SATURATION_FACTOR = 0.8 + + +class DynamicBatchSampler(Sampler): + """ + + Dynamically adds samples to mini-batch up to a maximum batch size. + + NOTE: Using this sampler with AISBaseMapDatasets that use ObjectGroups + in their ais_source_lists will be slower than using it with Buckets as + ObjectGroups will perform one extra API call per object to get size metadata. + + Args: + data_source (AISBaseMapDataset): Base AIS map-style dataset to sample from to create dynamic mini-batches. + max_batch_size (float): Maximum size of mini-batch in bytes. + drop_last (bool, optional): If `True`, then will drop last batch if the batch is not atleast 80% of `max_batch_size`. + Defaults to `False`. + allow_oversized_samples (bool, optional): If `True`, then any sample that is larger than the `max_batch_size` will be processed + in its own min-batch by itself instead of being dropped. Defaults to `False`. + """ + + def __init__( + self, + data_source: AISBaseMapDataset, + max_batch_size: float, + drop_last: bool = False, + allow_oversized_samples: bool = False, + ) -> None: + self._data_source = data_source + self._max_batch_size = max_batch_size + self._samples_list = data_source._create_samples_list() + self._drop_last = drop_last + self._allow_oversized_samples = allow_oversized_samples + self._logger = getLogger(f"{__name__}.put_files") + + def __iter__(self) -> Iterator[List[int]]: + """ + Returns an iterator containing mini-batches (lists of indices). + """ + total_mem = 0 + batch = [] + index = 0 + + # get sample size for each index, check if there is space in the batch, and yield batches whenever full + # calculate spaces in batch non-preemptively + while index < len(self): + sample = self._samples_list[index] + + if sample.size > self._max_batch_size: + if self._allow_oversized_samples is True: + yield [index] + else: + self._logger.warn( + f"Sample {sample.name} cannot be processed as it is larger than the max batch size: {sample.size} bytes > {self._max_batch_size} bytes" + ) + + index += 1 + continue + + if total_mem + sample.size < self._max_batch_size: + batch.append(index) + index += 1 + total_mem += sample.size + else: + + if total_mem + sample.size == self._max_batch_size: + batch.append(index) + index += 1 + + yield batch + batch = [] + total_mem = 0 + + # if batch exists and we are not dropping last or if we are dropping last but the batch is saturated + # then yield the last batch + if (batch and not self._drop_last) or ( + self._drop_last and total_mem / self._max_batch_size > SATURATION_FACTOR + ): + yield batch + + def __len__(self) -> int: + """ + Returns the total number of samples. + """ + return len(self._samples_list) diff --git a/python/aistore/pytorch/utils.py b/python/aistore/pytorch/utils.py index fde07215f9..73477aeab4 100644 --- a/python/aistore/pytorch/utils.py +++ b/python/aistore/pytorch/utils.py @@ -7,6 +7,35 @@ from urllib.parse import urlunparse from typing import Tuple from aistore.sdk.utils import parse_url as sdk_parse_url +from math import floor + +MB_TO_B = 1000000 + + +def convert_mb_to_bytes(megabytes: float) -> int: + """ + Converts megabytes to bytes and truncates any extra bytes (floor). + + Args: + megabytes (float): number of megabytes to convert + + Returns: + int: number of bytes after conversion (floor of actual byte value) + """ + return floor(megabytes * MB_TO_B) + + +def convert_bytes_to_mb(bytes: int) -> float: + """ + Converts byes to megabytes. + + Args: + bytes (int): number of bytes to convert to megabytes + + Returns: + float: number of megabytes after conversion + """ + return bytes / MB_TO_B def unparse_url(provider: str, bck_name: str, obj_name: str) -> str: diff --git a/python/aistore/sdk/ais_source.py b/python/aistore/sdk/ais_source.py index 22e84c8d20..13f43eb4b1 100644 --- a/python/aistore/sdk/ais_source.py +++ b/python/aistore/sdk/ais_source.py @@ -19,12 +19,16 @@ def client(self) -> RequestClient: """The client bound to the AISSource.""" @abstractmethod - def list_all_objects_iter(self, prefix: str = "") -> Iterable[Object]: + def list_all_objects_iter( + self, prefix: str = "", props: str = "name,size" + ) -> Iterable[Object]: """ Get an iterable of all the objects contained in this source (bucket, group, etc.) Args: prefix (str, optional): Only include objects with names matching this prefix + props (str, optional): Comma-separated list of object properties to return. + Default value includes all properties: "name,size" Returns: Iterable over selected objects diff --git a/python/aistore/sdk/bucket.py b/python/aistore/sdk/bucket.py index 6128a3212a..efb42f8bd9 100644 --- a/python/aistore/sdk/bucket.py +++ b/python/aistore/sdk/bucket.py @@ -149,19 +149,23 @@ def list_urls(self, prefix: str = "", etl_name: str = None) -> Iterable[str]: for entry in self.list_objects_iter(prefix=prefix, props="name"): yield self.object(entry.name).get_url(etl_name=etl_name) - def list_all_objects_iter(self, prefix: str = "") -> Iterable[Object]: + def list_all_objects_iter( + self, prefix: str = "", props: str = "name,size" + ) -> Iterable[Object]: """ Implementation of the abstract method from AISSource that provides an iterator - of all the objects in this bucket matching the specified prefix + of all the objects in this bucket matching the specified prefix. Args: prefix (str, optional): Limit objects selected by a given string prefix + props (str, optional): Comma-separated list of object properties to return. + Default value includes all properties: "name,size" Returns: Iterator of all object URLs matching the prefix """ - for entry in self.list_objects_iter(prefix=prefix, props="name"): - yield self.object(entry.name) + for entry in self.list_objects_iter(prefix=prefix, props=props): + yield self.object(entry.name, entry.size) def create(self, exist_ok=False): """ @@ -812,21 +816,19 @@ def _get_uploaded_obj_name(file, root_path, basename, prepend): return prepend + obj_name return obj_name - def object(self, obj_name: str) -> Object: + def object(self, obj_name: str, size: int = None) -> Object: """ Factory constructor for an object in this bucket. Does not make any HTTP request, only instantiates an object in a bucket owned by the client. Args: obj_name (str): Name of object + size (int, optional): Size of object in bytes Returns: The object created. """ - return Object( - bucket=self, - name=obj_name, - ) + return Object(bucket=self, name=obj_name, size=size) def objects( self, diff --git a/python/aistore/sdk/multiobj/object_group.py b/python/aistore/sdk/multiobj/object_group.py index 2d1aab4f5c..0f16c8c326 100644 --- a/python/aistore/sdk/multiobj/object_group.py +++ b/python/aistore/sdk/multiobj/object_group.py @@ -94,19 +94,31 @@ def list_urls(self, prefix: str = "", etl_name: str = None) -> Iterable[str]: for obj_name in self._obj_collection: yield self.bck.object(obj_name).get_url(etl_name=etl_name) - def list_all_objects_iter(self, prefix: str = "") -> Iterable[Object]: + def list_all_objects_iter( + self, prefix: str = "", props: str = "name,size" + ) -> Iterable[Object]: """ Implementation of the abstract method from AISSource that provides an iterator - of all the objects in this bucket matching the specified prefix + of all the objects in this bucket matching the specified prefix. Args: prefix (str, optional): Limit objects selected by a given string prefix + props (str, optional): Comma-separated list of object properties to return. + Default value includes all properties: "name,size" Returns: Iterator of all the objects in the group """ for obj_name in self._obj_collection: - yield self.bck.object(obj_name) + + obj = self.bck.object(obj_name) + + # TODO: Use the head object API to pass an object of requested props + if props is not None and "size" in props.split(","): + size = obj.head()["Content-Length"] + obj.size = size + + yield obj def delete(self): """ diff --git a/python/aistore/sdk/object.py b/python/aistore/sdk/object.py index 8b9e73dfd7..018b8051e0 100644 --- a/python/aistore/sdk/object.py +++ b/python/aistore/sdk/object.py @@ -50,16 +50,17 @@ class Object: Args: bucket (Bucket): Bucket to which this object belongs name (str): name of object - + size (int, optional): size of object in bytes """ - def __init__(self, bucket: "Bucket", name: str): + def __init__(self, bucket: "Bucket", name: str, size: int = None): self._bucket = bucket self._client = bucket.client self._bck_name = bucket.name self._qparams = bucket.qparam self._name = name self._object_path = f"{URL_PATH_OBJECTS}/{ self._bck_name}/{ self.name }" + self._size = size @property def bucket(self): @@ -71,6 +72,11 @@ def name(self): """Name of this object""" return self._name + @property + def size(self): + """Size of this object in bytes""" + return self._size + def head(self) -> Header: """ Requests object properties. diff --git a/python/tests/integration/pytorch/test_samplers.py b/python/tests/integration/pytorch/test_samplers.py new file mode 100644 index 0000000000..0a7ea3a936 --- /dev/null +++ b/python/tests/integration/pytorch/test_samplers.py @@ -0,0 +1,78 @@ +""" +Integration Test class for AIStore PyTorch Samplers + +Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +""" + +from unittest import TestCase +from tests.integration import CLUSTER_ENDPOINT +from tests.utils import destroy_bucket, random_string +from aistore import Client +from random import randint +from aistore.pytorch import AISMapDataset, DynamicBatchSampler +from torch.utils.data import DataLoader +from sys import getsizeof +from aistore.pytorch.utils import convert_bytes_to_mb + + +MIN_OBJ_SIZE = 1000 # bytes = 1kb +MAX_OBJ_SIZE = 1000000 # bytes = 1mb +NUM_OBJECTS = 100 +MAX_BATCH_SIZE = 1500000 # bytes = 1.5mb + + +class TestAISSampler(TestCase): + """ + Integration tests for the AIS Pytorch Samplers + """ + + def setUp(self) -> None: + self.bck_name = random_string() + self.client = Client(CLUSTER_ENDPOINT) + self.bck = self.client.bucket(self.bck_name) + self.bck.create() + + for i in range(NUM_OBJECTS): + content = b"\0" * (randint(0, (MAX_OBJ_SIZE - MIN_OBJ_SIZE)) + MIN_OBJ_SIZE) + self.bck.object(f"object-{i}").put_content(content) + + self.dataset = AISMapDataset(ais_source_list=self.bck) + + def tearDown(self) -> None: + """ + Cleanup after each test, destroy the bucket if it exists + """ + destroy_bucket(self.client, self.bck_name) + + def test_dynamic_sampler(self): + # Create dataloader using dynamic batch sampler + loader = DataLoader( + dataset=self.dataset, + batch_sampler=DynamicBatchSampler( + data_source=self.dataset, + max_batch_size=MAX_BATCH_SIZE, + ), + ) + + num_objects = 0 + for names, content in loader: + # Test that batches are not empty and have consistent shape + self.assertTrue(names is not None) + self.assertTrue(content is not None) + self.assertEqual(len(names), len(content)) + + # Test that the size of each object is within the bounds + batch_size = 0 + for data in content: + data_size = getsizeof(data) + self.assertTrue(data_size >= MIN_OBJ_SIZE and data_size < MAX_OBJ_SIZE) + batch_size += data_size + + # Test that total batch size is within the bounds + batch_size = convert_bytes_to_mb(batch_size) + self.assertTrue(batch_size <= MAX_BATCH_SIZE) + + num_objects += len(names) + + # Test that all objects are included in our batch + self.assertEqual(num_objects, NUM_OBJECTS) diff --git a/python/tests/unit/pytorch/test_datasets.py b/python/tests/unit/pytorch/test_datasets.py index 6d8f58c465..ea3e840fc8 100644 --- a/python/tests/unit/pytorch/test_datasets.py +++ b/python/tests/unit/pytorch/test_datasets.py @@ -15,7 +15,6 @@ class TestAISDataset(unittest.TestCase): def setUp(self) -> None: - self.mock_client = Mock() mock_obj = Mock() mock_obj.get.return_value.read_all.return_value = b"mock data" self.mock_objects = [ diff --git a/python/tests/unit/pytorch/test_samplers.py b/python/tests/unit/pytorch/test_samplers.py new file mode 100644 index 0000000000..b17d48ff47 --- /dev/null +++ b/python/tests/unit/pytorch/test_samplers.py @@ -0,0 +1,132 @@ +""" +Unit Test class for AIStore PyTorch Samplers + +Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +""" + +import unittest +from unittest.mock import patch, Mock +import unittest.mock +from aistore.sdk import Bucket +from aistore.sdk.object import Object +from aistore.pytorch import AISMapDataset, DynamicBatchSampler +from torch.utils.data import DataLoader + + +class TestAISSampler(unittest.TestCase): + def setUp(self) -> None: + mock_obj = Mock(Object) + + self.data = b"\0" * 1000 # 1kb + mock_obj.get.return_value.read_all.return_value = self.data + mock_obj.size = len(self.data) + mock_obj.name = "test_obj" + + self.mock_objects = [mock_obj for _ in range(10)] # 10 objects total + + self.mock_bck = Mock(Bucket) + + self.patcher_get_objects = patch( + "aistore.pytorch.base_map_dataset.AISBaseMapDataset._create_samples_list", + return_value=self.mock_objects, + ) + + self.patcher_get_objects.start() + + self.mock_bck.list_all_objects_iter.return_value = iter(self.mock_objects) + + self.ais_dataset = AISMapDataset(ais_source_list=self.mock_bck) + + def tearDown(self) -> None: + self.patcher_get_objects.stop() + + def test_dynamic_sampler(self): + loader = DataLoader( + self.ais_dataset, + batch_sampler=DynamicBatchSampler( + data_source=self.ais_dataset, + max_batch_size=2000, # two objects (each is 1KB) per batch, 5 batches + ), + ) + + num_batches = 0 + for names, content in loader: + num_batches += 1 + self.assertEqual(len(names), 2) + + for data in content: + self.assertEqual(data, self.data) + + self.assertEqual(num_batches, 5) + + def test_dynamic_sampler_drop_last(self): + loader = DataLoader( + self.ais_dataset, + batch_sampler=DynamicBatchSampler( + data_source=self.ais_dataset, + max_batch_size=3000, # three objects (each is 1KB) per batch + drop_last=True, # should result in 3 batches instead of four + ), + ) + + num_batches = 0 + for names, content in loader: + num_batches += 1 + self.assertEqual(len(names), 3) + + for data in content: + self.assertEqual(data, self.data) + + self.assertEqual(num_batches, 3) + + def test_dynamic_sampler_oversized(self): + loader = DataLoader( + self.ais_dataset, + batch_sampler=DynamicBatchSampler( + data_source=self.ais_dataset, + max_batch_size=500, # even though objects are larger, include + allow_oversized_samples=True, # should result in 10 batches + ), + ) + + num_batches = 0 + for names, content in loader: + num_batches += 1 + self.assertEqual(len(names), 1) + + for data in content: + self.assertEqual(data, self.data) + + self.assertEqual(num_batches, 10) + + def test_dynamic_sampler_oversized_drop_last(self): + + # add odd one odd one out 6kb object + mock_obj = Mock(Object) + large_data = b"\0" * 6000 # 6kb + mock_obj.get.return_value.read_all.return_value = large_data + mock_obj.size = len(large_data) + mock_obj.name = "test_obj" + + self.mock_objects.append(mock_obj) + self.mock_bck.list_all_objects_iter.return_value = iter(self.mock_objects) + + loader = DataLoader( + self.ais_dataset, + batch_sampler=DynamicBatchSampler( + data_source=self.ais_dataset, + max_batch_size=3000, # three objects (each is 1KB) per batch + drop_last=True, # should result in 3 batches instead of four + allow_oversized_samples=True, # should add one batch since we added 6kb object above + ), + ) + + num_batches = 0 + for names, content in loader: + num_batches += 1 + self.assertTrue(len(names) == 3 or len(names) == 1) + + for data in content: + self.assertTrue(data == self.data or data == large_data) + + self.assertEqual(num_batches, 4) diff --git a/python/tests/unit/sdk/multiobj/test_object_group.py b/python/tests/unit/sdk/multiobj/test_object_group.py index 1e7ee9d468..57fbb1219b 100644 --- a/python/tests/unit/sdk/multiobj/test_object_group.py +++ b/python/tests/unit/sdk/multiobj/test_object_group.py @@ -282,5 +282,5 @@ def test_list_urls(self): self.mock_bck.object.assert_has_calls(expected_obj_calls) def test_list_all_objects_iter(self): - res = self.object_group.list_all_objects_iter() + res = self.object_group.list_all_objects_iter(props=None) self.assertEqual(len(list(res)), len(self.obj_names))