Skip to content

Commit

Permalink
Fix is_local for paths starting with az:// (#849)
Browse files Browse the repository at this point in the history
Summary:
Signed-off-by: Marco Pfirrmann <[email protected]>

Fixes #840.
The module used to extract the protocol does not differentiate between `abfs://` and `az://` paths, always returning `abfs`. This fix adds `az` to the protocol list to be checked.

Pull Request resolved: #849

Reviewed By: wenleix

Differential Revision: D41345945

Pulled By: NivekT

fbshipit-source-id: 229a2176c5370129192e0c1a458312142cb75b1a
  • Loading branch information
mpfirrmann authored and facebook-github-bot committed Nov 16, 2022
1 parent daa2fc9 commit 594b30e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ rarfile
protobuf >= 3.9.2, < 3.20
datasets
graphviz
adlfs
15 changes: 15 additions & 0 deletions test/test_remote_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,21 @@ def test_fsspec_io_iterdatapipe(self):
res = list(fsspec_loader_dp)
self.assertEqual(len(res), 18, f"{input} failed")

@skipIfNoFSSpecS3
def test_fsspec_azure_blob(self):
url = "public/curated/covid-19/ecdc_cases/latest/ecdc_cases.csv"
account_name = "pandemicdatalake"
azure_prefixes = ["abfs", "az"]
fsspec_loader_dp = {}

for prefix in azure_prefixes:
fsspec_lister_dp = FSSpecFileLister(f"{prefix}://{url}", account_name=account_name)
fsspec_loader_dp[prefix] = FSSpecFileOpener(fsspec_lister_dp, account_name=account_name).parse_csv()

res_abfs = list(fsspec_loader_dp["abfs"])[0]
res_az = list(fsspec_loader_dp["az"])[0]
self.assertEqual(res_abfs, res_az, f"{input} failed")

@skipIfAWS
def test_disabled_s3_io_iterdatapipe(self):
file_urls = ["s3://ai2-public-datasets"]
Expand Down
4 changes: 4 additions & 0 deletions torchdata/datapipes/iter/load/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def __iter__(self) -> Iterator[str]:
else:
protocol_list = fs.protocol

# fspec.core.url_to_fs will return "abfs" for both, "az://" and "abfs://" urls
if "abfs" in protocol_list:
protocol_list.append("az")

is_local = fs.protocol == "file" or not any(root.startswith(protocol) for protocol in protocol_list)
if fs.isfile(path):
yield root
Expand Down

0 comments on commit 594b30e

Please sign in to comment.