diff --git a/src/datatrove/io.py b/src/datatrove/io.py index fdaafea1..1fba5fac 100644 --- a/src/datatrove/io.py +++ b/src/datatrove/io.py @@ -6,6 +6,7 @@ from fsspec import open as fsspec_open from fsspec.callbacks import NoOpCallback, TqdmCallback from fsspec.core import get_fs_token_paths, strip_protocol, url_to_fs +from fsspec.implementations.cached import CachingFileSystem from fsspec.implementations.dirfs import DirFileSystem from fsspec.implementations.local import LocalFileSystem from huggingface_hub import HfFileSystem, cached_assets_path @@ -137,7 +138,7 @@ def list_files( # makes it slightly easier for file extensions glob_pattern = f"*{glob_pattern}" extra_options = {} - if isinstance(self.fs, HfFileSystem): + if isinstance(_get_true_fs(self.fs), HfFileSystem): extra_options["expand_info"] = False # speed up if include_directories and not glob_pattern: extra_options["withdirs"] = True @@ -374,3 +375,11 @@ def get_shard_from_paths_file(paths_file: DataFileLike, rank: int, world_size): for pathi, path in enumerate(f): if (pathi - rank) % world_size == 0: yield path.strip() + + +def _get_true_fs(fs: AbstractFileSystem): + if isinstance(fs, CachingFileSystem): + # We have to unwrap the cached filesystem to get the real fs + return fs.fs + + return fs