-
Notifications
You must be signed in to change notification settings - Fork 189
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
python/pytorch: Implement dynamic sampler for map based datasets
Signed-off-by: Soham Manoli <[email protected]>
- Loading branch information
Showing
12 changed files
with
386 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
""" | ||
Dynamic Batch Sampler for Dynamic Batch Sizing | ||
In scenarios where memory is a constraint, the DynamicBatchSampler | ||
can be used to generate mini-batches that fit within a memory constraint | ||
so that there is a guarantee that each batch fits within memory | ||
while attempting to fit the maximum number of samples in each batch. | ||
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
""" | ||
|
||
from torch.utils.data import Sampler | ||
from typing import Iterator, List | ||
from aistore.pytorch.base_map_dataset import AISBaseMapDataset | ||
from logging import getLogger | ||
|
||
# Saturation of a batch needed to not be dropped with drop_last=True | ||
SATURATION_FACTOR = 0.8 | ||
|
||
|
||
class DynamicBatchSampler(Sampler): | ||
""" | ||
Dynamically adds samples to mini-batch up to a maximum batch size. | ||
NOTE: Using this sampler with AISBaseMapDatasets that use ObjectGroups | ||
in their ais_source_lists will be slower than using it with Buckets as | ||
ObjectGroups will perform one extra API call per object to get size metadata. | ||
Args: | ||
data_source (AISBaseMapDataset): Base AIS map-style dataset to sample from to create dynamic mini-batches. | ||
max_batch_size (float): Maximum size of mini-batch in bytes. | ||
drop_last (bool, optional): If `True`, then will drop last batch if the batch is not atleast 80% of `max_batch_size`. | ||
Defaults to `False`. | ||
allow_oversized_samples (bool, optional): If `True`, then any sample that is larger than the `max_batch_size` will be processed | ||
in its own min-batch by itself instead of being dropped. Defaults to `False`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data_source: AISBaseMapDataset, | ||
max_batch_size: float, | ||
drop_last: bool = False, | ||
allow_oversized_samples: bool = False, | ||
) -> None: | ||
self._data_source = data_source | ||
self._max_batch_size = max_batch_size | ||
self._samples_list = data_source._create_samples_list() | ||
self._drop_last = drop_last | ||
self._allow_oversized_samples = allow_oversized_samples | ||
self._logger = getLogger(f"{__name__}.put_files") | ||
|
||
def __iter__(self) -> Iterator[List[int]]: | ||
""" | ||
Returns an iterator containing mini-batches (lists of indices). | ||
""" | ||
total_mem = 0 | ||
batch = [] | ||
index = 0 | ||
|
||
# get sample size for each index, check if there is space in the batch, and yield batches whenever full | ||
# calculate spaces in batch non-preemptively | ||
while index < len(self): | ||
sample = self._samples_list[index] | ||
|
||
if sample.size > self._max_batch_size: | ||
if self._allow_oversized_samples is True: | ||
yield [index] | ||
else: | ||
self._logger.warn( | ||
f"Sample {sample.name} cannot be processed as it is larger than the max batch size: {sample.size} bytes > {self._max_batch_size} bytes" | ||
) | ||
|
||
index += 1 | ||
continue | ||
|
||
if total_mem + sample.size < self._max_batch_size: | ||
batch.append(index) | ||
index += 1 | ||
total_mem += sample.size | ||
else: | ||
|
||
if total_mem + sample.size == self._max_batch_size: | ||
batch.append(index) | ||
index += 1 | ||
|
||
yield batch | ||
batch = [] | ||
total_mem = 0 | ||
|
||
# if batch exists and we are not dropping last or if we are dropping last but the batch is saturated | ||
# then yield the last batch | ||
if (batch and not self._drop_last) or ( | ||
self._drop_last and total_mem / self._max_batch_size > SATURATION_FACTOR | ||
): | ||
yield batch | ||
|
||
def __len__(self) -> int: | ||
""" | ||
Returns the total number of samples. | ||
""" | ||
return len(self._samples_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
""" | ||
Integration Test class for AIStore PyTorch Samplers | ||
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
""" | ||
|
||
from unittest import TestCase | ||
from tests.integration import CLUSTER_ENDPOINT | ||
from tests.utils import destroy_bucket, random_string | ||
from aistore import Client | ||
from random import randint | ||
from aistore.pytorch import AISMapDataset, DynamicBatchSampler | ||
from torch.utils.data import DataLoader | ||
from sys import getsizeof | ||
from aistore.pytorch.utils import convert_bytes_to_mb | ||
|
||
|
||
MIN_OBJ_SIZE = 1000 # bytes = 1kb | ||
MAX_OBJ_SIZE = 1000000 # bytes = 1mb | ||
NUM_OBJECTS = 100 | ||
MAX_BATCH_SIZE = 1500000 # bytes = 1.5mb | ||
|
||
|
||
class TestAISSampler(TestCase): | ||
""" | ||
Integration tests for the AIS Pytorch Samplers | ||
""" | ||
|
||
def setUp(self) -> None: | ||
self.bck_name = random_string() | ||
self.client = Client(CLUSTER_ENDPOINT) | ||
self.bck = self.client.bucket(self.bck_name) | ||
self.bck.create() | ||
|
||
for i in range(NUM_OBJECTS): | ||
content = b"\0" * (randint(0, (MAX_OBJ_SIZE - MIN_OBJ_SIZE)) + MIN_OBJ_SIZE) | ||
self.bck.object(f"object-{i}").put_content(content) | ||
|
||
self.dataset = AISMapDataset(ais_source_list=self.bck) | ||
|
||
def tearDown(self) -> None: | ||
""" | ||
Cleanup after each test, destroy the bucket if it exists | ||
""" | ||
destroy_bucket(self.client, self.bck_name) | ||
|
||
def test_dynamic_sampler(self): | ||
# Create dataloader using dynamic batch sampler | ||
loader = DataLoader( | ||
dataset=self.dataset, | ||
batch_sampler=DynamicBatchSampler( | ||
data_source=self.dataset, | ||
max_batch_size=MAX_BATCH_SIZE, | ||
), | ||
) | ||
|
||
num_objects = 0 | ||
for names, content in loader: | ||
# Test that batches are not empty and have consistent shape | ||
self.assertTrue(names is not None) | ||
self.assertTrue(content is not None) | ||
self.assertEqual(len(names), len(content)) | ||
|
||
# Test that the size of each object is within the bounds | ||
batch_size = 0 | ||
for data in content: | ||
data_size = getsizeof(data) | ||
self.assertTrue(data_size >= MIN_OBJ_SIZE and data_size < MAX_OBJ_SIZE) | ||
batch_size += data_size | ||
|
||
# Test that total batch size is within the bounds | ||
batch_size = convert_bytes_to_mb(batch_size) | ||
self.assertTrue(batch_size <= MAX_BATCH_SIZE) | ||
|
||
num_objects += len(names) | ||
|
||
# Test that all objects are included in our batch | ||
self.assertEqual(num_objects, NUM_OBJECTS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.