Skip to content

Commit

Permalink
Improve fsspec DataPipe to accept extra keyword arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Jun 3, 2022
1 parent 5aac88f commit d7337e7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 18 deletions.
1 change: 1 addition & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pytest
expecttest
fsspec
s3fs
iopath == 0.1.9
numpy
rarfile
Expand Down
65 changes: 54 additions & 11 deletions test/test_remote_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,33 @@
from torchdata.datapipes.iter import (
EndOnDiskCacheHolder,
FileOpener,
FSSpecFileLister,
FSSpecFileOpener,
HttpReader,
IterableWrapper,
OnDiskCacheHolder,
S3FileLister,
S3FileLoader,
)

try:
import fsspec
import s3fs

HAS_FSSPEC_S3 = True
except ImportError:
HAS_FSSPEC_S3 = False
skipIfNoFSSpecS3 = unittest.skipIf(not HAS_FSSPEC_S3, "no FSSpec with S3fs")

try:
from torchdata._torchdata import S3Handler

HAS_AWS = True
except ImportError:
HAS_AWS = False
skipIfAWS = unittest.skipIf(HAS_AWS, "AWSSDK Enabled")
skipIfNoAWS = unittest.skipIf(not HAS_AWS, "No AWSSDK Enabled")


class TestDataPipeRemoteIO(expecttest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -183,20 +203,43 @@ def _read_and_decode(x):
if not IS_WINDOWS:
dl = DataLoader(file_cache_dp, num_workers=3, multiprocessing_context="fork", batch_size=1)
expected = [[os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")] for i in range(3)] * 3
self.assertEqual(sorted(expected), sorted(list(dl)))
res = list(dl)
self.assertEqual(sorted(expected), sorted(res))

def test_s3_io_iterdatapipe(self):
# sanity test
@skipIfNoFSSpecS3
def test_fsspec_io_iterdatapipe(self):
input_list = [
(["s3://ai2-public-datasets"], 39), # bucket without '/'
(["s3://ai2-public-datasets/charades/"], 18), # bucket with '/'
(
[
"s3://ai2-public-datasets/charades/Charades_v1.zip",
"s3://ai2-public-datasets/charades/Charades_v1_flow.tar",
"s3://ai2-public-datasets/charades/Charades_v1_rgb.tar",
"s3://ai2-public-datasets/charades/Charades_v1_480.zip",
],
4,
), # multiple files
]
for urls, num in input_list:
fsspec_lister_dp = FSSpecFileLister(IterableWrapper(urls), anon=True)
self.assertEqual(sum(1 for _ in fsspec_lister_dp), num, f"{urls} failed")

url = "s3://ai2-public-datasets/charades/"
fsspec_loader_dp = FSSpecFileOpener(FSSpecFileLister(IterableWrapper([url]), anon=True), anon=True)
res = list(fsspec_loader_dp)
self.assertEqual(len(res), 18, f"{input} failed")

@skipIfAWS
def test_disabled_s3_io_iterdatapipe(self):
file_urls = ["s3://ai2-public-datasets"]
try:
s3_lister_dp = S3FileLister(IterableWrapper(file_urls))
s3_loader_dp = S3FileLoader(IterableWrapper(file_urls))
except ModuleNotFoundError:
warnings.warn(
"S3 IO datapipes or C++ extension '_torchdata' isn't built in the current 'torchdata' package"
)
return
with self.assertRaisesRegex(ModuleNotFoundError, "TorchData must be built with"):
_ = S3FileLister(IterableWrapper(file_urls))
with self.assertRaisesRegex(ModuleNotFoundError, "TorchData must be built with"):
_ = S3FileLoader(IterableWrapper(file_urls))

@skipIfNoAWS
def test_s3_io_iterdatapipe(self):
# S3FileLister: different inputs
input_list = [
[["s3://ai2-public-datasets"], 77], # bucket without '/'
Expand Down
23 changes: 18 additions & 5 deletions torchdata/datapipes/iter/load/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]):
Args:
root: The root `fsspec` path directory or list of path directories to list files from
masks: Unix style filter string or string list for filtering file name(s)
kwargs: Extra options that make sense to a particular storage connection,
e.g. host, port, username, password, etc.
Example:
>>> from torchdata.datapipes.iter import FSSpecFileLister
Expand All @@ -51,6 +53,7 @@ def __init__(
self,
root: Union[str, Sequence[str], IterDataPipe],
masks: Union[str, List[str]] = "",
**kwargs,
) -> None:
_assert_fsspec()

Expand All @@ -63,10 +66,11 @@ def __init__(
else:
self.datapipe = root
self.masks = masks
self.kwargs = kwargs

def __iter__(self) -> Iterator[str]:
for root in self.datapipe:
fs, path = fsspec.core.url_to_fs(root)
fs, path = fsspec.core.url_to_fs(root, **self.kwargs)

if isinstance(fs.protocol, str):
protocol_list = [fs.protocol]
Expand All @@ -87,8 +91,10 @@ def __iter__(self) -> Iterator[str]:
else:
if is_local:
abs_path = os.path.join(path, file_name)
else:
elif not file_name.startswith(path):
abs_path = posixpath.join(path, file_name)
else:
abs_path = file_name

starts_with = False
for protocol in protocol_list:
Expand All @@ -110,22 +116,25 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: Iterable DataPipe that provides the pathnames or URLs
mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default)
kwargs: Extra options that make sense to a particular storage connection,
e.g. host, port, username, password, etc.
Example:
>>> from torchdata.datapipes.iter import FSSpecFileLister
>>> datapipe = FSSpecFileLister(root=dir_path)
>>> file_dp = datapipe.open_files_by_fsspec()
"""

def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r") -> None:
def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", **kwargs) -> None:
_assert_fsspec()

self.source_datapipe: IterDataPipe[str] = source_datapipe
self.mode: str = mode
self.kwargs = kwargs

def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for file_uri in self.source_datapipe:
fs, path = fsspec.core.url_to_fs(file_uri)
fs, path = fsspec.core.url_to_fs(file_uri, **self.kwargs)
file = fs.open(path, self.mode)
yield file_uri, StreamWrapper(file)

Expand All @@ -148,6 +157,8 @@ class FSSpecSaverIterDataPipe(IterDataPipe[str]):
source_datapipe: Iterable DataPipe with tuples of metadata and data
mode: Mode in which the file will be opened for write the data (``"w"`` by default)
filepath_fn: Function that takes in metadata and returns the target path of the new file
kwargs: Extra options that make sense to a particular storage connection,
e.g. host, port, username, password, etc.
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
Expand All @@ -164,17 +175,19 @@ def __init__(
source_datapipe: IterDataPipe[Tuple[Any, U]],
mode: str = "w",
filepath_fn: Optional[Callable] = None,
**kwargs,
):
_assert_fsspec()

self.source_datapipe: IterDataPipe[Tuple[Any, U]] = source_datapipe
self.mode: str = mode
self.filepath_fn: Optional[Callable] = filepath_fn
self.kwargs = kwargs

def __iter__(self) -> Iterator[str]:
for meta, data in self.source_datapipe:
filepath = meta if self.filepath_fn is None else self.filepath_fn(meta)
fs, path = fsspec.core.url_to_fs(filepath)
fs, path = fsspec.core.url_to_fs(filepath, **self.kwargs)
with fs.open(path, self.mode) as f:
f.write(data)
yield filepath
Expand Down
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/load/s3io.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class S3FileListerIterDataPipe(IterDataPipe[str]):

def __init__(self, source_datapipe: IterDataPipe[str], length: int = -1, request_timeout_ms=-1, region="") -> None:
if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"):
raise ModuleNotFoundError("Torchdata must be built with BUILD_S3=1 to use this datapipe.")
raise ModuleNotFoundError("TorchData must be built with BUILD_S3=1 to use this datapipe.")

self.source_datapipe: IterDataPipe[str] = source_datapipe
self.length: int = length
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
multi_part_download=None,
) -> None:
if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"):
raise ModuleNotFoundError("Torchdata must be built with BUILD_S3=1 to use this datapipe.")
raise ModuleNotFoundError("TorchData must be built with BUILD_S3=1 to use this datapipe.")

self.source_datapipe: IterDataPipe[str] = source_datapipe
self.handler = torchdata._torchdata.S3Handler(request_timeout_ms, region)
Expand Down

0 comments on commit d7337e7

Please sign in to comment.