diff --git a/test/requirements.txt b/test/requirements.txt index 4f223b2e6..05a2336ab 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -1,6 +1,7 @@ pytest expecttest fsspec +s3fs iopath == 0.1.9 numpy rarfile diff --git a/test/test_remote_io.py b/test/test_remote_io.py index 57c963297..1e432282e 100644 --- a/test/test_remote_io.py +++ b/test/test_remote_io.py @@ -19,6 +19,8 @@ from torchdata.datapipes.iter import ( EndOnDiskCacheHolder, FileOpener, + FSSpecFileLister, + FSSpecFileOpener, HttpReader, IterableWrapper, OnDiskCacheHolder, @@ -26,6 +28,24 @@ S3FileLoader, ) +try: + import fsspec + import s3fs + + HAS_FSSPEC_S3 = True +except ImportError: + HAS_FSSPEC_S3 = False +skipIfNoFSSpecS3 = unittest.skipIf(not HAS_FSSPEC_S3, "no FSSpec with S3fs") + +try: + from torchdata._torchdata import S3Handler + + HAS_AWS = True +except ImportError: + HAS_AWS = False +skipIfAWS = unittest.skipIf(HAS_AWS, "AWSSDK Enabled") +skipIfNoAWS = unittest.skipIf(not HAS_AWS, "No AWSSDK Enabled") + class TestDataPipeRemoteIO(expecttest.TestCase): def setUp(self): @@ -183,20 +203,43 @@ def _read_and_decode(x): if not IS_WINDOWS: dl = DataLoader(file_cache_dp, num_workers=3, multiprocessing_context="fork", batch_size=1) expected = [[os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")] for i in range(3)] * 3 - self.assertEqual(sorted(expected), sorted(list(dl))) + res = list(dl) + self.assertEqual(sorted(expected), sorted(res)) - def test_s3_io_iterdatapipe(self): - # sanity test + @skipIfNoFSSpecS3 + def test_fsspec_io_iterdatapipe(self): + input_list = [ + (["s3://ai2-public-datasets"], 39), # bucket without '/' + (["s3://ai2-public-datasets/charades/"], 18), # bucket with '/' + ( + [ + "s3://ai2-public-datasets/charades/Charades_v1.zip", + "s3://ai2-public-datasets/charades/Charades_v1_flow.tar", + "s3://ai2-public-datasets/charades/Charades_v1_rgb.tar", + "s3://ai2-public-datasets/charades/Charades_v1_480.zip", + ], + 4, + ), # multiple files + ] + for urls, num in input_list: + fsspec_lister_dp = FSSpecFileLister(IterableWrapper(urls), anon=True) + self.assertEqual(sum(1 for _ in fsspec_lister_dp), num, f"{urls} failed") + + url = "s3://ai2-public-datasets/charades/" + fsspec_loader_dp = FSSpecFileOpener(FSSpecFileLister(IterableWrapper([url]), anon=True), anon=True) + res = list(fsspec_loader_dp) + self.assertEqual(len(res), 18, f"{input} failed") + + @skipIfAWS + def test_disabled_s3_io_iterdatapipe(self): file_urls = ["s3://ai2-public-datasets"] - try: - s3_lister_dp = S3FileLister(IterableWrapper(file_urls)) - s3_loader_dp = S3FileLoader(IterableWrapper(file_urls)) - except ModuleNotFoundError: - warnings.warn( - "S3 IO datapipes or C++ extension '_torchdata' isn't built in the current 'torchdata' package" - ) - return + with self.assertRaisesRegex(ModuleNotFoundError, "TorchData must be built with"): + _ = S3FileLister(IterableWrapper(file_urls)) + with self.assertRaisesRegex(ModuleNotFoundError, "TorchData must be built with"): + _ = S3FileLoader(IterableWrapper(file_urls)) + @skipIfNoAWS + def test_s3_io_iterdatapipe(self): # S3FileLister: different inputs input_list = [ [["s3://ai2-public-datasets"], 77], # bucket without '/' diff --git a/torchdata/datapipes/iter/load/fsspec.py b/torchdata/datapipes/iter/load/fsspec.py index 2ded99eee..7a874b3aa 100644 --- a/torchdata/datapipes/iter/load/fsspec.py +++ b/torchdata/datapipes/iter/load/fsspec.py @@ -41,6 +41,8 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]): Args: root: The root `fsspec` path directory or list of path directories to list files from masks: Unix style filter string or string list for filtering file name(s) + kwargs: Extra options that make sense to a particular storage connection, + e.g. host, port, username, password, etc. Example: >>> from torchdata.datapipes.iter import FSSpecFileLister @@ -51,6 +53,7 @@ def __init__( self, root: Union[str, Sequence[str], IterDataPipe], masks: Union[str, List[str]] = "", + **kwargs, ) -> None: _assert_fsspec() @@ -63,10 +66,11 @@ def __init__( else: self.datapipe = root self.masks = masks + self.kwargs = kwargs def __iter__(self) -> Iterator[str]: for root in self.datapipe: - fs, path = fsspec.core.url_to_fs(root) + fs, path = fsspec.core.url_to_fs(root, **self.kwargs) if isinstance(fs.protocol, str): protocol_list = [fs.protocol] @@ -87,8 +91,10 @@ def __iter__(self) -> Iterator[str]: else: if is_local: abs_path = os.path.join(path, file_name) - else: + elif not file_name.startswith(path): abs_path = posixpath.join(path, file_name) + else: + abs_path = file_name starts_with = False for protocol in protocol_list: @@ -110,6 +116,8 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): Args: source_datapipe: Iterable DataPipe that provides the pathnames or URLs mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default) + kwargs: Extra options that make sense to a particular storage connection, + e.g. host, port, username, password, etc. Example: >>> from torchdata.datapipes.iter import FSSpecFileLister @@ -117,15 +125,16 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): >>> file_dp = datapipe.open_files_by_fsspec() """ - def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r") -> None: + def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", **kwargs) -> None: _assert_fsspec() self.source_datapipe: IterDataPipe[str] = source_datapipe self.mode: str = mode + self.kwargs = kwargs def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: for file_uri in self.source_datapipe: - fs, path = fsspec.core.url_to_fs(file_uri) + fs, path = fsspec.core.url_to_fs(file_uri, **self.kwargs) file = fs.open(path, self.mode) yield file_uri, StreamWrapper(file) @@ -148,6 +157,8 @@ class FSSpecSaverIterDataPipe(IterDataPipe[str]): source_datapipe: Iterable DataPipe with tuples of metadata and data mode: Mode in which the file will be opened for write the data (``"w"`` by default) filepath_fn: Function that takes in metadata and returns the target path of the new file + kwargs: Extra options that make sense to a particular storage connection, + e.g. host, port, username, password, etc. Example: >>> from torchdata.datapipes.iter import IterableWrapper @@ -164,17 +175,19 @@ def __init__( source_datapipe: IterDataPipe[Tuple[Any, U]], mode: str = "w", filepath_fn: Optional[Callable] = None, + **kwargs, ): _assert_fsspec() self.source_datapipe: IterDataPipe[Tuple[Any, U]] = source_datapipe self.mode: str = mode self.filepath_fn: Optional[Callable] = filepath_fn + self.kwargs = kwargs def __iter__(self) -> Iterator[str]: for meta, data in self.source_datapipe: filepath = meta if self.filepath_fn is None else self.filepath_fn(meta) - fs, path = fsspec.core.url_to_fs(filepath) + fs, path = fsspec.core.url_to_fs(filepath, **self.kwargs) with fs.open(path, self.mode) as f: f.write(data) yield filepath diff --git a/torchdata/datapipes/iter/load/s3io.py b/torchdata/datapipes/iter/load/s3io.py index 086156e8c..8d4878bce 100644 --- a/torchdata/datapipes/iter/load/s3io.py +++ b/torchdata/datapipes/iter/load/s3io.py @@ -51,7 +51,7 @@ class S3FileListerIterDataPipe(IterDataPipe[str]): def __init__(self, source_datapipe: IterDataPipe[str], length: int = -1, request_timeout_ms=-1, region="") -> None: if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"): - raise ModuleNotFoundError("Torchdata must be built with BUILD_S3=1 to use this datapipe.") + raise ModuleNotFoundError("TorchData must be built with BUILD_S3=1 to use this datapipe.") self.source_datapipe: IterDataPipe[str] = source_datapipe self.length: int = length @@ -113,7 +113,7 @@ def __init__( multi_part_download=None, ) -> None: if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"): - raise ModuleNotFoundError("Torchdata must be built with BUILD_S3=1 to use this datapipe.") + raise ModuleNotFoundError("TorchData must be built with BUILD_S3=1 to use this datapipe.") self.source_datapipe: IterDataPipe[str] = source_datapipe self.handler = torchdata._torchdata.S3Handler(request_timeout_ms, region)