diff --git a/.github/workflows/domain_ci.yml b/.github/workflows/domain_ci.yml index 4efc560ff..7729848d9 100644 --- a/.github/workflows/domain_ci.yml +++ b/.github/workflows/domain_ci.yml @@ -39,7 +39,9 @@ jobs: uses: actions/checkout@v2 - name: Install torchdata - run: python setup.py install + run: | + pip install -r requirements.txt + python setup.py install - name: Install test requirements run: pip install pytest pytest-mock scipy iopath pycocotools h5py @@ -85,7 +87,9 @@ jobs: uses: actions/checkout@v2 - name: Install torchdata - run: python setup.py install + run: | + pip install -r requirements.txt + python setup.py install - name: Install test requirements run: pip install dill expecttest pytest iopath @@ -126,7 +130,9 @@ jobs: uses: actions/checkout@v2 - name: Install torchdata - run: python setup.py install + run: | + pip install -r requirements.txt + python setup.py install - name: Install test requirements run: pip install dill expecttest numpy pytest diff --git a/.gitignore b/.gitignore index d1ab534f5..65607dd22 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ dist/* torchdata.egg-info/* torchdata/version.py +torchdata/datapipes/iter/__init__.pyi # Editor temporaries *.swn diff --git a/setup.py b/setup.py index f89a4b7b8..0fad8ab88 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,8 @@ from pathlib import Path from setuptools import find_packages, setup +from torchdata.datapipes.gen_pyi import gen_pyi + ROOT_DIR = Path(__file__).parent.resolve() @@ -110,3 +112,4 @@ def get_parser(): packages=find_packages(exclude=["test*", "examples*"]), zip_safe=False, ) + gen_pyi() diff --git a/torchdata/datapipes/__init__.py b/torchdata/datapipes/__init__.py index 2316f7333..aebccb3d0 100644 --- a/torchdata/datapipes/__init__.py +++ b/torchdata/datapipes/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. -from torch.utils.data import functional_datapipe +from torch.utils.data import DataChunk, functional_datapipe from . import iter, map, utils -__all__ = ["functional_datapipe", "iter", "map", "utils"] +__all__ = ["DataChunk", "functional_datapipe", "iter", "map", "utils"] diff --git a/torchdata/datapipes/gen_pyi.py b/torchdata/datapipes/gen_pyi.py index ad28915d2..e08d19776 100644 --- a/torchdata/datapipes/gen_pyi.py +++ b/torchdata/datapipes/gen_pyi.py @@ -1,9 +1,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. +import os import pathlib +from pathlib import Path from typing import Dict, List, Optional, Set -import torch.utils.data.gen_pyi as core_gen_pyi -from torch.utils.data.gen_pyi import FileManager, get_method_definitions +import torch.utils.data.datapipes.gen_pyi as core_gen_pyi +from torch.utils.data.datapipes.gen_pyi import gen_from_template, get_method_definitions def get_lines_base_file(base_file_path: str, to_skip: Optional[Set[str]] = None): @@ -18,14 +20,17 @@ def get_lines_base_file(base_file_path: str, to_skip: Optional[Set[str]] = None) if skip_line in line: skip_flag = True if not skip_flag: + line = line.replace("\n", "") res.append(line) return res -def main() -> None: +def gen_pyi() -> None: + ROOT_DIR = Path(__file__).parent.resolve() + print(f"Generating DataPipe Python interface file in {ROOT_DIR}") iter_init_base = get_lines_base_file( - "iter/__init__.py", + os.path.join(ROOT_DIR, "iter/__init__.py"), {"from torch.utils.data import IterDataPipe", "# Copyright (c) Facebook, Inc. and its affiliates."}, ) @@ -69,14 +74,16 @@ def main() -> None: iter_method_definitions = core_iter_method_definitions + td_iter_method_definitions - fm = FileManager(install_dir=".", template_dir=".", dry_run=False) - fm.write_with_template( - filename="iter/__init__.pyi", - template_fn="iter/__init__.pyi.in", - env_callable=lambda: {"init_base": iter_init_base, "IterDataPipeMethods": iter_method_definitions}, + replacements = [("${init_base}", iter_init_base, 0), ("${IterDataPipeMethods}", iter_method_definitions, 4)] + + gen_from_template( + dir=str(ROOT_DIR), + template_name="iter/__init__.pyi.in", + output_name="iter/__init__.pyi", + replacements=replacements, ) # TODO: Add map_method_definitions when there are MapDataPipes defined in this library if __name__ == "__main__": - main() # TODO: Run this script automatically within the build and CI process + gen_pyi() diff --git a/torchdata/datapipes/iter/__init__.pyi b/torchdata/datapipes/iter/__init__.pyi deleted file mode 100644 index 9ee4eceab..000000000 --- a/torchdata/datapipes/iter/__init__.pyi +++ /dev/null @@ -1,363 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. - -from torch.utils.data.datapipes.iter import ( - Batcher, - Collator, - Concater, - Demultiplexer, - FileLister, - FileOpener, - Filter, - Forker, - Grouper, - IterableWrapper, - Mapper, - Multiplexer, - RoutedDecoder, - Sampler, - ShardingFilter, - Shuffler, - StreamReader, - UnBatcher, - Zipper, -) -from torchdata.datapipes.iter.load.fsspec import ( - FSSpecFileListerIterDataPipe as FSSpecFileLister, - FSSpecFileOpenerIterDataPipe as FSSpecFileOpener, - FSSpecSaverIterDataPipe as FSSpecSaver, -) -from torchdata.datapipes.iter.load.iopath import ( - IoPathFileListerIterDataPipe as IoPathFileLister, - IoPathFileOpenerIterDataPipe as IoPathFileOpener, - IoPathSaverIterDataPipe as IoPathSaver, -) - -from torchdata.datapipes.iter.load.online import ( - GDriveReaderDataPipe as GDriveReader, - HTTPReaderIterDataPipe as HttpReader, - OnlineReaderIterDataPipe as OnlineReader, -) -from torchdata.datapipes.iter.transform.bucketbatcher import BucketBatcherIterDataPipe as BucketBatcher -from torchdata.datapipes.iter.transform.flatmap import FlatMapperIterDataPipe as FlatMapper -from torchdata.datapipes.iter.util.cacheholder import ( - EndOnDiskCacheHolderIterDataPipe as EndOnDiskCacheHolder, - InMemoryCacheHolderIterDataPipe as InMemoryCacheHolder, - OnDiskCacheHolderIterDataPipe as OnDiskCacheHolder, -) -from torchdata.datapipes.iter.util.combining import ( - IterKeyZipperIterDataPipe as IterKeyZipper, - MapKeyZipperIterDataPipe as MapKeyZipper, -) -from torchdata.datapipes.iter.util.cycler import CyclerIterDataPipe as Cycler -from torchdata.datapipes.iter.util.dataframemaker import ( - DataFrameMakerIterDataPipe as DataFrameMaker, - ParquetDFLoaderIterDataPipe as ParquetDataFrameLoader, -) -from torchdata.datapipes.iter.util.decompressor import ( - DecompressorIterDataPipe as Decompressor, - ExtractorIterDataPipe as Extractor, -) -from torchdata.datapipes.iter.util.hashchecker import HashCheckerIterDataPipe as HashChecker -from torchdata.datapipes.iter.util.header import HeaderIterDataPipe as Header -from torchdata.datapipes.iter.util.indexadder import ( - EnumeratorIterDataPipe as Enumerator, - IndexAdderIterDataPipe as IndexAdder, -) -from torchdata.datapipes.iter.util.jsonparser import JsonParserIterDataPipe as JsonParser -from torchdata.datapipes.iter.util.paragraphaggregator import ParagraphAggregatorIterDataPipe as ParagraphAggregator -from torchdata.datapipes.iter.util.plain_text_reader import ( - CSVDictParserIterDataPipe as CSVDictParser, - CSVParserIterDataPipe as CSVParser, - LineReaderIterDataPipe as LineReader, -) -from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader -from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar -from torchdata.datapipes.iter.util.samplemultiplexer import SampleMultiplexerDataPipe as SampleMultiplexer -from torchdata.datapipes.iter.util.saver import SaverIterDataPipe as Saver -from torchdata.datapipes.iter.util.tararchiveloader import ( - TarArchiveLoaderIterDataPipe as TarArchiveLoader, - TarArchiveReaderIterDataPipe as TarArchiveReader, -) -from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper -from torchdata.datapipes.iter.util.xzfileloader import ( - XzFileLoaderIterDataPipe as XzFileLoader, - XzFileReaderIterDataPipe as XzFileReader, -) -from torchdata.datapipes.iter.util.ziparchiveloader import ( - ZipArchiveLoaderIterDataPipe as ZipArchiveLoader, - ZipArchiveReaderIterDataPipe as ZipArchiveReader, -) - -############################################################################### -# Reference From PyTorch Core -############################################################################### - -__all__ = [ - "Batcher", - "BucketBatcher", - "CSVDictParser", - "CSVParser", - "Collator", - "Concater", - "Cycler", - "DataFrameMaker", - "Decompressor", - "Demultiplexer", - "EndOnDiskCacheHolder", - "Enumerator", - "Extractor", - "FSSpecFileLister", - "FSSpecFileOpener", - "FSSpecSaver", - "FileLister", - "FileOpener", - "Filter", - "FlatMapper", - "Forker", - "GDriveReader", - "Grouper", - "HashChecker", - "Header", - "HttpReader", - "InMemoryCacheHolder", - "IndexAdder", - "IoPathFileLister", - "IoPathFileOpener", - "IoPathSaver", - "IterDataPipe", - "IterKeyZipper", - "IterableWrapper", - "JsonParser", - "LineReader", - "MapKeyZipper", - "Mapper", - "Multiplexer", - "OnDiskCacheHolder", - "OnlineReader", - "ParagraphAggregator", - "ParquetDataFrameLoader", - "RarArchiveLoader", - "RoutedDecoder", - "Rows2Columnar", - "SampleMultiplexer", - "Sampler", - "Saver", - "ShardingFilter", - "Shuffler", - "StreamReader", - "TarArchiveLoader", - "TarArchiveReader", - "UnBatcher", - "UnZipper", - "XzFileLoader", - "XzFileReader", - "ZipArchiveLoader", - "ZipArchiveReader", - "Zipper", -] - -# Please keep this list sorted -assert __all__ == sorted(__all__) - -from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union - -from torch.utils.data import DataChunk, IterableDataset -from torch.utils.data._typing import _DataPipeMeta -from torchdata.datapipes.map import MapDataPipe - -######################################################################################################################## -# The part below is generated by parsing through the Python files where IterDataPipes are defined. -# This base template ("__init__.pyi.in") is generated from mypy stubgen with minimal editing for code injection -# The output file will be "__init__.pyi". -# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other -# classes/objects here, even though we are not injecting extra code into them at the moment. - -from .util.decompressor import CompressionType - -try: - import torcharrow -except ImportError: - torcharrow = None - -T_co = TypeVar("T_co", covariant=True) - -class IterDataPipe(IterableDataset[T_co], metaclass=_DataPipeMeta): - functions: Dict[str, Callable] = ... - reduce_ex_hook: Optional[Callable] = ... - getstate_hook: Optional[Callable] = ... - def __getattr__(self, attribute_name: Any): ... - @classmethod - def register_function(cls, function_name: Any, function: Any) -> None: ... - @classmethod - def register_datapipe_as_function( - cls, function_name: Any, cls_to_register: Any, enable_df_api_tracing: bool = ... - ): ... - def __getstate__(self): ... - def __reduce_ex__(self, *args: Any, **kwargs: Any): ... - @classmethod - def set_getstate_hook(cls, hook_fn: Any) -> None: ... - @classmethod - def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ... - # Functional form of 'BatcherIterDataPipe' - def batch(self, batch_size: int, drop_last: bool = False, wrapper_class=DataChunk) -> IterDataPipe: ... - # Functional form of 'CollatorIterDataPipe' - def collate(self, collate_fn: Callable = ...) -> IterDataPipe: ... - # Functional form of 'ConcaterIterDataPipe' - def concat(self, *datapipes: IterDataPipe) -> IterDataPipe: ... - # Functional form of 'DemultiplexerIterDataPipe' - def demux( - self, - num_instances: int, - classifier_fn: Callable[[T_co], Optional[int]], - drop_none: bool = False, - buffer_size: int = 1000, - ) -> List[IterDataPipe]: ... - # Functional form of 'FilterIterDataPipe' - def filter(self, filter_fn: Callable, drop_empty_batches: bool = True) -> IterDataPipe: ... - # Functional form of 'ForkerIterDataPipe' - def fork(self, num_instances: int, buffer_size: int = 1000) -> List[IterDataPipe]: ... - # Functional form of 'GrouperIterDataPipe' - def groupby( - self, - group_key_fn: Callable, - *, - buffer_size: int = 10000, - group_size: Optional[int] = None, - guaranteed_group_size: Optional[int] = None, - drop_remaining: bool = False, - ) -> IterDataPipe: ... - # Functional form of 'MapperIterDataPipe' - def map(self, fn: Callable, input_col=None, output_col=None) -> IterDataPipe: ... - # Functional form of 'MultiplexerIterDataPipe' - def mux(self, *datapipes) -> IterDataPipe: ... - # Functional form of 'RoutedDecoderIterDataPipe' - def routed_decode(self, *handlers: Callable, key_fn: Callable = ...) -> IterDataPipe: ... - # Functional form of 'ShardingFilterIterDataPipe' - def sharding_filter(self) -> IterDataPipe: ... - # Functional form of 'ShufflerIterDataPipe' - def shuffle(self, *, default: bool = True, buffer_size: int = 10000, unbatch_level: int = 0) -> IterDataPipe: ... - # Functional form of 'UnBatcherIterDataPipe' - def unbatch(self, unbatch_level: int = 1) -> IterDataPipe: ... - # Functional form of 'ZipperIterDataPipe' - def zip(self, *datapipes: IterDataPipe) -> IterDataPipe: ... - # Functional form of 'IndexAdderIterDataPipe' - def add_index(self, index_name: str = "index") -> IterDataPipe: ... - # Functional form of 'BucketBatcherIterDataPipe' - def bucketbatch( - self, - batch_size: int, - drop_last: bool = False, - batch_num: int = 100, - bucket_num: int = 1, - sort_key: Optional[Callable] = None, - in_batch_shuffle: bool = True, - ) -> IterDataPipe: ... - # Functional form of 'HashCheckerIterDataPipe' - def check_hash(self, hash_dict: Dict[str, str], hash_type: str = "sha256", rewind: bool = True) -> IterDataPipe: ... - # Functional form of 'CyclerIterDataPipe' - def cycle(self, count: Optional[int] = None) -> IterDataPipe: ... - # Functional form of 'DataFrameMakerIterDataPipe' - def dataframe( - self, dataframe_size: int = 1000, dtype=None, columns: Optional[List[str]] = None, device: str = "" - ) -> torcharrow.DataFrame: ... - # Functional form of 'DecompressorIterDataPipe' - def decompress(self, file_type: Optional[Union[str, CompressionType]] = None) -> IterDataPipe: ... - # Functional form of 'EndOnDiskCacheHolderIterDataPipe' - def end_caching(self, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False) -> IterDataPipe: ... - # Functional form of 'EnumeratorIterDataPipe' - def enumerate(self, starting_index: int = 0) -> IterDataPipe: ... - # Functional form of 'FlatMapperIterDataPipe' - def flatmap(self, fn: Callable) -> IterDataPipe: ... - # Functional form of 'HeaderIterDataPipe' - def header(self, limit: int = 10) -> IterDataPipe: ... - # Functional form of 'InMemoryCacheHolderIterDataPipe' - def in_memory_cache(self, size: Optional[int] = None) -> IterDataPipe: ... - # Functional form of 'ParagraphAggregatorIterDataPipe' - def lines_to_paragraphs(self, joiner: Callable = ...) -> IterDataPipe: ... - # Functional form of 'RarArchiveLoaderIterDataPipe' - def load_from_rar(self, *, length: int = -1) -> IterDataPipe: ... - # Functional form of 'TarArchiveLoaderIterDataPipe' - def load_from_tar(self, mode: str = "r:*", length: int = -1) -> IterDataPipe: ... - # Functional form of 'XzFileLoaderIterDataPipe' - def load_from_xz(self, length: int = -1) -> IterDataPipe: ... - # Functional form of 'ZipArchiveLoaderIterDataPipe' - def load_from_zip(self, length: int = -1) -> IterDataPipe: ... - # Functional form of 'ParquetDFLoaderIterDataPipe' - def load_parquet_as_df( - self, dtype=None, columns: Optional[List[str]] = None, device: str = "", use_threads: bool = False - ) -> IterDataPipe: ... - # Functional form of 'OnDiskCacheHolderIterDataPipe' - def on_disk_cache( - self, - filepath_fn: Optional[Callable] = None, - hash_dict: Dict[str, str] = None, - hash_type: str = "sha256", - extra_check_fn: Optional[Callable[[str], bool]] = None, - ) -> IterDataPipe: ... - # Functional form of 'FSSpecFileOpenerIterDataPipe' - def open_file_by_fsspec(self, mode: str = "r") -> IterDataPipe: ... - # Functional form of 'IoPathFileOpenerIterDataPipe' - def open_file_by_iopath(self, mode: str = "r", pathmgr=None) -> IterDataPipe: ... - # Functional form of 'CSVParserIterDataPipe' - def parse_csv( - self, - *, - skip_lines: int = 0, - decode: bool = True, - encoding: str = "utf-8", - errors: str = "ignore", - return_path: bool = False, - **fmtparams, - ) -> IterDataPipe: ... - # Functional form of 'CSVDictParserIterDataPipe' - def parse_csv_as_dict( - self, - *, - skip_lines: int = 0, - decode: bool = True, - encoding: str = "utf-8", - errors: str = "ignore", - return_path: bool = False, - **fmtparams, - ) -> IterDataPipe: ... - # Functional form of 'JsonParserIterDataPipe' - def parse_json_files(self, **kwargs) -> IterDataPipe: ... - # Functional form of 'LineReaderIterDataPipe' - def readlines( - self, - *, - skip_lines: int = 0, - strip_newline: bool = True, - decode: bool = False, - encoding="utf-8", - errors: str = "ignore", - return_path: bool = True, - ) -> IterDataPipe: ... - # Functional form of 'Rows2ColumnarIterDataPipe' - def rows2columnar(self, column_names: List[str] = None) -> IterDataPipe: ... - # Functional form of 'FSSpecSaverIterDataPipe' - def save_by_fsspec(self, mode: str = "w", filepath_fn: Optional[Callable] = None) -> IterDataPipe: ... - # Functional form of 'IoPathSaverIterDataPipe' - def save_by_iopath( - self, mode: str = "w", filepath_fn: Optional[Callable] = None, *, pathmgr=None - ) -> IterDataPipe: ... - # Functional form of 'SaverIterDataPipe' - def save_to_disk(self, mode: str = "w", filepath_fn: Optional[Callable] = None) -> IterDataPipe: ... - # Functional form of 'UnZipperIterDataPipe' - def unzip( - self, sequence_length: int, buffer_size: int = 1000, columns_to_skip: Optional[Sequence[int]] = None - ) -> List[IterDataPipe]: ... - # Functional form of 'IterKeyZipperIterDataPipe' - def zip_with_iter( - self, - ref_datapipe: IterDataPipe, - key_fn: Callable, - ref_key_fn: Optional[Callable] = None, - keep_key: bool = False, - buffer_size: int = 10000, - merge_fn: Optional[Callable] = None, - ) -> IterDataPipe: ... - # Functional form of 'MapKeyZipperIterDataPipe' - def zip_with_map( - self, map_datapipe: MapDataPipe, key_fn: Callable, merge_fn: Optional[Callable] = None - ) -> IterDataPipe: ... diff --git a/torchdata/datapipes/iter/__init__.pyi.in b/torchdata/datapipes/iter/__init__.pyi.in index 3b3241bd5..ec5500e9d 100644 --- a/torchdata/datapipes/iter/__init__.pyi.in +++ b/torchdata/datapipes/iter/__init__.pyi.in @@ -4,14 +4,14 @@ ${init_base} ######################################################################################################################## # The part below is generated by parsing through the Python files where IterDataPipes are defined. # This base template ("__init__.pyi.in") is generated from mypy stubgen with minimal editing for code injection -# The output file will be "__init__.pyi". +# The output file will be "__init__.pyi". The generation function is called by "setup.py". # Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other # classes/objects here, even though we are not injecting extra code into them at the moment. from .util.decompressor import CompressionType from torchdata.datapipes.map import MapDataPipe from torch.utils.data import DataChunk, IterableDataset -from torch.utils.data._typing import _DataPipeMeta +from torch.utils.data.datapipes._typing import _DataPipeMeta from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union