Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve fsspec DataPipe to accept extra keyword arguments #495

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pytest
expecttest
fsspec
s3fs
iopath == 0.1.9
numpy
rarfile
Expand Down
65 changes: 54 additions & 11 deletions test/test_remote_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,33 @@
from torchdata.datapipes.iter import (
EndOnDiskCacheHolder,
FileOpener,
FSSpecFileLister,
FSSpecFileOpener,
HttpReader,
IterableWrapper,
OnDiskCacheHolder,
S3FileLister,
S3FileLoader,
)

try:
import fsspec
import s3fs

HAS_FSSPEC_S3 = True
except ImportError:
HAS_FSSPEC_S3 = False
skipIfNoFSSpecS3 = unittest.skipIf(not HAS_FSSPEC_S3, "no FSSpec with S3fs")

try:
from torchdata._torchdata import S3Handler

HAS_AWS = True
except ImportError:
HAS_AWS = False
skipIfAWS = unittest.skipIf(HAS_AWS, "AWSSDK Enabled")
skipIfNoAWS = unittest.skipIf(not HAS_AWS, "No AWSSDK Enabled")


class TestDataPipeRemoteIO(expecttest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -183,20 +203,43 @@ def _read_and_decode(x):
if not IS_WINDOWS:
dl = DataLoader(file_cache_dp, num_workers=3, multiprocessing_context="fork", batch_size=1)
expected = [[os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")] for i in range(3)] * 3
self.assertEqual(sorted(expected), sorted(list(dl)))
res = list(dl)
self.assertEqual(sorted(expected), sorted(res))

def test_s3_io_iterdatapipe(self):
# sanity test
@skipIfNoFSSpecS3
def test_fsspec_io_iterdatapipe(self):
input_list = [
(["s3://ai2-public-datasets"], 39), # bucket without '/'
(["s3://ai2-public-datasets/charades/"], 18), # bucket with '/'
(
[
"s3://ai2-public-datasets/charades/Charades_v1.zip",
"s3://ai2-public-datasets/charades/Charades_v1_flow.tar",
"s3://ai2-public-datasets/charades/Charades_v1_rgb.tar",
"s3://ai2-public-datasets/charades/Charades_v1_480.zip",
],
4,
), # multiple files
]
for urls, num in input_list:
fsspec_lister_dp = FSSpecFileLister(IterableWrapper(urls), anon=True)
self.assertEqual(sum(1 for _ in fsspec_lister_dp), num, f"{urls} failed")

url = "s3://ai2-public-datasets/charades/"
fsspec_loader_dp = FSSpecFileOpener(FSSpecFileLister(IterableWrapper([url]), anon=True), anon=True)
res = list(fsspec_loader_dp)
self.assertEqual(len(res), 18, f"{input} failed")

@skipIfAWS
def test_disabled_s3_io_iterdatapipe(self):
file_urls = ["s3://ai2-public-datasets"]
try:
s3_lister_dp = S3FileLister(IterableWrapper(file_urls))
s3_loader_dp = S3FileLoader(IterableWrapper(file_urls))
except ModuleNotFoundError:
warnings.warn(
"S3 IO datapipes or C++ extension '_torchdata' isn't built in the current 'torchdata' package"
)
return
with self.assertRaisesRegex(ModuleNotFoundError, "TorchData must be built with"):
_ = S3FileLister(IterableWrapper(file_urls))
with self.assertRaisesRegex(ModuleNotFoundError, "TorchData must be built with"):
_ = S3FileLoader(IterableWrapper(file_urls))

@skipIfNoAWS
def test_s3_io_iterdatapipe(self):
# S3FileLister: different inputs
input_list = [
[["s3://ai2-public-datasets"], 77], # bucket without '/'
Expand Down
23 changes: 18 additions & 5 deletions torchdata/datapipes/iter/load/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]):
Args:
root: The root `fsspec` path directory or list of path directories to list files from
masks: Unix style filter string or string list for filtering file name(s)
kwargs: Extra options that make sense to a particular storage connection,
e.g. host, port, username, password, etc.

Example:
>>> from torchdata.datapipes.iter import FSSpecFileLister
Expand All @@ -52,6 +54,7 @@ def __init__(
self,
root: Union[str, Sequence[str], IterDataPipe],
masks: Union[str, List[str]] = "",
**kwargs,
) -> None:
_assert_fsspec()

Expand All @@ -64,10 +67,11 @@ def __init__(
else:
self.datapipe = root
self.masks = masks
self.kwargs = kwargs

def __iter__(self) -> Iterator[str]:
for root in self.datapipe:
fs, path = fsspec.core.url_to_fs(root)
fs, path = fsspec.core.url_to_fs(root, **self.kwargs)

if isinstance(fs.protocol, str):
protocol_list = [fs.protocol]
Expand All @@ -88,8 +92,10 @@ def __iter__(self) -> Iterator[str]:
else:
if is_local:
abs_path = os.path.join(path, file_name)
else:
elif not file_name.startswith(path):
abs_path = posixpath.join(path, file_name)
else:
abs_path = file_name
Comment on lines +97 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm - this never happens if is_local == True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I can't guarantee it since I am not aware of all use cases.Based on my experience on fsspec, I only encounter the problem when the input is s3 url.
For local files, I would trust the test here https://github.com/pytorch/data/blob/main/test/test_fsspec.py#L56-L65. As long as it doesn't break, I think


starts_with = False
for protocol in protocol_list:
Expand All @@ -111,22 +117,25 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
Args:
source_datapipe: Iterable DataPipe that provides the pathnames or URLs
mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default)
kwargs: Extra options that make sense to a particular storage connection,
e.g. host, port, username, password, etc.

Example:
>>> from torchdata.datapipes.iter import FSSpecFileLister
>>> datapipe = FSSpecFileLister(root=dir_path)
>>> file_dp = datapipe.open_files_by_fsspec()
"""

def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r") -> None:
def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", **kwargs) -> None:
_assert_fsspec()

self.source_datapipe: IterDataPipe[str] = source_datapipe
self.mode: str = mode
self.kwargs = kwargs

def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for file_uri in self.source_datapipe:
fs, path = fsspec.core.url_to_fs(file_uri)
fs, path = fsspec.core.url_to_fs(file_uri, **self.kwargs)
file = fs.open(path, self.mode)
yield file_uri, StreamWrapper(file)

Expand All @@ -149,6 +158,8 @@ class FSSpecSaverIterDataPipe(IterDataPipe[str]):
source_datapipe: Iterable DataPipe with tuples of metadata and data
mode: Mode in which the file will be opened for write the data (``"w"`` by default)
filepath_fn: Function that takes in metadata and returns the target path of the new file
kwargs: Extra options that make sense to a particular storage connection,
e.g. host, port, username, password, etc.

Example:
>>> from torchdata.datapipes.iter import IterableWrapper
Expand All @@ -165,17 +176,19 @@ def __init__(
source_datapipe: IterDataPipe[Tuple[Any, U]],
mode: str = "w",
filepath_fn: Optional[Callable] = None,
**kwargs,
):
_assert_fsspec()

self.source_datapipe: IterDataPipe[Tuple[Any, U]] = source_datapipe
self.mode: str = mode
self.filepath_fn: Optional[Callable] = filepath_fn
self.kwargs = kwargs

def __iter__(self) -> Iterator[str]:
for meta, data in self.source_datapipe:
filepath = meta if self.filepath_fn is None else self.filepath_fn(meta)
fs, path = fsspec.core.url_to_fs(filepath)
fs, path = fsspec.core.url_to_fs(filepath, **self.kwargs)
with fs.open(path, self.mode) as f:
f.write(data)
yield filepath
Expand Down
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/load/s3io.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class S3FileListerIterDataPipe(IterDataPipe[str]):

def __init__(self, source_datapipe: IterDataPipe[str], length: int = -1, request_timeout_ms=-1, region="") -> None:
if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"):
raise ModuleNotFoundError("Torchdata must be built with BUILD_S3=1 to use this datapipe.")
raise ModuleNotFoundError("TorchData must be built with BUILD_S3=1 to use this datapipe.")

self.source_datapipe: IterDataPipe[str] = source_datapipe
self.length: int = length
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
multi_part_download=None,
) -> None:
if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"):
raise ModuleNotFoundError("Torchdata must be built with BUILD_S3=1 to use this datapipe.")
raise ModuleNotFoundError("TorchData must be built with BUILD_S3=1 to use this datapipe.")

self.source_datapipe: IterDataPipe[str] = source_datapipe
self.handler = torchdata._torchdata.S3Handler(request_timeout_ms, region)
Expand Down