Skip to content

Commit

Permalink
stubgen: properly sort & add newlines to imports in generated stubs (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
laggykiller authored Mar 12, 2024
1 parent 754e1b2 commit df8996a
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 13 deletions.
41 changes: 36 additions & 5 deletions src/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class and repeatedly call ``.put()`` to register modules or contents within the
import textwrap
import importlib
import importlib.machinery
import importlib.util
import types
import typing
from dataclasses import dataclass
Expand Down Expand Up @@ -1089,13 +1090,43 @@ def type_str(self, tp: Union[List[Any], Tuple[Any, ...], Dict[Any, Any], Any]) -
result = repr(tp)
return self.simplify_types(result)

def check_party(self, module: str) -> Literal[0, 1, 2]:
"""
Check source of module
0 = From stdlib
1 = From 3rd party package
2 = From the package being built
"""
if module.startswith(".") or module == self.module.__name__.split('.')[0]:
return 2

try:
spec = importlib.util.find_spec(module)
except ModuleNotFoundError:
return 1

if spec:
if spec.origin and "site-packages" in spec.origin:
return 1
else:
return 0
else:
return 1

def get(self) -> str:
"""Generate the final stub output"""
s = ""
last_party = None

for module in sorted(self.imports):
for module in sorted(self.imports, key=lambda i: str(self.check_party(i)) + i):
imports = self.imports[module]
items: List[str] = []
party = self.check_party(module)

if party != last_party:
if last_party is not None:
s += "\n"
last_party = party

for (k, v1), v2 in imports.items():
if k is None:
Expand All @@ -1108,15 +1139,16 @@ def get(self) -> str:
items.append(f"{k} as {v2}")
else:
items.append(k)


items = sorted(items)
if items:
items_v0 = ", ".join(items)
items_v0 = f"from {module} import {items_v0}\n"
items_v1 = "(\n " + ",\n ".join(items) + "\n)"
items_v1 = f"from {module} import {items_v1}\n"
s += items_v0 if len(items_v0) <= 70 else items_v1
if s:
s += "\n"

s += "\n\n"
s += self.put_abstract_enum_class()

# Append the main generated stub
Expand Down Expand Up @@ -1335,7 +1367,6 @@ def add_pattern(query: str, lines: List[str]):

def main(args: Optional[List[str]] = None) -> None:
import sys
import os

# Ensure that the current directory is on the path
if "" not in sys.path and "." not in sys.path:
Expand Down
4 changes: 2 additions & 2 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
is_pypy = platform.python_implementation() == 'PyPy'
is_darwin = platform.system() == 'Darwin'

def collect():
def collect() -> None:
if is_pypy:
for i in range(3):
for _ in range(3):
gc.collect()
else:
gc.collect()
Expand Down
3 changes: 2 additions & 1 deletion tests/py_stub_test.pyi.ref
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable
from typing import overload, TypeVar
from typing import TypeVar, overload


class AClass:
__annotations__: dict = {'STATIC_VAR' : int}
Expand Down
1 change: 1 addition & 0 deletions tests/test_classes_ext.pyi.ref
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import overload


class A:
def __init__(self, arg: int, /) -> None: ...

Expand Down
1 change: 1 addition & 0 deletions tests/test_enum_ext.pyi.ref
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import overload


class _Enum:
def __init__(self, arg: object, /) -> None: ...
def __repr__(self, /) -> str: ...
Expand Down
3 changes: 2 additions & 1 deletion tests/test_functions_ext.pyi.ref
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Callable
import types
from typing import overload, Annotated, Any
from typing import Annotated, Any, overload


def call_guard_value() -> int: ...

Expand Down
1 change: 1 addition & 0 deletions tests/test_make_iterator_ext.pyi.ref
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Iterator, Mapping
from typing import overload


class IdentityMap:
def __init__(self) -> None: ...

Expand Down
4 changes: 3 additions & 1 deletion tests/test_ndarray_ext.pyi.ref
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from numpy.typing import ArrayLike
from typing import Annotated, overload

from numpy.typing import ArrayLike


class Cls:
def __init__(self) -> None: ...

Expand Down
3 changes: 2 additions & 1 deletion tests/test_stl_ext.pyi.ref
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections.abc import Sequence, Callable, Mapping, Set
from collections.abc import Callable, Mapping, Sequence, Set
import os
import pathlib
from typing import overload


class ClassWithMovableField:
def __init__(self) -> None: ...

Expand Down
6 changes: 4 additions & 2 deletions tests/test_typing_ext.pyi.ref
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections.abc import Iterable
from typing import Generic, Optional, Self, TypeAlias, TypeVar

from . import submodule as submodule
from .submodule import F as F, f as f2
from collections.abc import Iterable
from typing import Self, Optional, TypeAlias, TypeVar, Generic


# a prefix

Expand Down

0 comments on commit df8996a

Please sign in to comment.