Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support type stub generation for staticmethod #14934

Merged
merged 13 commits into from
Jan 8, 2024
1 change: 1 addition & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,7 @@ def generate_stubs(options: Options) -> None:
doc_dir=options.doc_dir,
include_private=options.include_private,
export_less=options.export_less,
include_docstrings=options.include_docstrings,
)
num_modules = len(all_modules)
if not options.quiet and num_modules > 0:
Expand Down
14 changes: 9 additions & 5 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,14 @@ def is_classmethod(self, class_info: ClassInfo, name: str, obj: object) -> bool:
return inspect.ismethod(obj)

def is_staticmethod(self, class_info: ClassInfo | None, name: str, obj: object) -> bool:
if self.is_c_module:
if class_info is None:
return False
elif self.is_c_module:
raw_lookup = getattr(class_info.cls, "__dict__") # noqa: B009
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this been # noqa'd? Wouldn't it be cleaner to do this, as the linter suggests?

Suggested change
raw_lookup = getattr(class_info.cls, "__dict__") # noqa: B009
raw_lookup = class_info.cls.__dict__

Copy link
Contributor Author

@WeilerMarcel WeilerMarcel Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I wrote it that way because there are two other places in stubgenc.py where this pattern was used:

  • mypy/mypy/stubgenc.py

    Lines 453 to 455 in 8121e3c

    def get_members(self, obj: object) -> list[tuple[str, Any]]:
    obj_dict: Mapping[str, Any] = getattr(obj, "__dict__") # noqa: B009
    results = []
  • mypy/mypy/stubgenc.py

    Lines 718 to 725 in 8121e3c

    def generate_class_stub(self, class_name: str, cls: type, output: list[str]) -> None:
    """Generate stub for a single class using runtime introspection.
    The result lines will be appended to 'output'. If necessary, any
    required names will be added to 'imports'.
    """
    raw_lookup = getattr(cls, "__dict__") # noqa: B009
    items = self.get_members(cls)

    I will check the file history to see if I can find a specific reason why getattr is used.

Copy link
Contributor Author

@WeilerMarcel WeilerMarcel Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The getattr call originally comes from commit 6bcaf40 and was added to make typeshed happy:

mypy/mypy/stubgenc.py

Lines 140 to 142 in 6bcaf40

# typeshed gives obj.__dict__ the not quite correct type Dict[str, Any]
# (it could be a mappingproxy!), which makes mypyc mad, so obfuscate it.
obj_dict = getattr(obj, '__dict__') # type: Mapping[str, Any]

The # noqa was added in 4287af4:

obj_dict = getattr(obj, '__dict__') # type: Mapping[str, Any] # noqa

raw_value = raw_lookup.get(name, obj)
return type(raw_value).__name__ == "staticmethod"
WeilerMarcel marked this conversation as resolved.
Show resolved Hide resolved
else:
return class_info is not None and isinstance(
inspect.getattr_static(class_info.cls, name), staticmethod
)
return isinstance(inspect.getattr_static(class_info.cls, name), staticmethod)

@staticmethod
def is_abstract_method(obj: object) -> bool:
Expand Down Expand Up @@ -751,7 +753,9 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
continue
attr = "__init__"
# FIXME: make this nicer
if self.is_classmethod(class_info, attr, value):
if self.is_staticmethod(class_info, attr, value):
class_info.self_var = ""
elif self.is_classmethod(class_info, attr, value):
class_info.self_var = "cls"
else:
class_info.self_var = "self"
Expand Down
15 changes: 15 additions & 0 deletions test-data/pybind11_mypy_demo/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ const Point Point::y_axis = Point(0, 1);
Point::LengthUnit Point::length_unit = Point::LengthUnit::mm;
Point::AngleUnit Point::angle_unit = Point::AngleUnit::radian;

struct Foo
{
static int some_static_method(int a, int b) { return a * 42 + b; }
static int overloaded_static_method(int value) { return value * 42; }
static double overloaded_static_method(double value) { return value * 42; }
};

} // namespace: basics

void bind_basics(py::module& basics) {
Expand Down Expand Up @@ -159,6 +166,14 @@ void bind_basics(py::module& basics) {
.value("radian", Point::AngleUnit::radian)
.value("degree", Point::AngleUnit::degree);

// Static methods
py::class_<Foo> pyFoo(basics, "Foo");

pyFoo
.def_static("some_static_method", &Foo::some_static_method, R"#(None)#", py::arg("a"), py::arg("b"))
.def_static("overloaded_static_method", py::overload_cast<int>(&Foo::overloaded_static_method), py::arg("value"))
.def_static("overloaded_static_method", py::overload_cast<double>(&Foo::overloaded_static_method), py::arg("value"));

// Module-level attributes
basics.attr("PI") = std::acos(-1);
basics.attr("__version__") = "0.0.1";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import basics as basics
Original file line number Diff line number Diff line change
@@ -1,7 +1,34 @@
from typing import ClassVar
from typing import ClassVar, overload

from typing import overload
PI: float
__version__: str

class Foo:
def __init__(self, *args, **kwargs) -> None:
"""Initialize self. See help(type(self)) for accurate signature."""
@overload
@staticmethod
def overloaded_static_method(value: int) -> int:
"""overloaded_static_method(*args, **kwargs)
Overloaded function.

1. overloaded_static_method(value: int) -> int

2. overloaded_static_method(value: float) -> float"""
@overload
@staticmethod
def overloaded_static_method(value: float) -> float:
"""overloaded_static_method(*args, **kwargs)
Overloaded function.

1. overloaded_static_method(value: int) -> int

2. overloaded_static_method(value: float) -> float"""
@staticmethod
def some_static_method(a: int, b: int) -> int:
"""some_static_method(a: int, b: int) -> int

None"""

class Point:
class AngleUnit:
Expand All @@ -13,8 +40,6 @@ class Point:
"""__init__(self: pybind11_mypy_demo.basics.Point.AngleUnit, value: int) -> None"""
def __eq__(self, other: object) -> bool:
"""__eq__(self: object, other: object) -> bool"""
def __getstate__(self) -> int:
"""__getstate__(self: object) -> int"""
def __hash__(self) -> int:
"""__hash__(self: object) -> int"""
def __index__(self) -> int:
Expand All @@ -23,8 +48,6 @@ class Point:
"""__int__(self: pybind11_mypy_demo.basics.Point.AngleUnit) -> int"""
def __ne__(self, other: object) -> bool:
"""__ne__(self: object, other: object) -> bool"""
def __setstate__(self, state: int) -> None:
"""__setstate__(self: pybind11_mypy_demo.basics.Point.AngleUnit, state: int) -> None"""
@property
def name(self) -> str: ...
@property
Expand All @@ -40,8 +63,6 @@ class Point:
"""__init__(self: pybind11_mypy_demo.basics.Point.LengthUnit, value: int) -> None"""
def __eq__(self, other: object) -> bool:
"""__eq__(self: object, other: object) -> bool"""
def __getstate__(self) -> int:
"""__getstate__(self: object) -> int"""
def __hash__(self) -> int:
"""__hash__(self: object) -> int"""
def __index__(self) -> int:
Expand All @@ -50,8 +71,6 @@ class Point:
"""__int__(self: pybind11_mypy_demo.basics.Point.LengthUnit) -> int"""
def __ne__(self, other: object) -> bool:
"""__ne__(self: object, other: object) -> bool"""
def __setstate__(self, state: int) -> None:
"""__setstate__(self: pybind11_mypy_demo.basics.Point.LengthUnit, state: int) -> None"""
@property
def name(self) -> str: ...
@property
Expand All @@ -70,43 +89,46 @@ class Point:

1. __init__(self: pybind11_mypy_demo.basics.Point) -> None

2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None"""
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
"""
@overload
def __init__(self, x: float, y: float) -> None:
"""__init__(*args, **kwargs)
Overloaded function.

1. __init__(self: pybind11_mypy_demo.basics.Point) -> None

2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None"""
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
"""
@overload
def distance_to(self, x: float, y: float) -> float:
"""distance_to(*args, **kwargs)
Overloaded function.

1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float

2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float"""
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float
"""
@overload
def distance_to(self, other: Point) -> float:
"""distance_to(*args, **kwargs)
Overloaded function.

1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float

2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float"""
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float
"""
@property
def length(self) -> float: ...

def answer() -> int:
'''answer() -> int
"""answer() -> int"""

answer docstring, with end quote"'''
def midpoint(left: float, right: float) -> float:
"""midpoint(left: float, right: float) -> float"""

def sum(arg0: int, arg1: int) -> int:
'''sum(arg0: int, arg1: int) -> int
"""sum(arg0: int, arg1: int) -> int"""

multiline docstring test, edge case quotes """\'\'\''''
def weighted_midpoint(left: float, right: float, alpha: float = ...) -> float:
"""weighted_midpoint(left: float, right: float, alpha: float = 0.5) -> float"""
11 changes: 11 additions & 0 deletions test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/basics.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ from typing import ClassVar, overload
PI: float
__version__: str

class Foo:
def __init__(self, *args, **kwargs) -> None: ...
@overload
@staticmethod
def overloaded_static_method(value: int) -> int: ...
@overload
@staticmethod
def overloaded_static_method(value: float) -> float: ...
@staticmethod
def some_static_method(a: int, b: int) -> int: ...

class Point:
class AngleUnit:
__members__: ClassVar[dict] = ... # read-only
Expand Down