Skip to content

Commit

Permalink
python/pytorch: Implement dynamic sampler for map based datasets
Browse files Browse the repository at this point in the history
Signed-off-by: Soham Manoli <[email protected]>
  • Loading branch information
msoham123 committed Jul 18, 2024
1 parent 43d88d2 commit aaec09d
Show file tree
Hide file tree
Showing 12 changed files with 386 additions and 18 deletions.
1 change: 1 addition & 0 deletions python/aistore/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from aistore.pytorch.shard_reader import AISShardReader
from aistore.pytorch.base_map_dataset import AISBaseMapDataset
from aistore.pytorch.base_iter_dataset import AISBaseIterDataset
from aistore.pytorch.dynamic_sampler import DynamicBatchSampler
5 changes: 4 additions & 1 deletion python/aistore/pytorch/base_map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AISBaseMapDataset(ABC, Dataset):
"""
A base class for creating map-style AIS Datasets. Should not be instantiated directly. Subclasses
should implement :meth:`__getitem__` which fetches a samples given a key from the dataset and can optionally
override other methods from torch IterableDataset such as :meth:`__len__`. Additionally,
override other methods from torch Dataset such as :meth:`__len__` and :meth:`__getitems__`. Additionally,
to modify the behavior of loading samples from a source, override :meth:`_get_sample_list_from_source`.
Args:
Expand Down Expand Up @@ -87,3 +87,6 @@ def _create_samples_list(self) -> List[Object]:
@abstractmethod
def __getitem__(self, index):
pass

def __getitems__(self, indices: List[int]):
return [self.__getitem__(index) for index in indices]
102 changes: 102 additions & 0 deletions python/aistore/pytorch/dynamic_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
Dynamic Batch Sampler for Dynamic Batch Sizing
In scenarios where memory is a constraint, the DynamicBatchSampler
can be used to generate mini-batches that fit within a memory constraint
so that there is a guarantee that each batch fits within memory
while attempting to fit the maximum number of samples in each batch.
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from torch.utils.data import Sampler
from typing import Iterator, List
from aistore.pytorch.base_map_dataset import AISBaseMapDataset
from logging import getLogger

# Saturation of a batch needed to not be dropped with drop_last=True
SATURATION_FACTOR = 0.8


class DynamicBatchSampler(Sampler):
"""
Dynamically adds samples to mini-batch up to a maximum batch size.
NOTE: Using this sampler with AISBaseMapDatasets that use ObjectGroups
in their ais_source_lists will be slower than using it with Buckets as
ObjectGroups will perform one extra API call per object to get size metadata.
Args:
data_source (AISBaseMapDataset): Base AIS map-style dataset to sample from to create dynamic mini-batches.
max_batch_size (float): Maximum size of mini-batch in bytes.
drop_last (bool, optional): If `True`, then will drop last batch if the batch is not atleast 80% of `max_batch_size`.
Defaults to `False`.
allow_oversized_samples (bool, optional): If `True`, then any sample that is larger than the `max_batch_size` will be processed
in its own min-batch by itself instead of being dropped. Defaults to `False`.
"""

def __init__(
self,
data_source: AISBaseMapDataset,
max_batch_size: float,
drop_last: bool = False,
allow_oversized_samples: bool = False,
) -> None:
self._data_source = data_source
self._max_batch_size = max_batch_size
self._samples_list = data_source._create_samples_list()
self._drop_last = drop_last
self._allow_oversized_samples = allow_oversized_samples
self._logger = getLogger(f"{__name__}.put_files")

def __iter__(self) -> Iterator[List[int]]:
"""
Returns an iterator containing mini-batches (lists of indices).
"""
total_mem = 0
batch = []
index = 0

# get sample size for each index, check if there is space in the batch, and yield batches whenever full
# calculate spaces in batch non-preemptively
while index < len(self):
sample = self._samples_list[index]

if sample.size > self._max_batch_size:
if self._allow_oversized_samples is True:
yield [index]
else:
self._logger.warn(
f"Sample {sample.name} cannot be processed as it is larger than the max batch size: {sample.size} bytes > {self._max_batch_size} bytes"
)

index += 1
continue

if total_mem + sample.size < self._max_batch_size:
batch.append(index)
index += 1
total_mem += sample.size
else:

if total_mem + sample.size == self._max_batch_size:
batch.append(index)
index += 1

yield batch
batch = []
total_mem = 0

# if batch exists and we are not dropping last or if we are dropping last but the batch is saturated
# then yield the last batch
if (batch and not self._drop_last) or (
self._drop_last and total_mem / self._max_batch_size > SATURATION_FACTOR
):
yield batch

def __len__(self) -> int:
"""
Returns the total number of samples.
"""
return len(self._samples_list)
29 changes: 29 additions & 0 deletions python/aistore/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,35 @@
from urllib.parse import urlunparse
from typing import Tuple
from aistore.sdk.utils import parse_url as sdk_parse_url
from math import floor

MB_TO_B = 1000000


def convert_mb_to_bytes(megabytes: float) -> int:
"""
Converts megabytes to bytes and truncates any extra bytes (floor).
Args:
megabytes (float): number of megabytes to convert
Returns:
int: number of bytes after conversion (floor of actual byte value)
"""
return floor(megabytes * MB_TO_B)


def convert_bytes_to_mb(bytes: int) -> float:
"""
Converts byes to megabytes.
Args:
bytes (int): number of bytes to convert to megabytes
Returns:
float: number of megabytes after conversion
"""
return bytes / MB_TO_B


def unparse_url(provider: str, bck_name: str, obj_name: str) -> str:
Expand Down
6 changes: 5 additions & 1 deletion python/aistore/sdk/ais_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ def client(self) -> RequestClient:
"""The client bound to the AISSource."""

@abstractmethod
def list_all_objects_iter(self, prefix: str = "") -> Iterable[Object]:
def list_all_objects_iter(
self, prefix: str = "", props: str = "name,size"
) -> Iterable[Object]:
"""
Get an iterable of all the objects contained in this source (bucket, group, etc.)
Args:
prefix (str, optional): Only include objects with names matching this prefix
props (str, optional): Comma-separated list of object properties to return.
Default value includes all properties: "name,size"
Returns:
Iterable over selected objects
Expand Down
20 changes: 11 additions & 9 deletions python/aistore/sdk/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,23 @@ def list_urls(self, prefix: str = "", etl_name: str = None) -> Iterable[str]:
for entry in self.list_objects_iter(prefix=prefix, props="name"):
yield self.object(entry.name).get_url(etl_name=etl_name)

def list_all_objects_iter(self, prefix: str = "") -> Iterable[Object]:
def list_all_objects_iter(
self, prefix: str = "", props: str = "name,size"
) -> Iterable[Object]:
"""
Implementation of the abstract method from AISSource that provides an iterator
of all the objects in this bucket matching the specified prefix
of all the objects in this bucket matching the specified prefix.
Args:
prefix (str, optional): Limit objects selected by a given string prefix
props (str, optional): Comma-separated list of object properties to return.
Default value includes all properties: "name,size"
Returns:
Iterator of all object URLs matching the prefix
"""
for entry in self.list_objects_iter(prefix=prefix, props="name"):
yield self.object(entry.name)
for entry in self.list_objects_iter(prefix=prefix, props=props):
yield self.object(entry.name, entry.size)

def create(self, exist_ok=False):
"""
Expand Down Expand Up @@ -812,21 +816,19 @@ def _get_uploaded_obj_name(file, root_path, basename, prepend):
return prepend + obj_name
return obj_name

def object(self, obj_name: str) -> Object:
def object(self, obj_name: str, size: int = None) -> Object:
"""
Factory constructor for an object in this bucket.
Does not make any HTTP request, only instantiates an object in a bucket owned by the client.
Args:
obj_name (str): Name of object
size (int, optional): Size of object in bytes
Returns:
The object created.
"""
return Object(
bucket=self,
name=obj_name,
)
return Object(bucket=self, name=obj_name, size=size)

def objects(
self,
Expand Down
18 changes: 15 additions & 3 deletions python/aistore/sdk/multiobj/object_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,31 @@ def list_urls(self, prefix: str = "", etl_name: str = None) -> Iterable[str]:
for obj_name in self._obj_collection:
yield self.bck.object(obj_name).get_url(etl_name=etl_name)

def list_all_objects_iter(self, prefix: str = "") -> Iterable[Object]:
def list_all_objects_iter(
self, prefix: str = "", props: str = "name,size"
) -> Iterable[Object]:
"""
Implementation of the abstract method from AISSource that provides an iterator
of all the objects in this bucket matching the specified prefix
of all the objects in this bucket matching the specified prefix.
Args:
prefix (str, optional): Limit objects selected by a given string prefix
props (str, optional): Comma-separated list of object properties to return.
Default value includes all properties: "name,size"
Returns:
Iterator of all the objects in the group
"""
for obj_name in self._obj_collection:
yield self.bck.object(obj_name)

obj = self.bck.object(obj_name)

# TODO: Use the head object API to pass an object of requested props
if props is not None and "size" in props.split(","):
size = obj.head()["Content-Length"]
obj.size = size

yield obj

def delete(self):
"""
Expand Down
10 changes: 8 additions & 2 deletions python/aistore/sdk/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,17 @@ class Object:
Args:
bucket (Bucket): Bucket to which this object belongs
name (str): name of object
size (int, optional): size of object in bytes
"""

def __init__(self, bucket: "Bucket", name: str):
def __init__(self, bucket: "Bucket", name: str, size: int = None):
self._bucket = bucket
self._client = bucket.client
self._bck_name = bucket.name
self._qparams = bucket.qparam
self._name = name
self._object_path = f"{URL_PATH_OBJECTS}/{ self._bck_name}/{ self.name }"
self._size = size

@property
def bucket(self):
Expand All @@ -71,6 +72,11 @@ def name(self):
"""Name of this object"""
return self._name

@property
def size(self):
"""Size of this object in bytes"""
return self._size

def head(self) -> Header:
"""
Requests object properties.
Expand Down
78 changes: 78 additions & 0 deletions python/tests/integration/pytorch/test_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Integration Test class for AIStore PyTorch Samplers
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from unittest import TestCase
from tests.integration import CLUSTER_ENDPOINT
from tests.utils import destroy_bucket, random_string
from aistore import Client
from random import randint
from aistore.pytorch import AISMapDataset, DynamicBatchSampler
from torch.utils.data import DataLoader
from sys import getsizeof
from aistore.pytorch.utils import convert_bytes_to_mb


MIN_OBJ_SIZE = 1000 # bytes = 1kb
MAX_OBJ_SIZE = 1000000 # bytes = 1mb
NUM_OBJECTS = 100
MAX_BATCH_SIZE = 1500000 # bytes = 1.5mb


class TestAISSampler(TestCase):
"""
Integration tests for the AIS Pytorch Samplers
"""

def setUp(self) -> None:
self.bck_name = random_string()
self.client = Client(CLUSTER_ENDPOINT)
self.bck = self.client.bucket(self.bck_name)
self.bck.create()

for i in range(NUM_OBJECTS):
content = b"\0" * (randint(0, (MAX_OBJ_SIZE - MIN_OBJ_SIZE)) + MIN_OBJ_SIZE)
self.bck.object(f"object-{i}").put_content(content)

self.dataset = AISMapDataset(ais_source_list=self.bck)

def tearDown(self) -> None:
"""
Cleanup after each test, destroy the bucket if it exists
"""
destroy_bucket(self.client, self.bck_name)

def test_dynamic_sampler(self):
# Create dataloader using dynamic batch sampler
loader = DataLoader(
dataset=self.dataset,
batch_sampler=DynamicBatchSampler(
data_source=self.dataset,
max_batch_size=MAX_BATCH_SIZE,
),
)

num_objects = 0
for names, content in loader:
# Test that batches are not empty and have consistent shape
self.assertTrue(names is not None)
self.assertTrue(content is not None)
self.assertEqual(len(names), len(content))

# Test that the size of each object is within the bounds
batch_size = 0
for data in content:
data_size = getsizeof(data)
self.assertTrue(data_size >= MIN_OBJ_SIZE and data_size < MAX_OBJ_SIZE)
batch_size += data_size

# Test that total batch size is within the bounds
batch_size = convert_bytes_to_mb(batch_size)
self.assertTrue(batch_size <= MAX_BATCH_SIZE)

num_objects += len(names)

# Test that all objects are included in our batch
self.assertEqual(num_objects, NUM_OBJECTS)
1 change: 0 additions & 1 deletion python/tests/unit/pytorch/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

class TestAISDataset(unittest.TestCase):
def setUp(self) -> None:
self.mock_client = Mock()
mock_obj = Mock()
mock_obj.get.return_value.read_all.return_value = b"mock data"
self.mock_objects = [
Expand Down
Loading

0 comments on commit aaec09d

Please sign in to comment.