Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE Hackathon] Automatically generate datapipe.pyi in setup.py #290

Closed
wants to merge 9 commits into from
Closed
12 changes: 9 additions & 3 deletions .github/workflows/domain_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ jobs:
uses: actions/checkout@v2

- name: Install torchdata
run: python setup.py install
run: |
pip install -r requirements.txt
python setup.py install

- name: Install test requirements
run: pip install pytest pytest-mock scipy iopath pycocotools h5py
Expand Down Expand Up @@ -85,7 +87,9 @@ jobs:
uses: actions/checkout@v2

- name: Install torchdata
run: python setup.py install
run: |
pip install -r requirements.txt
python setup.py install

- name: Install test requirements
run: pip install dill expecttest pytest iopath
Expand Down Expand Up @@ -126,7 +130,9 @@ jobs:
uses: actions/checkout@v2

- name: Install torchdata
run: python setup.py install
run: |
pip install -r requirements.txt
python setup.py install

- name: Install test requirements
run: pip install dill expecttest numpy pytest
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ dist/*
torchdata.egg-info/*

torchdata/version.py
torchdata/datapipes/iter/__init__.pyi

# Editor temporaries
*.swn
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pathlib import Path

from setuptools import find_packages, setup
from torchdata.datapipes.gen_pyi import gen_pyi


ROOT_DIR = Path(__file__).parent.resolve()

Expand Down Expand Up @@ -110,3 +112,4 @@ def get_parser():
packages=find_packages(exclude=["test*", "examples*"]),
zip_safe=False,
)
gen_pyi()
4 changes: 2 additions & 2 deletions torchdata/datapipes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from torch.utils.data import functional_datapipe
from torch.utils.data import DataChunk, functional_datapipe

from . import iter, map, utils

__all__ = ["functional_datapipe", "iter", "map", "utils"]
__all__ = ["DataChunk", "functional_datapipe", "iter", "map", "utils"]
27 changes: 17 additions & 10 deletions torchdata/datapipes/gen_pyi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import pathlib
from pathlib import Path
from typing import Dict, List, Optional, Set

import torch.utils.data.gen_pyi as core_gen_pyi
from torch.utils.data.gen_pyi import FileManager, get_method_definitions
import torch.utils.data.datapipes.gen_pyi as core_gen_pyi
from torch.utils.data.datapipes.gen_pyi import gen_from_template, get_method_definitions


def get_lines_base_file(base_file_path: str, to_skip: Optional[Set[str]] = None):
Expand All @@ -18,14 +20,17 @@ def get_lines_base_file(base_file_path: str, to_skip: Optional[Set[str]] = None)
if skip_line in line:
skip_flag = True
if not skip_flag:
line = line.replace("\n", "")
res.append(line)
return res


def main() -> None:
def gen_pyi() -> None:
ROOT_DIR = Path(__file__).parent.resolve()
print(f"Generating DataPipe Python interface file in {ROOT_DIR}")

iter_init_base = get_lines_base_file(
"iter/__init__.py",
os.path.join(ROOT_DIR, "iter/__init__.py"),
{"from torch.utils.data import IterDataPipe", "# Copyright (c) Facebook, Inc. and its affiliates."},
)

Expand Down Expand Up @@ -69,14 +74,16 @@ def main() -> None:

iter_method_definitions = core_iter_method_definitions + td_iter_method_definitions

fm = FileManager(install_dir=".", template_dir=".", dry_run=False)
fm.write_with_template(
filename="iter/__init__.pyi",
template_fn="iter/__init__.pyi.in",
env_callable=lambda: {"init_base": iter_init_base, "IterDataPipeMethods": iter_method_definitions},
replacements = [("${init_base}", iter_init_base, 0), ("${IterDataPipeMethods}", iter_method_definitions, 4)]

gen_from_template(
dir=str(ROOT_DIR),
template_name="iter/__init__.pyi.in",
output_name="iter/__init__.pyi",
replacements=replacements,
)
# TODO: Add map_method_definitions when there are MapDataPipes defined in this library


if __name__ == "__main__":
main() # TODO: Run this script automatically within the build and CI process
gen_pyi()
Loading