From 8c0c45a51e62adfc9dc3c5b3827d59211e19f0c6 Mon Sep 17 00:00:00 2001 From: Bernard Han <103850980+bernardhan33@users.noreply.github.com> Date: Thu, 14 Mar 2024 13:26:15 -0700 Subject: [PATCH] Add support for Dataflux Iterable Dataset (#17) * add Dataflux Iterable Dataset * rename to DataFluxIterableDataset * add test case for multi-worker setup * update license header year * update license header year #2 * Address comments --- dataflux_client_python | 2 +- dataflux_pytorch/dataflux_iterable_dataset.py | 171 ++++++++++ dataflux_pytorch/dataflux_mapstyle_dataset.py | 2 +- .../tests/test_dataflux_iterable_dataset.py | 293 ++++++++++++++++++ .../tests/test_dataflux_mapstyle_dataset.py | 2 +- 5 files changed, 467 insertions(+), 3 deletions(-) create mode 100644 dataflux_pytorch/dataflux_iterable_dataset.py create mode 100644 dataflux_pytorch/tests/test_dataflux_iterable_dataset.py diff --git a/dataflux_client_python b/dataflux_client_python index 2fc12b62..81a42aef 160000 --- a/dataflux_client_python +++ b/dataflux_client_python @@ -1 +1 @@ -Subproject commit 2fc12b62a3942b6bb9e637482ac7d428acad1b46 +Subproject commit 81a42aefae8760726d83287f7c4016f3cd37522b diff --git a/dataflux_pytorch/dataflux_iterable_dataset.py b/dataflux_pytorch/dataflux_iterable_dataset.py new file mode 100644 index 00000000..da0ce96a --- /dev/null +++ b/dataflux_pytorch/dataflux_iterable_dataset.py @@ -0,0 +1,171 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import os +import math +import logging + +from torch.utils import data +from google.cloud import storage +from google.api_core.client_info import ClientInfo + +import dataflux_core + + +class Config: + """Customizable configuration to the DataFluxIterableDataset. + + Attributes: + sort_listing_results: A boolean flag indicating if data listing results + will be alphabetically sorted. Default to False. + + max_composite_object_size: An integer indicating a cap for the maximum + size of the composite object in bytes. Default to 100000000 = 100 MiB. + + num_processes: The number of processes to be used in the Dataflux algorithms. + Default to the number of CPUs from the running environment. + + prefix: The prefix that is used to list the objects in the bucket with. + The default is None which means it will list all the objects in the bucket. + + max_listing_retries: An integer indicating the maximum number of retries + to attempt in case of any Python multiprocessing errors during + GCS objects listing. Default to 3. + """ + + def __init__( + self, + sort_listing_results: bool = False, + max_composite_object_size: int = 100000000, + num_processes: int = os.cpu_count(), + prefix: str = None, + max_listing_retries: int = 3, + ): + self.sort_listing_results = sort_listing_results + self.max_composite_object_size = max_composite_object_size + self.num_processes = num_processes + self.prefix = prefix + self.max_listing_retries = max_listing_retries + + +class DataFluxIterableDataset(data.IterableDataset): + def __init__( + self, + project_name, + bucket_name, + config=Config(), + data_format_fn=lambda data: data, + storage_client=None, + ): + """Initializes the DataFluxIterableDataset. + + The initialization sets up the needed configuration and runs data + listing using the Dataflux algorithm. + + Args: + project_name: The name of the GCP project. + bucket_name: The name of the GCS bucket that holds the objects to compose. + The Dataflux download algorithm uploads the the composed object to this bucket too. + destination_blob_name: The name of the composite object to be created. + config: A dataflux_iterable_dataset.Config object that includes configuration + customizations. If not specified, a default config with default parameters is created. + data_format_fn: A function that formats the downloaded bytes to the desired format. + If not specified, the default formatting function leaves the data as-is. + storage_client: The google.cloud.storage.Client object initiated with sufficient permission + to access the project and the bucket. If not specified, it will be created + during initialization. + """ + super().__init__() + self.storage_client = storage_client + if not storage_client: + self.storage_client = storage.Client( + project=project_name, + client_info=ClientInfo(user_agent="dataflux/0.0"), + ) + self.project_name = project_name + self.bucket_name = bucket_name + self.data_format_fn = data_format_fn + self.config = config + self.dataflux_download_optimization_params = ( + dataflux_core.download.DataFluxDownloadOptimizationParams( + max_composite_object_size=self.config.max_composite_object_size + ) + ) + + self.objects = self._list_GCS_blobs_with_retry() + + def __iter__(self): + worker_info = data.get_worker_info() + if worker_info is None: + # Single-process data loading. + yield from [ + self.data_format_fn(bytes_content) + for bytes_content in dataflux_core.download.dataflux_download_lazy( + project_name=self.project_name, + bucket_name=self.bucket_name, + objects=self.objects, + storage_client=self.storage_client, + dataflux_download_optimization_params=self.dataflux_download_optimization_params, + ) + ] + else: + # Multi-process data loading. Split the workload among workers. + # Ref: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset. + per_worker = int( + math.ceil(len(self.objects) / float(worker_info.num_workers)) + ) + worker_id = worker_info.id + start = worker_id * per_worker + end = min(start + per_worker, len(self.objects)) + yield from [ + self.data_format_fn(bytes_content) + for bytes_content in dataflux_core.download.dataflux_download_lazy( + project_name=self.project_name, + bucket_name=self.bucket_name, + objects=self.objects[start:end], + storage_client=self.storage_client, + dataflux_download_optimization_params=self.dataflux_download_optimization_params, + ) + ] + + def _list_GCS_blobs_with_retry(self): + """Retries Dataflux Listing upon exceptions, up to the retries defined in self.config.""" + error = None + listed_objects = [] + for _ in range(self.config.max_listing_retries): + try: + listed_objects = dataflux_core.fast_list.ListingController( + max_parallelism=self.config.num_processes, + project=self.project_name, + bucket=self.bucket_name, + sort_results=self.config.sort_listing_results, + prefix=self.config.prefix, + ).run() + except Exception as e: + logging.error( + f"exception {str(e)} caught running Dataflux fast listing." + ) + error = e + continue + + # No exception -- we can immediately return the listed objects. + else: + return listed_objects + + # Did not break the for loop, therefore all attempts + # raised an exception. + else: + raise error diff --git a/dataflux_pytorch/dataflux_mapstyle_dataset.py b/dataflux_pytorch/dataflux_mapstyle_dataset.py index b63bdf11..05963a4b 100644 --- a/dataflux_pytorch/dataflux_mapstyle_dataset.py +++ b/dataflux_pytorch/dataflux_mapstyle_dataset.py @@ -77,7 +77,7 @@ def __init__( Args: project_name: The name of the GCP project. bucket_name: The name of the GCS bucket that holds the objects to compose. - The function uploads the the composed object to this bucket too. + The Dataflux download algorithm uploads the the composed object to this bucket too. destination_blob_name: The name of the composite object to be created. config: A dataflux_mapstyle_dataset.Config object that includes configuration customizations. If not specified, a default config with default parameters is created. diff --git a/dataflux_pytorch/tests/test_dataflux_iterable_dataset.py b/dataflux_pytorch/tests/test_dataflux_iterable_dataset.py new file mode 100644 index 00000000..776e5b80 --- /dev/null +++ b/dataflux_pytorch/tests/test_dataflux_iterable_dataset.py @@ -0,0 +1,293 @@ +""" + Copyright 2024 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import unittest +from unittest import mock +import math + +from dataflux_client_python.dataflux_core.tests import fake_gcs +from dataflux_pytorch import dataflux_iterable_dataset + + +class IterableDatasetTestCase(unittest.TestCase): + def setUp(self): + super().setUp() + self.project_name = "foo" + self.bucket_name = "bar" + self.config = dataflux_iterable_dataset.Config( + num_processes=3, max_listing_retries=3, prefix="prefix/" + ) + self.data_format_fn = lambda data: data + client = fake_gcs.Client() + + self.want_objects = [("objectA", 1), ("objectB", 2)] + self.storage_client = client + + @mock.patch("dataflux_pytorch.dataflux_iterable_dataset.dataflux_core") + def test_init(self, mock_dataflux_core): + """Tests the DataFluxIterableDataset can be initiated with the expected listing results.""" + # Arrange. + mock_listing_controller = mock.Mock() + mock_listing_controller.run.return_value = self.want_objects + mock_dataflux_core.fast_list.ListingController.return_value = ( + mock_listing_controller + ) + + # Act. + ds = dataflux_iterable_dataset.DataFluxIterableDataset( + project_name=self.project_name, + bucket_name=self.bucket_name, + config=self.config, + data_format_fn=self.data_format_fn, + storage_client=self.storage_client, + ) + + # Assert. + self.assertEqual( + ds.objects, + self.want_objects, + f"got listed objects {ds.objects}, want {self.want_objects}", + ) + + @mock.patch("dataflux_pytorch.dataflux_iterable_dataset.dataflux_core") + def test_init_with_required_parameters(self, mock_dataflux_core): + """Tests the DataFluxIterableDataset can be initiated with only the required parameters.""" + # Arrange. + mock_listing_controller = mock.Mock() + mock_listing_controller.run.return_value = self.want_objects + mock_dataflux_core.fast_list.ListingController.return_value = ( + mock_listing_controller + ) + + # Act. + ds = dataflux_iterable_dataset.DataFluxIterableDataset( + project_name=self.project_name, + bucket_name=self.bucket_name, + # storage_client is optional param but still needed here + # to avoid actual storage.Client construction. + storage_client=self.storage_client, + ) + + # Assert. + self.assertEqual( + ds.objects, + self.want_objects, + f"got listed objects {ds.objects}, want {self.want_objects}", + ) + + @mock.patch("dataflux_pytorch.dataflux_iterable_dataset.dataflux_core") + def test_init_retry_exception_passes(self, mock_dataflux_core): + """Tests that the initialization retries objects llisting upon exception and passes.""" + # Arrange. + mock_listing_controller = mock.Mock() + + # Simulate that the first invocation raises an exception and the second invocation + # passes with the wanted results. + mock_listing_controller.run.side_effect = [ + Exception(), + self.want_objects, + Exception(), + ] + mock_dataflux_core.fast_list.ListingController.return_value = ( + mock_listing_controller + ) + + # Act. + ds = dataflux_iterable_dataset.DataFluxIterableDataset( + project_name=self.project_name, + bucket_name=self.bucket_name, + config=self.config, + data_format_fn=self.data_format_fn, + storage_client=self.storage_client, + ) + + # Assert. + self.assertEqual( + ds.objects, + self.want_objects, + f"got listed objects {ds.objects}, want {self.want_objects}", + ) + + @mock.patch("dataflux_pytorch.dataflux_iterable_dataset.dataflux_core") + def test_init_raises_exception_when_retries_exhaust(self, mock_dataflux_core): + """Tests that the initialization raises exception upon exhaustive retries.""" + # Arrange. + mock_listing_controller = mock.Mock() + want_exception = RuntimeError("123") + + # Simulate that all retries return with exceptions. + mock_listing_controller.run.side_effect = [ + want_exception for _ in range(self.config.max_listing_retries) + ] + mock_dataflux_core.fast_list.ListingController.return_value = ( + mock_listing_controller + ) + + # Act & Assert. + with self.assertRaises(RuntimeError) as re: + ds = dataflux_iterable_dataset.DataFluxIterableDataset( + project_name=self.project_name, + bucket_name=self.bucket_name, + config=self.config, + data_format_fn=self.data_format_fn, + storage_client=self.storage_client, + ) + self.assertIsNone( + ds.objects, + f"got a non-None objects instance variable, want None when all listing retries are exhausted", + ) + + self.assertEqual( + re.exception, + want_exception, + f"got exception {re.exception}, want {want_exception}", + ) + + @mock.patch("dataflux_pytorch.dataflux_iterable_dataset.dataflux_core") + @mock.patch("torch.utils.data.get_worker_info") + def test_iter_single_process(self, mock_worker_info, mock_dataflux_core): + """Tests that the using the iterator of the dataset downloads the list of the correct objects with a single process setup.""" + # Arrange. + mock_listing_controller = mock.Mock() + mock_listing_controller.run.return_value = self.want_objects + mock_dataflux_core.fast_list.ListingController.return_value = ( + mock_listing_controller + ) + want_optimization_params = object() + mock_dataflux_core.download.DataFluxDownloadOptimizationParams.return_value = ( + want_optimization_params + ) + dataflux_download_return_val = [ + bytes("contentA", "utf-8"), + bytes("contentBB", "utf-8"), + ] + + mock_dataflux_core.download.dataflux_download_lazy.return_value = iter( + dataflux_download_return_val + ) + mock_worker_info.return_value = None + + data_format_fn = lambda content: len(content) + want_downloaded = [ + data_format_fn(bytes_content) + for bytes_content in dataflux_download_return_val + ] + + # Act. + ds = dataflux_iterable_dataset.DataFluxIterableDataset( + project_name=self.project_name, + bucket_name=self.bucket_name, + config=self.config, + data_format_fn=data_format_fn, + storage_client=self.storage_client, + ) + got_downloaded = [] + for downloaded in ds: + got_downloaded.append(downloaded) + + # Assert. + self.assertEqual( + got_downloaded, + want_downloaded, + ) + # Since this is a single process setup, we expect dataflux_download_lazy to be + # called with the full list of objects. + mock_dataflux_core.download.dataflux_download_lazy.assert_called_with( + project_name=self.project_name, + bucket_name=self.bucket_name, + objects=self.want_objects, + storage_client=self.storage_client, + dataflux_download_optimization_params=want_optimization_params, + ) + + @mock.patch("dataflux_pytorch.dataflux_iterable_dataset.dataflux_core") + @mock.patch("torch.utils.data.get_worker_info") + def test_iter_multiple_processes(self, mock_worker_info, mock_dataflux_core): + """ + Tests that the using the iterator of the dataset downloads the list of the correct objects with a multi-process setup. + Specifically, each worker should be assigned to download a different batch of the dataset. + """ + # Arrange. + want_objects = [("objectA", 1), ("objectB", 2), ("objectC", 3), ("objectD", 4)] + mock_listing_controller = mock.Mock() + mock_listing_controller.run.return_value = want_objects + mock_dataflux_core.fast_list.ListingController.return_value = ( + mock_listing_controller + ) + want_optimization_params = object() + mock_dataflux_core.download.DataFluxDownloadOptimizationParams.return_value = ( + want_optimization_params + ) + dataflux_download_return_val = [ + bytes("contentA", "utf-8"), + bytes("contentBB", "utf-8"), + ] + + mock_dataflux_core.download.dataflux_download_lazy.return_value = iter( + dataflux_download_return_val + ) + + class _WorkerInfo: + """A fake WorkerInfo class for testing purpose.""" + + def __init__(self, num_workers, id): + self.num_workers = num_workers + self.id = id + + num_workers = 2 + id = 0 + want_per_worker = math.ceil(len(want_objects) / num_workers) + want_start = id * want_per_worker + want_end = want_start + want_per_worker + worker_info = _WorkerInfo(num_workers=num_workers, id=id) + mock_worker_info.return_value = worker_info + + data_format_fn = lambda content: len(content) + want_downloaded = [ + data_format_fn(bytes_content) + for bytes_content in dataflux_download_return_val + ] + + # Act. + ds = dataflux_iterable_dataset.DataFluxIterableDataset( + project_name=self.project_name, + bucket_name=self.bucket_name, + config=self.config, + data_format_fn=data_format_fn, + storage_client=self.storage_client, + ) + got_downloaded = [] + for downloaded in ds: + got_downloaded.append(downloaded) + + # Assert. + self.assertEqual( + got_downloaded, + want_downloaded, + ) + # Since this is a multi-process setup, we expect dataflux_download_lazy to be + # only called to download a slice of the objects want_objects[want_start:want_end]. + mock_dataflux_core.download.dataflux_download_lazy.assert_called_with( + project_name=self.project_name, + bucket_name=self.bucket_name, + objects=want_objects[want_start:want_end], + storage_client=self.storage_client, + dataflux_download_optimization_params=want_optimization_params, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/dataflux_pytorch/tests/test_dataflux_mapstyle_dataset.py b/dataflux_pytorch/tests/test_dataflux_mapstyle_dataset.py index 5f23e8cf..9d5ff062 100644 --- a/dataflux_pytorch/tests/test_dataflux_mapstyle_dataset.py +++ b/dataflux_pytorch/tests/test_dataflux_mapstyle_dataset.py @@ -1,5 +1,5 @@ """ - Copyright 2023 Google LLC + Copyright 2024 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.