diff --git a/.flake8 b/.flake8 index 14f53be8e..0e17765c9 100644 --- a/.flake8 +++ b/.flake8 @@ -4,6 +4,7 @@ ignore = E203,E402,E501,F821,W503,W504, per-file-ignores = __init__.py: F401, F403, F405 test/*: F401 + _extension.py: F401 exclude = ./.git, ./third_party, diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b69739a6f..18b68c850 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,27 +29,77 @@ jobs: - 3.7 - 3.8 - 3.9 + with-s3: + - 1 + - 0 steps: - name: Setup additional system libraries if: startsWith( matrix.os, 'ubuntu' ) run: | sudo add-apt-repository multiverse sudo apt update - sudo apt install rar unrar + sudo apt install rar unrar libssl-dev libcurl4-openssl-dev zlib1g-dev - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + - name: Setup msbuild on Windows + if: matrix.with-s3 == 1 && matrix.os == 'windows-latest' + uses: microsoft/setup-msbuild@v1.1 + - name: Set up Visual Studio shell + if: matrix.with-s3 == 1 && matrix.os == 'windows-latest' + uses: egor-tensin/vs-shell@v2 + with: + arch: x64 - name: Check out source repository uses: actions/checkout@v2 - name: Install dependencies run: | pip3 install -r requirements.txt pip3 install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + pip3 install cmake ninja pybind11 + echo "/home/runner/.local/bin" >> $GITHUB_PATH + - name: Export AWS-SDK-CPP & PYBIND11 + if: matrix.with-s3 == 1 + shell: bash + run: | + if [[ ${{ matrix.os }} == 'windows-latest' ]]; then + AWSSDK_PATH="$GITHUB_WORKSPACE\\aws-sdk-cpp\\sdk-lib" + else + AWSSDK_PATH="$GITHUB_WORKSPACE/aws-sdk-cpp/sdk-lib" + fi + PYBIND11_PATH=`pybind11-config --cmakedir` + echo "::set-output name=awssdk::$AWSSDK_PATH" + echo "::set-output name=pybind11::$PYBIND11_PATH" + id: export_path + - name: Install AWS-SDK-CPP on Windows for S3 IO datapipes + if: matrix.with-s3 == 1 && matrix.os == 'windows-latest' + run: | + git clone --recurse-submodules https://github.com/aws/aws-sdk-cpp + cd aws-sdk-cpp + mkdir sdk-lib + cmake -S . -B build -GNinja -DBUILD_ONLY="s3;transfer" -DBUILD_SHARED_LIBS=OFF -DENABLE_TESTING=OFF -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=sdk-lib + cmake --build build --config Release + cmake --install build --config Release + - name: Install AWS-SDK-CPP on Non-Windows for S3 IO datapipes + if: matrix.with-s3 == 1 && matrix.os != 'windows-latest' + run: | + git clone --recurse-submodules https://github.com/aws/aws-sdk-cpp + cd aws-sdk-cpp/ + mkdir sdk-build sdk-lib + cd sdk-build + cmake .. -DCMAKE_BUILD_TYPE=Release -DBUILD_ONLY="s3;transfer" -DENABLE_TESTING=OFF -DCMAKE_INSTALL_PREFIX=../sdk-lib + make + sudo make install + - name: Build TorchData + run: | + python setup.py develop + env: + BUILD_S3: ${{ matrix.with-s3 }} + pybind11_DIR: ${{ steps.export_path.outputs.pybind11 }} + AWSSDK_DIR: ${{ steps.export_path.outputs.awssdk }} - name: Install test requirements run: pip3 install expecttest fsspec iopath==0.1.9 numpy pytest rarfile - - name: Build TorchData - run: python setup.py develop - name: Run DataPipes tests with pytest if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} run: diff --git a/.gitignore b/.gitignore index 65607dd22..7c2168b9f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ build/* dist/* -torchdata.egg-info/* +*.egg-info/* torchdata/version.py torchdata/datapipes/iter/__init__.pyi @@ -17,6 +17,24 @@ torchdata/datapipes/iter/__init__.pyi # macOS dir files .DS_Store +## General + +*/*.so* +*/**/*.so* +torchdata/*.so* + +# Compiled Object files +*.slo +*.lo +*.o +*.cuo +*.obj + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + # Compiled python *.pyc *.pyd diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..531dca41c --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,59 @@ +cmake_minimum_required(VERSION 3.13 FATAL_ERROR) + +# Most of the configurations are taken from PyTorch +# https://github.com/pytorch/pytorch/blob/0c9fb4aff0d60eaadb04e4d5d099fb1e1d5701a9/CMakeLists.txt + +# Use compiler ID "AppleClang" instead of "Clang" for XCode. +# Not setting this sometimes makes XCode C compiler gets detected as "Clang", +# even when the C++ one is detected as "AppleClang". +cmake_policy(SET CMP0010 NEW) +cmake_policy(SET CMP0025 NEW) + +# Suppress warning flags in default MSVC configuration. It's not +# mandatory that we do this (and we don't if cmake is old), but it's +# nice when it's possible, and it's possible on our Windows configs. +if(NOT CMAKE_VERSION VERSION_LESS 3.15.0) + cmake_policy(SET CMP0092 NEW) +endif() + +project(torchdata) + +# check and set CMAKE_CXX_STANDARD +string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard) +if(env_cxx_standard GREATER -1) + message( + WARNING "C++ standard version definition detected in environment variable." + "PyTorch requires -std=c++14. Please remove -std=c++ settings in your environment.") +endif() + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_C_STANDARD 11) + +# https://developercommunity.visualstudio.com/t/VS-16100-isnt-compatible-with-CUDA-11/1433342 +if(MSVC) + if(USE_CUDA) + set(CMAKE_CXX_STANDARD 17) + endif() +endif() + + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# Apple specific +if(APPLE) + # Get clang version on macOS + execute_process( COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string ) + string(REGEX REPLACE "Apple LLVM version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION_STRING ${clang_full_version_string}) + message( STATUS "CLANG_VERSION_STRING: " ${CLANG_VERSION_STRING} ) + + # RPATH stuff + set(CMAKE_MACOSX_RPATH ON) + + set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") +endif() + +# Options +option(BUILD_S3 "Build s3 io functionality" OFF) + +add_subdirectory(torchdata/csrc) diff --git a/setup.py b/setup.py index 464afc45b..fc3925005 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,17 @@ #!/usr/bin/env python # Copyright (c) Facebook, Inc. and its affiliates. +import distutils.command.clean import os +import shutil import subprocess +import sys from pathlib import Path from setuptools import find_packages, setup -from torchdata.datapipes.gen_pyi import gen_pyi +from tools import setup_helpers +from tools.gen_pyi import gen_pyi ROOT_DIR = Path(__file__).parent.resolve() @@ -52,12 +56,41 @@ def _export_version(version, sha): ] +class clean(distutils.command.clean.clean): + def run(self): + # Run default behavior first + distutils.command.clean.clean.run(self) + + # Remove torchdata extension + def remove_extension(pattern): + for path in (ROOT_DIR / "torchdata").glob(pattern): + print(f"removing extension '{path}'") + path.unlink() + + for ext in ["so", "dylib", "pyd"]: + remove_extension("**/*." + ext) + + # Remove build directory + build_dirs = [ + ROOT_DIR / "build", + ] + for path in build_dirs: + if path.exists(): + print(f"removing '{path}' (and everything under it)") + shutil.rmtree(str(path), ignore_errors=True) + + if __name__ == "__main__": VERSION, SHA = _get_version() _export_version(VERSION, SHA) print("-- Building version " + VERSION) + if sys.argv[1] != "clean": + gen_pyi() + # TODO: Fix #343 + os.chdir(ROOT_DIR) + setup( # Metadata name="torchdata", @@ -82,8 +115,18 @@ def _export_version(version, sha): "Programming Language :: Python :: Implementation :: CPython", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], + package_data={ + "torchdata": [ + "datapipes/iter/*.pyi", + ], + }, # Package Info - packages=find_packages(exclude=["test*", "examples*"]), + packages=find_packages(exclude=["test*", "examples*", "tools*", "torchdata.csrc*", "build*"]), zip_safe=False, + # C++ Extension Modules + ext_modules=setup_helpers.get_ext_modules(), + cmdclass={ + "build_ext": setup_helpers.CMakeBuild, + "clean": clean, + }, ) - gen_pyi() diff --git a/test/test_remote_io.py b/test/test_remote_io.py index 5bf3b92e4..179156676 100644 --- a/test/test_remote_io.py +++ b/test/test_remote_io.py @@ -6,9 +6,19 @@ import expecttest +import torchdata + from _utils._common_utils_for_test import check_hash_fn, create_temp_dir -from torchdata.datapipes.iter import EndOnDiskCacheHolder, FileOpener, HttpReader, IterableWrapper, OnDiskCacheHolder +from torchdata.datapipes.iter import ( + EndOnDiskCacheHolder, + FileOpener, + HttpReader, + IterableWrapper, + OnDiskCacheHolder, + S3FileLister, + S3FileLoader, +) class TestDataPipeRemoteIO(expecttest.TestCase): @@ -161,6 +171,110 @@ def _read_and_decode(x): self.assertTrue(os.path.exists(expected_csv_path)) self.assertEqual(expected_csv_path, csv_path) + def test_s3_io_iterdatapipe(self): + # sanity test + file_urls = ["s3://ai2-public-datasets"] + try: + s3_lister_dp = S3FileLister(IterableWrapper(file_urls)) + s3_loader_dp = S3FileLoader(IterableWrapper(file_urls)) + except ModuleNotFoundError: + warnings.warn( + "S3 IO datapipes or C++ extension '_torchdata' isn't built in the current 'torchdata' package" + ) + return + + # S3FileLister: different inputs + input_list = [ + [["s3://ai2-public-datasets"], 71], # bucket without '/' + [["s3://ai2-public-datasets/"], 71], # bucket with '/' + [["s3://ai2-public-datasets/charades"], 18], # folder without '/' + [["s3://ai2-public-datasets/charades/"], 18], # folder without '/' + [["s3://ai2-public-datasets/charad"], 18], # prefix + [ + [ + "s3://ai2-public-datasets/charades/Charades_v1", + "s3://ai2-public-datasets/charades/Charades_vu17", + ], + 12, + ], # prefixes + [["s3://ai2-public-datasets/charades/Charades_v1.zip"], 1], # single file + [ + [ + "s3://ai2-public-datasets/charades/Charades_v1.zip", + "s3://ai2-public-datasets/charades/Charades_v1_flow.tar", + "s3://ai2-public-datasets/charades/Charades_v1_rgb.tar", + "s3://ai2-public-datasets/charades/Charades_v1_480.zip", + ], + 4, + ], # multiple files + [ + [ + "s3://ai2-public-datasets/charades/Charades_v1.zip", + "s3://ai2-public-datasets/charades/Charades_v1_flow.tar", + "s3://ai2-public-datasets/charades/Charades_v1_rgb.tar", + "s3://ai2-public-datasets/charades/Charades_v1_480.zip", + "s3://ai2-public-datasets/charades/Charades_vu17", + ], + 10, + ], # files + prefixes + ] + for input in input_list: + s3_lister_dp = S3FileLister(IterableWrapper(input[0]), region="us-west-2") + self.assertEqual(sum(1 for _ in s3_lister_dp), input[1], f"{input[0]} failed") + + # S3FileLister: prefixes + different region + file_urls = [ + "s3://aft-vbi-pds/bin-images/111", + "s3://aft-vbi-pds/bin-images/222", + ] + s3_lister_dp = S3FileLister(IterableWrapper(file_urls), region="us-east-1") + self.assertEqual(sum(1 for _ in s3_lister_dp), 2212, f"{input} failed") + + # S3FileLister: incorrect inputs + input_list = [ + [""], + ["ai2-public-datasets"], + ["s3://"], + ["s3:///bin-images"], + ] + for input in input_list: + with self.assertRaises(ValueError, msg=f"{input} should raise ValueError."): + s3_lister_dp = S3FileLister(IterableWrapper(input), region="us-east-1") + for _ in s3_lister_dp: + pass + + # S3FileLoader: loader + input = [ + "s3://charades-tar-shards/charades-video-0.tar", + "s3://charades-tar-shards/charades-video-1.tar", + ] # multiple files + s3_loader_dp = S3FileLoader(input, region="us-west-2") + self.assertEqual(sum(1 for _ in s3_loader_dp), 2, f"{input} failed") + + input = [["s3://aft-vbi-pds/bin-images/100730.jpg"], 1] + s3_loader_dp = S3FileLoader(input[0], region="us-east-1") + self.assertEqual(sum(1 for _ in s3_loader_dp), input[1], f"{input[0]} failed") + + # S3FileLoader: incorrect inputs + input_list = [ + [""], + ["ai2-public-datasets"], + ["s3://"], + ["s3:///bin-images"], + ["s3://ai2-public-datasets/bin-image"], + ] + for input in input_list: + with self.assertRaises(ValueError, msg=f"{input} should raise ValueError."): + s3_loader_dp = S3FileLoader(input, region="us-east-1") + for _ in s3_loader_dp: + pass + + # integration test + input = [["s3://charades-tar-shards/"], 10] + s3_lister_dp = S3FileLister(IterableWrapper(input[0]), region="us-west-2") + s3_loader_dp = S3FileLoader(s3_lister_dp, region="us-west-2") + self.assertEqual(sum(1 for _ in s3_loader_dp), input[1], f"{input[0]} failed") + if __name__ == "__main__": unittest.main() diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchdata/datapipes/gen_pyi.py b/tools/gen_pyi.py similarity index 91% rename from torchdata/datapipes/gen_pyi.py rename to tools/gen_pyi.py index e08d19776..d3fa4b75a 100644 --- a/torchdata/datapipes/gen_pyi.py +++ b/tools/gen_pyi.py @@ -1,6 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. import os -import pathlib from pathlib import Path from typing import Dict, List, Optional, Set @@ -26,11 +25,11 @@ def get_lines_base_file(base_file_path: str, to_skip: Optional[Set[str]] = None) def gen_pyi() -> None: - ROOT_DIR = Path(__file__).parent.resolve() - print(f"Generating DataPipe Python interface file in {ROOT_DIR}") + DATAPIPE_DIR = Path(__file__).parent.parent.resolve() / "torchdata" / "datapipes" + print(f"Generating DataPipe Python interface file in {DATAPIPE_DIR}") iter_init_base = get_lines_base_file( - os.path.join(ROOT_DIR, "iter/__init__.py"), + os.path.join(DATAPIPE_DIR, "iter/__init__.py"), {"from torch.utils.data import IterDataPipe", "# Copyright (c) Facebook, Inc. and its affiliates."}, ) @@ -65,7 +64,7 @@ def gen_pyi() -> None: iterDP_deprecated_files, "IterDataPipe", iterDP_method_to_special_output_type, - root=str(pathlib.Path(__file__).parent.resolve()), + root=str(DATAPIPE_DIR), ) td_iter_method_definitions = [ @@ -77,7 +76,7 @@ def gen_pyi() -> None: replacements = [("${init_base}", iter_init_base, 0), ("${IterDataPipeMethods}", iter_method_definitions, 4)] gen_from_template( - dir=str(ROOT_DIR), + dir=str(DATAPIPE_DIR), template_name="iter/__init__.pyi.in", output_name="iter/__init__.pyi", replacements=replacements, diff --git a/tools/setup_helpers/__init__.py b/tools/setup_helpers/__init__.py new file mode 100644 index 000000000..7afa3f31c --- /dev/null +++ b/tools/setup_helpers/__init__.py @@ -0,0 +1 @@ +from .extension import * # noqa diff --git a/tools/setup_helpers/extension.py b/tools/setup_helpers/extension.py new file mode 100644 index 000000000..7d63fead6 --- /dev/null +++ b/tools/setup_helpers/extension.py @@ -0,0 +1,123 @@ +import distutils.sysconfig +import os +import platform +import subprocess +import sys +from pathlib import Path + +from setuptools.command.build_ext import build_ext + +try: + from pybind11.setup_helpers import Pybind11Extension +except ImportError: + from setuptools import Extension as Pybind11Extension + + +__all__ = [ + "get_ext_modules", + "CMakeBuild", +] + + +_THIS_DIR = Path(__file__).parent.resolve() +_ROOT_DIR = _THIS_DIR.parent.parent.resolve() + + +def _get_build(var, default=False): + if var not in os.environ: + return default + + val = os.environ.get(var, "0") + trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"] + falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"] + if val in trues: + return True + if val not in falses: + print(f"WARNING: Unexpected environment variable value `{var}={val}`. " f"Expected one of {trues + falses}") + return False + + +_BUILD_S3 = _get_build("BUILD_S3", False) +_AWSSDK_DIR = os.environ.get("AWSSDK_DIR", None) + + +def get_ext_modules(): + if _BUILD_S3: + return [Pybind11Extension(name="torchdata._torchdata", sources=[])] + else: + return [] + + +class CMakeBuild(build_ext): + def run(self): + try: + subprocess.check_output(["cmake", "--version"]) + except OSError: + raise RuntimeError("CMake is not available.") from None + super().run() + + def build_extension(self, ext): + # Because the following `cmake` command will build all of `ext_modules`` at the same time, + # we would like to prevent multiple calls to `cmake`. + # Therefore, we call `cmake` only for `torchdata._torchdata`, + # in case `ext_modules` contains more than one module. + if ext.name != "torchdata._torchdata": + return + + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + + # required for auto-detection of auxiliary "native" libs + if not extdir.endswith(os.path.sep): + extdir += os.path.sep + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + cmake_args = [ + f"-DCMAKE_BUILD_TYPE={cfg}", + f"-DCMAKE_INSTALL_PREFIX={extdir}", + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY={extdir}", # For Windows + f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}", + f"-DBUILD_S3:BOOL={'ON' if _BUILD_S3 else 'OFF'}", + ] + + build_args = ["--config", cfg] + + if _BUILD_S3 and _AWSSDK_DIR: + cmake_args += [ + f"-DAWSSDK_DIR={_AWSSDK_DIR}", + ] + + # Default to Ninja + if "CMAKE_GENERATOR" not in os.environ or platform.system() == "Windows": + cmake_args += ["-GNinja"] + if platform.system() == "Windows": + python_version = sys.version_info + cmake_args += [ + "-DCMAKE_C_COMPILER=cl", + "-DCMAKE_CXX_COMPILER=cl", + f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", + ] + + # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level + # across all generators. + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += [f"-j{self.parallel}"] + + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + + subprocess.check_call(["cmake", str(_ROOT_DIR)] + cmake_args, cwd=self.build_temp) + subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp) + + def get_ext_filename(self, fullname): + ext_filename = super().get_ext_filename(fullname) + ext_filename_parts = ext_filename.split(".") + without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:] + ext_filename = ".".join(without_abi) + return ext_filename diff --git a/torchdata/__init__.py b/torchdata/__init__.py index 18a0daa88..d294f6886 100644 --- a/torchdata/__init__.py +++ b/torchdata/__init__.py @@ -1,4 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. +from torchdata import _extension # noqa: F401 + from . import datapipes try: diff --git a/torchdata/_extension.py b/torchdata/_extension.py new file mode 100644 index 000000000..4c39c66a6 --- /dev/null +++ b/torchdata/_extension.py @@ -0,0 +1,28 @@ +import importlib +import os +from pathlib import Path + + +_LIB_DIR = Path(__file__).parent + + +def _init_extension(): + lib_dir = os.path.dirname(__file__) + + # TODO: If any extension had dependency of shared library, + # in order to support load these shred libraries dynamically, + # we need to add logic to load dll path on Windows + # See: https://github.com/pytorch/pytorch/blob/master/torch/__init__.py#L56-L140 + + loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) # type: ignore[arg-type] + ext_specs = extfinder.find_spec("_torchdata") + + if ext_specs is None: + return + + from torchdata import _torchdata as _torchdata + + +_init_extension() diff --git a/torchdata/_torchdata/__init__.pyi b/torchdata/_torchdata/__init__.pyi new file mode 100644 index 000000000..372d877cd --- /dev/null +++ b/torchdata/_torchdata/__init__.pyi @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import List + +# TODO: Add pyi generate script +class S3Handler: + def __init__(self, request_timeout_ms: int, region: str) -> None: ... + def s3_read(self, url: str) -> bytes: ... + def list_files(self, prefix: str) -> List[str]: ... + def set_buffer_size(self, buffer_size: int) -> None: ... + def set_multi_part_download(self, multi_part_download: bool) -> None: ... + def clear_marker(self) -> None: ... diff --git a/torchdata/csrc/CMakeLists.txt b/torchdata/csrc/CMakeLists.txt new file mode 100644 index 000000000..88740fa24 --- /dev/null +++ b/torchdata/csrc/CMakeLists.txt @@ -0,0 +1,50 @@ +if(BUILD_S3) + + message(STATUS "Building S3 IO functionality") + + # To make the right CPython is built with on GitHub Actions, + # see https://github.com/actions/setup-python/issues/121#issuecomment-1014500503 + set(Python_FIND_FRAMEWORK "LAST") + + if (WIN32) + find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Interpreter Development) + set(ADDITIONAL_ITEMS Python3::Python) + else() + find_package(Python3 COMPONENTS Interpreter Development) + endif() + + list(APPEND CMAKE_PREFIX_PATH ${AWSSDK_DIR}) + find_package(AWSSDK REQUIRED COMPONENTS s3 transfer) + find_package(pybind11 CONFIG REQUIRED) + + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + + set( + EXTENSION_SOURCES + pybind/pybind.cpp + pybind/S3Handler/S3Handler.cpp + ) + + add_library(_torchdata MODULE ${EXTENSION_SOURCES}) + + target_include_directories(_torchdata PRIVATE ${PROJECT_SOURCE_DIR} ${Python_INCLUDE_DIR} {pybind11_INCLUDE_DIRS}) + + message(STATUS "AWSSDK linked libs: ${AWSSDK_LINK_LIBRARIES}") + target_link_libraries( + _torchdata + PRIVATE + pybind11::module + ${AWSSDK_LINK_LIBRARIES} + ${Python_LIBRARIES} + ${ADDITIONAL_ITEMS} + ) + + set_target_properties(_torchdata PROPERTIES PREFIX "") + if (MSVC) + set_target_properties(_torchdata PROPERTIES SUFFIX ".pyd") + endif(MSVC) + if (APPLE) + set_target_properties(_torchdata PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + endif(APPLE) + +endif() diff --git a/torchdata/csrc/pybind/S3Handler/S3Handler.cpp b/torchdata/csrc/pybind/S3Handler/S3Handler.cpp new file mode 100644 index 000000000..6d97457bf --- /dev/null +++ b/torchdata/csrc/pybind/S3Handler/S3Handler.cpp @@ -0,0 +1,397 @@ +#include "S3Handler.h" + +namespace torchdata { + +namespace { + +static const size_t S3DefaultBufferSize = 128 * 1024 * 1024; // 128 MB +static const uint64_t S3DefaultMultiPartDownloadChunkSize = + 5 * 1024 * 1024; // 5 MB +static const int executorPoolSize = 25; +static const std::string S3DefaultMarker = ""; + +std::shared_ptr setUpS3Config( + const long requestTimeoutMs, + const std::string region) { + std::shared_ptr cfg = + std::shared_ptr( + new Aws::Client::ClientConfiguration()); + Aws::String config_file; + const char* config_file_env = getenv("AWS_CONFIG_FILE"); + if (config_file_env) { + config_file = config_file_env; + } else { + const char* home_env = getenv("HOME"); + if (home_env) { + config_file = home_env; + config_file += "/.aws/config"; + } + } + Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file); + loader.Load(); + + const char* use_https = getenv("S3_USE_HTTPS"); + if (use_https) { + if (use_https[0] == '0') { + cfg->scheme = Aws::Http::Scheme::HTTP; + } else { + cfg->scheme = Aws::Http::Scheme::HTTPS; + } + } + const char* verify_ssl = getenv("S3_VERIFY_SSL"); + if (verify_ssl) { + if (verify_ssl[0] == '0') { + cfg->verifySSL = false; + } else { + cfg->verifySSL = true; + } + } + const char* endpoint_url = getenv("S3_ENDPOINT_URL"); + if (endpoint_url) { + cfg->endpointOverride = endpoint_url; + } + if (region != "") { + cfg->region = region; + } else { + const char* env_region = getenv("AWS_REGION"); + if (env_region) { + cfg->region = env_region; + } + } + if (requestTimeoutMs > -1) { + cfg->requestTimeoutMs = requestTimeoutMs; + } + return cfg; +} + +void ShutdownClient(std::shared_ptr* s3_client) { + if (s3_client != nullptr) { + delete s3_client; + Aws::SDKOptions options; + Aws::ShutdownAPI(options); + } +} + +void ShutdownTransferManager( + std::shared_ptr* transfer_manager) { + if (transfer_manager != nullptr) { + delete transfer_manager; + } +} + +void ShutdownExecutor(Aws::Utils::Threading::PooledThreadExecutor* executor) { + if (executor != nullptr) { + delete executor; + } +} + +void parseS3Path( + const Aws::String& fname, + Aws::String* bucket, + Aws::String* object) { + if (fname.empty()) { + throw std::invalid_argument("The filename cannot be an empty string."); + } + + if (fname.size() < 5 || fname.substr(0, 5) != "s3://") { + throw std::invalid_argument("The filename must start with the S3 scheme."); + } + + std::string path = fname.substr(5); + + if (path.empty()) { + throw std::invalid_argument("The filename cannot be an empty string."); + } + + size_t pos = path.find_first_of('/'); + if (pos == 0) { + throw std::invalid_argument("The filename does not contain a bucket name."); + } + + *bucket = path.substr(0, pos); + *object = path.substr(pos + 1); + if (pos == std::string::npos) { + *object = ""; + } +} + +class S3FS { + private: + std::string bucket_name_; + std::string object_name_; + bool use_multi_part_download_; + std::shared_ptr s3_client_; + std::shared_ptr transfer_manager_; + + public: + S3FS( + const std::string& bucket, + const std::string& object, + const bool use_multi_part_download, + std::shared_ptr transfer_manager, + std::shared_ptr s3_client) + : bucket_name_(bucket), + object_name_(object), + use_multi_part_download_(use_multi_part_download), + transfer_manager_(transfer_manager), + s3_client_(s3_client) {} + + size_t Read(uint64_t offset, size_t n, char* buffer) { + if (use_multi_part_download_) { + return ReadTransferManager(offset, n, buffer); + } else { + return ReadS3Client(offset, n, buffer); + } + } + + size_t ReadS3Client(uint64_t offset, size_t n, char* buffer) { + Aws::S3::Model::GetObjectRequest getObjectRequest; + + getObjectRequest.WithBucket(bucket_name_.c_str()) + .WithKey(object_name_.c_str()); + + std::string bytes = "bytes="; + bytes += std::to_string(offset) + "-" + std::to_string(offset + n - 1); + + getObjectRequest.SetRange(bytes.c_str()); + + // When you don’t want to load the entire file into memory, + // you can use IOStreamFactory in AmazonWebServiceRequest to pass a + // lambda to create a string stream. + getObjectRequest.SetResponseStreamFactory( + []() { return Aws::New("S3IOAllocationTag"); }); + // get the object + Aws::S3::Model::GetObjectOutcome getObjectOutcome = + s3_client_->GetObject(getObjectRequest); + + if (!getObjectOutcome.IsSuccess()) { + Aws::S3::S3Error error = getObjectOutcome.GetError(); + std::cout << "ERROR: " << error.GetExceptionName() << ": " + << error.GetMessage() << std::endl; + return 0; + } else { + n = getObjectOutcome.GetResult().GetContentLength(); + // read data as a block: + getObjectOutcome.GetResult().GetBody().read(buffer, n); + return n; + } + } + + size_t ReadTransferManager(uint64_t offset, size_t n, char* buffer) { + auto create_stream_fn = [&]() { // create stream lambda fn + return Aws::New( + "S3ReadStream", + Aws::New( + "S3ReadStream", reinterpret_cast(buffer), n)); + }; // This buffer is what we used to initialize streambuf and is in memory + + std::shared_ptr downloadHandle = + transfer_manager_.get()->DownloadFile( + bucket_name_.c_str(), + object_name_.c_str(), + offset, + n, + create_stream_fn); + downloadHandle->WaitUntilFinished(); + + Aws::OFStream storeFile( + object_name_.c_str(), Aws::OFStream::out | Aws::OFStream::trunc); + + if (downloadHandle->GetStatus() != + Aws::Transfer::TransferStatus::COMPLETED) { + const Aws::Client::AWSError error = + downloadHandle->GetLastError(); + std::cout << "ERROR: " << error.GetExceptionName() << ": " + << error.GetMessage() << std::endl; + return 0; + } else { + return downloadHandle->GetBytesTransferred(); + } + } +}; + +} // namespace + +std::shared_ptr S3Handler::s3_handler_cfg_; + +S3Handler::S3Handler(const long requestTimeoutMs, const std::string region) + : s3_client_(nullptr, ShutdownClient), + transfer_manager_(nullptr, ShutdownTransferManager), + executor_(nullptr, ShutdownExecutor) { + initialization_lock_ = std::shared_ptr(new std::mutex()); + + // Load reading parameters + buffer_size_ = S3DefaultBufferSize; + const char* bufferSizeStr = getenv("S3_BUFFER_SIZE"); + if (bufferSizeStr) { + buffer_size_ = std::stoull(bufferSizeStr); + } + use_multi_part_download_ = true; + const char* use_multi_part_download_char = getenv("S3_MULTI_PART_DOWNLOAD"); + if (use_multi_part_download_char) { + std::string use_multi_part_download_str(use_multi_part_download_char); + if (use_multi_part_download_str == "OFF") { + use_multi_part_download_ = false; + } + } + + Aws::SDKOptions options; + Aws::InitAPI(options); + S3Handler::s3_handler_cfg_ = setUpS3Config(requestTimeoutMs, region); + InitializeS3Client(); + + last_marker_ = S3DefaultMarker; +} + +S3Handler::~S3Handler() {} + +void S3Handler::InitializeS3Client() { + std::lock_guard lock(*initialization_lock_); + s3_client_ = std::shared_ptr(new Aws::S3::S3Client( + *S3Handler::s3_handler_cfg_, + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, + false)); +} + +void S3Handler::InitializeExecutor() { + executor_ = Aws::MakeShared( + "executor", executorPoolSize); +} + +void S3Handler::InitializeTransferManager() { + std::shared_ptr s3_client = GetS3Client(); + std::lock_guard lock(*initialization_lock_); + + Aws::Transfer::TransferManagerConfiguration transfer_config( + GetExecutor().get()); + transfer_config.s3Client = s3_client; + // This buffer is what we used to initialize streambuf and is in memory + transfer_config.bufferSize = S3DefaultMultiPartDownloadChunkSize; + transfer_config.transferBufferMaxHeapSize = + (executorPoolSize + 1) * S3DefaultMultiPartDownloadChunkSize; + transfer_manager_ = Aws::Transfer::TransferManager::Create(transfer_config); +} + +std::shared_ptr S3Handler::GetS3Client() { + if (s3_client_.get() == nullptr) { + InitializeS3Client(); + } + return s3_client_; +} + +std::shared_ptr +S3Handler::GetExecutor() { + if (executor_.get() == nullptr) { + InitializeExecutor(); + } + return executor_; +} + +std::shared_ptr +S3Handler::GetTransferManager() { + if (transfer_manager_.get() == nullptr) { + InitializeTransferManager(); + } + return transfer_manager_; +} + +size_t S3Handler::GetFileSize( + const std::string& bucket, + const std::string& object) { + Aws::S3::Model::HeadObjectRequest headObjectRequest; + headObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str()); + Aws::S3::Model::HeadObjectOutcome headObjectOutcome = + GetS3Client()->HeadObject(headObjectRequest); + if (headObjectOutcome.IsSuccess()) { + return headObjectOutcome.GetResult().GetContentLength(); + } else { + Aws::String const& error_aws = headObjectOutcome.GetError().GetMessage(); + std::string error_str(error_aws.c_str(), error_aws.size()); + throw std::invalid_argument(error_str); + return 0; + } +} + +void S3Handler::ClearMarker() { + last_marker_ = S3DefaultMarker; +} + +void S3Handler::S3Read(const std::string& file_url, std::string* result) { + std::string bucket, object; + parseS3Path(file_url, &bucket, &object); + S3FS s3fs( + bucket, + object, + use_multi_part_download_, + GetTransferManager(), + GetS3Client()); + + uint64_t offset = 0; + uint64_t result_size = 0; + uint64_t file_size = GetFileSize(bucket, object); + size_t part_count = + (std:: + max)(static_cast((file_size + buffer_size_ - 1) / buffer_size_), static_cast(1)); + result->resize(file_size); + + for (int i = 0; i < part_count; i++) { + offset = result_size; + + size_t buf_len = std::min(buffer_size_, file_size - result_size); + + size_t read_len = + s3fs.Read(offset, buf_len, (char*)(result->data()) + offset); + + result_size += read_len; + + if (result_size == file_size) { + break; + } + + if (read_len != buf_len) { + std::cout << "Result size and buffer size did not match"; + break; + } + } +} + +void S3Handler::ListFiles( + const std::string& file_url, + std::vector* filenames) { + Aws::String bucket, prefix; + parseS3Path(file_url, &bucket, &prefix); + + Aws::S3::Model::ListObjectsRequest listObjectsRequest; + listObjectsRequest.WithBucket(bucket).WithPrefix(prefix).WithMarker( + last_marker_); + + Aws::S3::Model::ListObjectsOutcome listObjectsOutcome = + GetS3Client()->ListObjects(listObjectsRequest); + if (!listObjectsOutcome.IsSuccess()) { + Aws::String const& error_aws = listObjectsOutcome.GetError().GetMessage(); + throw std::invalid_argument(error_aws); + } + + Aws::Vector objects = + listObjectsOutcome.GetResult().GetContents(); + if (objects.empty()) { + return; + } + for (const Aws::S3::Model::Object& object : objects) { + if (object.GetKey().back() == '/') // ignore folders + + { + continue; + } + Aws::String entry = "s3://" + bucket + "/" + object.GetKey(); + filenames->push_back(entry.c_str()); + } + last_marker_ = objects.back().GetKey(); + + // extreme cases when all objects are folders + if (filenames->size() == 0) { + ListFiles(file_url, filenames); + } +} + +} // namespace torchdata diff --git a/torchdata/csrc/pybind/S3Handler/S3Handler.h b/torchdata/csrc/pybind/S3Handler/S3Handler.h new file mode 100644 index 000000000..641359d69 --- /dev/null +++ b/torchdata/csrc/pybind/S3Handler/S3Handler.h @@ -0,0 +1,76 @@ +#include "precompile.h" + +namespace torchdata { + +// In memory stream implementation +class S3UnderlyingStream : public Aws::IOStream { + public: + using Base = Aws::IOStream; + + // provide a customer controlled streambuf, so as to put all transferred + // data into this in memory buffer. + S3UnderlyingStream(std::streambuf* buf) : Base(buf) {} + + virtual ~S3UnderlyingStream() = default; +}; + +class S3Handler { + private: + static std::shared_ptr s3_handler_cfg_; + + std::shared_ptr initialization_lock_; + std::shared_ptr s3_client_; + std::shared_ptr executor_; + std::shared_ptr transfer_manager_; + + Aws::String last_marker_; + size_t buffer_size_; + bool use_multi_part_download_; + + void InitializeS3Client(); + void InitializeExecutor(); + void InitializeTransferManager(); + + std::shared_ptr GetS3Client(); + std::shared_ptr GetExecutor(); + std::shared_ptr GetTransferManager(); + size_t GetFileSize(const std::string& bucket, const std::string& object); + + public: + S3Handler(const long requestTimeoutMs, const std::string region); + ~S3Handler(); + + void SetLastMarker(const Aws::String last_marker) { + this->last_marker_ = last_marker; + } + void SetBufferSize(const uint64_t buffer_size) { + this->buffer_size_ = buffer_size; + } + void SetMultiPartDownload(const bool multi_part_download) { + this->use_multi_part_download_ = multi_part_download; + } + void ClearMarker(); + + long GetRequestTimeoutMs() const { + return s3_handler_cfg_->requestTimeoutMs; + } + Aws::String GetRegion() const { + return s3_handler_cfg_->region; + } + Aws::String GetLastMarker() const { + return last_marker_; + } + bool GetUseMultiPartDownload() const { + return use_multi_part_download_; + } + size_t GetBufferSize() const { + return buffer_size_; + } + + void S3Read(const std::string& file_url, std::string* result); + void ListFiles( + const std::string& file_url, + std::vector* filenames); +}; + +} // namespace torchdata diff --git a/torchdata/csrc/pybind/S3Handler/precompile.h b/torchdata/csrc/pybind/S3Handler/precompile.h new file mode 100644 index 000000000..b1da9d499 --- /dev/null +++ b/torchdata/csrc/pybind/S3Handler/precompile.h @@ -0,0 +1,30 @@ +#ifndef TORCHDATA_S3_IO_H +#define TORCHDATA_S3_IO_H + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#endif // TORCHDATA_S3_IO_H diff --git a/torchdata/csrc/pybind/pybind.cpp b/torchdata/csrc/pybind/pybind.cpp new file mode 100644 index 000000000..c9f0f757e --- /dev/null +++ b/torchdata/csrc/pybind/pybind.cpp @@ -0,0 +1,64 @@ +#include +#include + +#include +#include + +#include + +namespace py = pybind11; +using torchdata::S3Handler; + +PYBIND11_MODULE(_torchdata, m) { + py::class_(m, "S3Handler") + .def(py::init()) + .def( + "s3_read", + [](S3Handler* self, const std::string& file_url) { + std::string result; + self->S3Read(file_url, &result); + return py::bytes(result); + }) + .def( + "list_files", + [](S3Handler* self, const std::string& file_url) { + std::vector filenames; + self->ListFiles(file_url, &filenames); + return filenames; + }) + .def( + "set_buffer_size", + [](S3Handler* self, const uint64_t buffer_size) { + self->SetBufferSize(buffer_size); + }) + .def( + "set_multi_part_download", + [](S3Handler* self, const bool multi_part_download) { + self->SetMultiPartDownload(multi_part_download); + }) + .def("clear_marker", [](S3Handler* self) { self->ClearMarker(); }) + .def(py::pickle( + [](const S3Handler& s3_handler) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple( + s3_handler.GetRequestTimeoutMs(), + s3_handler.GetRegion(), + s3_handler.GetLastMarker(), + s3_handler.GetUseMultiPartDownload(), + s3_handler.GetBufferSize()); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 5) + throw std::runtime_error("Invalid state!"); + + /* Create a new C++ instance */ + S3Handler s3_handler(t[0].cast(), t[1].cast()); + + /* Assign any additional state */ + s3_handler.SetLastMarker(t[2].cast()); + s3_handler.SetMultiPartDownload(t[3].cast()); + s3_handler.SetBufferSize(t[4].cast()); + + return s3_handler; + })); +} diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index 02085fa45..7fbe01a4f 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -38,6 +38,10 @@ HTTPReaderIterDataPipe as HttpReader, OnlineReaderIterDataPipe as OnlineReader, ) +from torchdata.datapipes.iter.load.s3io import ( + S3FileListerIterDataPipe as S3FileLister, + S3FileLoaderIterDataPipe as S3FileLoader, +) from torchdata.datapipes.iter.transform.bucketbatcher import ( BucketBatcherIterDataPipe as BucketBatcher, InBatchShufflerIterDataPipe as InBatchShuffler, @@ -148,6 +152,8 @@ "RarArchiveLoader", "RoutedDecoder", "Rows2Columnar", + "S3FileLister", + "S3FileLoader", "SampleMultiplexer", "Sampler", "Saver", diff --git a/torchdata/datapipes/iter/load/README.md b/torchdata/datapipes/iter/load/README.md new file mode 100644 index 000000000..459ef84df --- /dev/null +++ b/torchdata/datapipes/iter/load/README.md @@ -0,0 +1,77 @@ +# S3 IO Datapipe Documentation + +## Installation + +Torchdata S3 IO datapipes depends on [aws-sdk-cpp](https://github.com/aws/aws-sdk-cpp). The following is just a +recommended way to installing aws-sdk-cpp, please refer to official documentation for detailed instructions. + +```bash +git clone --recurse-submodules https://github.com/aws/aws-sdk-cpp +cd aws-sdk-cpp/ +mkdir sdk-build +cd sdk-build +# need to add flag -DBUILD_SHARED_LIBS=OFF for static linking on Windows +cmake .. -DCMAKE_BUILD_TYPE=Release -DBUILD_ONLY="s3;transfer" +make +make install # may need sudo +``` + +`ninja` and `pybind11` are also required to link PyThon implementation to C++ source code. + +```bash +conda install ninja pybind11 +``` + +S3 IO datapipes are't included when building by default. To build S3 IO in `torchdata`, at the `/data` root folder, run +the following commands. + +```bash +export BUILD_S3=ON +pip uninstall torchdata -y +python setup.py clean +python setup.py install +``` + +## Using S3 IO datapies + +### S3FileLister + +`S3FileLister` accepts a list of S3 prefixes and iterates all matching s3 urls. The functional API is `list_file_by_s3`. +Acceptable prefixes include `s3://bucket-name`, `s3://bucket-name/`, `s3://bucket-name/folder`, +`s3://bucket-name/folder/`, and `s3://bucket-name/prefix`. You may also set `length`, `request_timeout_ms` (default 3000 +ms in aws-sdk-cpp), and `region`. Note that: + +1. Input **must** be a list and direct S3 URLs are skipped. +2. `length` is `-1` by default, and any call to `__len__()` is invalid, because the length is unknown until all files + are iterated. +3. `request_timeout_ms` and `region` will overwrite settings in the configuration file or environment variables. + +### S3FileLoader + +`S3FileLoader` accepts a list of S3 URLs and iterates all files in `BytesIO` format with `(url, BytesIO)` tuples. The +functional API is `load_file_by_s3`. You may also set `request_timeout_ms` (default 3000 ms in aws-sdk-cpp), `region`, +`buffer_size` (default 120Mb), and `multi_part_download` (default to use multi-part downloading). Note that: + +1. Input **must** be a list and S3 URLs must be valid. +2. `request_timeout_ms` and `region` will overwrite settings in the configuration file or environment variables. + +### Example + +```py +from torchdata.datapipes.iter import S3FileLister, S3FileLoader + +s3_prefixes = ['s3://bucket-name/folder/', ...] +dp_s3_urls = S3FileLister(s3_prefixes) +dp_s3_files = S3FileLoader(s3_urls) # outputs in (url, StreamWrapper(BytesIO)) +# more datapipes to convert loaded bytes, e.g. +datapipe = StreamWrapper(dp_s3_files).parse_csv(delimiter=' ') + +for d in datapipe: # Start loading data + pass +``` + +### Note + +It's recommended to set up a detailed configuration file with the `AWS_CONFIG_FILE` environment variable. The following +environment variables are also parsed: `HOME`, `S3_USE_HTTPS`, `S3_VERIFY_SSL`, `S3_ENDPOINT_URL`, `AWS_REGION` (would +be overwritten by the `region` variable). diff --git a/torchdata/datapipes/iter/load/s3io.py b/torchdata/datapipes/iter/load/s3io.py new file mode 100644 index 000000000..c59a1ee4e --- /dev/null +++ b/torchdata/datapipes/iter/load/s3io.py @@ -0,0 +1,121 @@ +from io import BytesIO +from typing import Iterator, Tuple + +import torchdata +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe +from torchdata.datapipes.utils import StreamWrapper + + +@functional_datapipe("list_file_by_s3") +class S3FileListerIterDataPipe(IterDataPipe[str]): + r""":class:`S3FileListerIterDataPipe`. + + Iterable DataPipe that lists URLs with the given prefixes (functional name: ``list_file_by_s3``). + Acceptable prefixes include `s3://bucket-name`, `s3://bucket-name/`, `s3://bucket-name/folder`, + `s3://bucket-name/folder/`, and `s3://bucket-name/prefix`. You may also set `length`, `request_timeout_ms` (default 3000 + ms in aws-sdk-cpp), and `region`. Note that: + + 1. Input **must** be a list and direct S3 URLs are skipped. + 2. `length` is `-1` by default, and any call to `__len__()` is invalid, because the length is unknown until all files + are iterated. + 3. `request_timeout_ms` and `region` will overwrite settings in the configuration file or environment variables. + + Args: + source_datapipe: a DataPipe that contains URLs/URL prefixes to s3 files + length: Nominal length of the datapipe + requestTimeoutMs: optional, overwrite the default timeout setting for this datapipe + region: optional, overwrite the default region inferred from credentials for this datapipe + + Note: + AWS_CPP_SDK is necessary to use the S3 DataPipe(s). + + Example: + >>> from torchdata.datapipes.iter import S3FileLister, S3FileLoader + >>> s3_prefixes = ['s3://bucket-name/folder/', ...] + >>> dp_s3_urls = S3FileLister(s3_prefixes) + >>> dp_s3_files = S3FileLoader(s3_urls) # outputs in (url, StreamWrapper(BytesIO)) + >>> # more datapipes to convert loaded bytes, e.g. + >>> datapipe = dp_s3_files.parse_csv(delimiter=' ') + >>> for d in datapipe: # Start loading data + ... pass + """ + + def __init__(self, source_datapipe: IterDataPipe[str], length: int = -1, request_timeout_ms=-1, region="") -> None: + if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"): + raise ModuleNotFoundError("Torchdata must be built with BUILD_S3=1 to use this datapipe.") + + self.source_datapipe: IterDataPipe[str] = source_datapipe + self.length: int = length + self.handler = torchdata._torchdata.S3Handler(request_timeout_ms, region) + + def __iter__(self) -> Iterator[str]: + for prefix in self.source_datapipe: + while True: + urls = self.handler.list_files(prefix) + yield from urls + if not urls: + break + self.handler.clear_marker() + + def __len__(self) -> int: + if self.length == -1: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") + return self.length + + +@functional_datapipe("load_file_by_s3") +class S3FileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): + r""":class:`S3FileListerIterDataPipe`. + + Iterable DataPipe that loads S3 files given S3 URLs (functional name: ``load_file_by_s3``). + `S3FileLoader` iterates all given S3 URLs in `BytesIO` format with `(url, BytesIO)` tuples. + You may also set `request_timeout_ms` (default 3000 ms in aws-sdk-cpp), `region`, + `buffer_size` (default 120Mb), and `multi_part_download` (default to use multi-part downloading). Note that: + + 1. Input **must** be a list and S3 URLs must be valid. + 2. `request_timeout_ms` and `region` will overwrite settings in the configuration file or environment variables. + + Args: + source_datapipe: a DataPipe that contains URLs to s3 files + requestTimeoutMs: optional, overwrite the default timeout setting for this datapipe + region: optional, overwrite the default region inferred from credentials for this datapipe + + Note: + AWS_CPP_SDK is necessary to use the S3 DataPipe(s). + + Example: + >>> from torchdata.datapipes.iter import S3FileLister, S3FileLoader + >>> s3_prefixes = ['s3://bucket-name/folder/', ...] + >>> dp_s3_urls = S3FileLister(s3_prefixes) + >>> dp_s3_files = S3FileLoader(s3_urls) # outputs in (url, StreamWrapper(BytesIO)) + >>> # more datapipes to convert loaded bytes, e.g. + >>> datapipe = dp_s3_files.parse_csv(delimiter=' ') + >>> for d in datapipe: # Start loading data + ... pass + """ + + def __init__( + self, + source_datapipe: IterDataPipe[str], + request_timeout_ms=-1, + region="", + buffer_size=None, + multi_part_download=None, + ) -> None: + if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"): + raise ModuleNotFoundError("Torchdata must be built with BUILD_S3=1 to use this datapipe.") + + self.source_datapipe: IterDataPipe[str] = source_datapipe + self.handler = torchdata._torchdata.S3Handler(request_timeout_ms, region) + if buffer_size: + self.handler.set_buffer_size(buffer_size) + if multi_part_download: + self.handler.set_multi_part_download(multi_part_download) + + def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: + for url in self.source_datapipe: + yield url, StreamWrapper(BytesIO(self.handler.s3_read(url))) + + def __len__(self) -> int: + return len(self.source_datapipe)