diff --git a/mypy.ini b/mypy.ini index 8bb1e7024..d5d430068 100644 --- a/mypy.ini +++ b/mypy.ini @@ -27,6 +27,9 @@ ignore_missing_imports = True [mypy-expecttest.*] ignore_missing_imports = True +[mypy-datasets.*] +ignore_missing_imports = True + [mypy-rarfile.*] ignore_missing_imports = True diff --git a/test/requirements.txt b/test/requirements.txt index 4f223b2e6..716cd7a11 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -5,3 +5,4 @@ iopath == 0.1.9 numpy rarfile protobuf < 4 +datasets diff --git a/test/test_huggingface_datasets.py b/test/test_huggingface_datasets.py new file mode 100644 index 000000000..51655e2c0 --- /dev/null +++ b/test/test_huggingface_datasets.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest +import warnings + +import expecttest + +from _utils._common_utils_for_test import create_temp_dir, create_temp_files, reset_after_n_next_calls + +from torchdata.datapipes.iter import HuggingFaceHubReader, IterableWrapper + +try: + import datasets + + HAS_DATASETS = True + +except ImportError: + HAS_DATASETS = False +skipIfNoDatasets = unittest.skipIf(not HAS_DATASETS, "no datasets") + + +class TestHuggingFaceHubReader(expecttest.TestCase): + def setUp(self): + self.temp_dir = create_temp_dir() + self.temp_files = create_temp_files(self.temp_dir) + self.temp_sub_dir = create_temp_dir(self.temp_dir.name) + self.temp_sub_files = create_temp_files(self.temp_sub_dir, 4, False) + + self.temp_dir_2 = create_temp_dir() + self.temp_files_2 = create_temp_files(self.temp_dir_2) + self.temp_sub_dir_2 = create_temp_dir(self.temp_dir_2.name) + self.temp_sub_files_2 = create_temp_files(self.temp_sub_dir_2, 4, False) + + def tearDown(self): + try: + self.temp_sub_dir.cleanup() + self.temp_dir.cleanup() + self.temp_sub_dir_2.cleanup() + self.temp_dir_2.cleanup() + except Exception as e: + warnings.warn(f"HuggingFace datasets was not able to cleanup temp dir due to {e}") + + @skipIfNoDatasets + def test_huggingface_hubreader(self): + datapipe = HuggingFaceHubReader(dataset="lhoestq/demo1", revision="main", streaming=True) + elem = next(iter(datapipe)) + assert type(elem) is dict + assert elem["package_name"] == "com.mantz_it.rfanalyzer" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_serialization.py b/test/test_serialization.py index 2237e63f4..f867c1acf 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -26,6 +26,11 @@ dill.extend(use_dill=False) +try: + import datasets +except ImportError: + datasets = None + try: import fsspec except ImportError: @@ -74,6 +79,8 @@ def _filepath_fn(name: str, dir) -> str: def _filter_by_module_availability(datapipes): filter_set = set() + if datasets is None: + filter_set.update([iterdp.HuggingFaceHubReader]) if fsspec is None: filter_set.update([iterdp.FSSpecFileLister, iterdp.FSSpecFileOpener, iterdp.FSSpecSaver]) if iopath is None: @@ -195,6 +202,7 @@ def test_serializable(self): (iterdp.HashChecker, None, ({},), {}), (iterdp.Header, None, (3,), {}), (iterdp.HttpReader, None, (), {}), + (iterdp.HuggingFaceHubReader, None, (), {}), # TODO (ejguan): Deterministic serialization is required # (iterdp.InBatchShuffler, IterableWrapper(range(10)).batch(3), (), {}), (iterdp.InMemoryCacheHolder, None, (), {}), @@ -298,6 +306,7 @@ def test_serializable(self): iterdp.IoPathFileOpener, iterdp.HashChecker, iterdp.HttpReader, + iterdp.HuggingFaceHubReader, iterdp.OnDiskCacheHolder, iterdp.OnlineReader, iterdp.ParquetDataFrameLoader, diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index a4109cced..95ea54718 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -38,6 +38,9 @@ FSSpecFileOpenerIterDataPipe as FSSpecFileOpener, FSSpecSaverIterDataPipe as FSSpecSaver, ) + +from torchdata.datapipes.iter.load.huggingface import HuggingFaceHubReaderIterDataPipe as HuggingFaceHubReader + from torchdata.datapipes.iter.load.iopath import ( IoPathFileListerIterDataPipe as IoPathFileLister, IoPathFileOpenerIterDataPipe as IoPathFileOpener, @@ -150,6 +153,7 @@ "HashChecker", "Header", "HttpReader", + "HuggingFaceHubReader", "InBatchShuffler", "InMemoryCacheHolder", "IndexAdder", diff --git a/torchdata/datapipes/iter/load/huggingface.py b/torchdata/datapipes/iter/load/huggingface.py new file mode 100644 index 000000000..e0764356d --- /dev/null +++ b/torchdata/datapipes/iter/load/huggingface.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict, Iterator, Optional, Tuple + +from torchdata.datapipes.iter import IterDataPipe +from torchdata.datapipes.utils import StreamWrapper + +try: + import datasets +except ImportError: + datasets = None + + +def _get_response_from_huggingface_hub( + dataset: str, split: str, revision: str, streaming: bool, data_files: Optional[Dict[str, str]] +) -> Iterator[Any]: + hf_dataset = datasets.load_dataset( + dataset, split=split, revision=revision, streaming=streaming, data_files=data_files + ) + return iter(hf_dataset) + + +class HuggingFaceHubReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): + r""" + Takes in dataset names and returns an Iterable HuggingFace dataset + Args format is the same as https://huggingface.co/docs/datasets/loading + Args: + source_datapipe: a DataPipe that contains dataset names which will be accepted by the HuggingFace datasets library + revision: the specific dataset version + split: train/test split + streaming: Stream dataset instead of downloading it one go + data_files: Optional dict to set custom train/test/validation split + Example: + >>> from torchdata.datapipes.iter import IterableWrapper, HuggingFaceHubReaderIterDataPipe + >>> huggingface_reader_dp = HuggingFaceHubReaderDataPipe("lhoestq/demo1", revision="main") + >>> elem = next(iter(huggingface_reader_dp)) + >>> elem["package_name"] + com.mantz_it.rfanalyzer + """ + + source_datapipe: IterDataPipe[str] + + def __init__( + self, + dataset: str, + *, + split: str = "train", + revision: str = "main", + streaming: bool = True, + data_files: Optional[Dict[str, str]] = None, + ) -> None: + if datasets is None: + raise ModuleNotFoundError( + "Package `datasets` is required to be installed to use this datapipe." + "Please use `pip install datasets` or `conda install -c conda-forge datasets`" + "to install the package" + ) + + self.dataset = dataset + self.split = split + self.revision = revision + self.streaming = streaming + self.data_files = data_files + + def __iter__(self) -> Iterator[Any]: + return _get_response_from_huggingface_hub( + dataset=self.dataset, + split=self.split, + revision=self.revision, + streaming=self.streaming, + data_files=self.data_files, + ) + + def __len__(self) -> int: + raise TypeError(f"{type(self).__name__} instance doesn't have valid length")