diff --git a/mypy/stubgen.py b/mypy/stubgen.py index aca836c52ce82..ca72494657468 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -657,6 +657,7 @@ def __init__( self.defined_names: set[str] = set() # Short names of methods defined in the body of the current class self.method_names: set[str] = set() + self.processing_dataclass = False def visit_mypy_file(self, o: MypyFile) -> None: self.module = o.fullname # Current module being processed @@ -706,6 +707,12 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: self.clear_decorators() def visit_func_def(self, o: FuncDef) -> None: + is_dataclass_generated = ( + self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated + ) + if is_dataclass_generated and o.name != "__init__": + # Skip methods generated by the @dataclass decorator (except for __init__) + return if ( self.is_private_name(o.name, o.fullname) or self.is_not_in_all(o.name) @@ -771,6 +778,12 @@ def visit_func_def(self, o: FuncDef) -> None: else: arg = name + annotation args.append(arg) + if o.name == "__init__" and is_dataclass_generated and "**" in args: + # The dataclass plugin generates invalid nameless "*" and "**" arguments + new_name = "".join(a.split(":", 1)[0] for a in args).replace("*", "") + args[args.index("*")] = f"*{new_name}_" # this name is guaranteed to be unique + args[args.index("**")] = f"**{new_name}__" # same here + retname = None if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType): if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType): @@ -899,6 +912,9 @@ def visit_class_def(self, o: ClassDef) -> None: if not self._indent and self._state != EMPTY: sep = len(self._output) self.add("\n") + decorators = self.get_class_decorators(o) + for d in decorators: + self.add(f"{self._indent}@{d}\n") self.add(f"{self._indent}class {o.name}") self.record_name(o.name) base_types = self.get_base_types(o) @@ -934,6 +950,7 @@ def visit_class_def(self, o: ClassDef) -> None: else: self._state = CLASS self.method_names = set() + self.processing_dataclass = False self._current_class = None def get_base_types(self, cdef: ClassDef) -> list[str]: @@ -979,6 +996,21 @@ def get_base_types(self, cdef: ClassDef) -> list[str]: base_types.append(f"{name}={value.accept(p)}") return base_types + def get_class_decorators(self, cdef: ClassDef) -> list[str]: + decorators: list[str] = [] + p = AliasPrinter(self) + for d in cdef.decorators: + if self.is_dataclass(d): + decorators.append(d.accept(p)) + self.import_tracker.require_name(get_qualified_name(d)) + self.processing_dataclass = True + return decorators + + def is_dataclass(self, expr: Expression) -> bool: + if isinstance(expr, CallExpr): + expr = expr.callee + return self.get_fullname(expr) == "dataclasses.dataclass" + def visit_block(self, o: Block) -> None: # Unreachable statements may be partially uninitialized and that may # cause trouble. @@ -1336,6 +1368,9 @@ def get_init( # Final without type argument is invalid in stubs. final_arg = self.get_str_type_of_node(rvalue) typename += f"[{final_arg}]" + elif self.processing_dataclass: + # attribute without annotation is not a dataclass field, don't add annotation. + return f"{self._indent}{lvalue} = ...\n" else: typename = self.get_str_type_of_node(rvalue) initializer = self.get_assign_initializer(rvalue) @@ -1343,12 +1378,20 @@ def get_init( def get_assign_initializer(self, rvalue: Expression) -> str: """Does this rvalue need some special initializer value?""" - if self._current_class and self._current_class.info: - # Current rules - # 1. Return `...` if we are dealing with `NamedTuple` and it has an existing default value - if self._current_class.info.is_named_tuple and not isinstance(rvalue, TempNode): - return " = ..." - # TODO: support other possible cases, where initializer is important + if not self._current_class: + return "" + # Current rules + # 1. Return `...` if we are dealing with `NamedTuple` or `dataclass` field and + # it has an existing default value + if ( + self._current_class.info + and self._current_class.info.is_named_tuple + and not isinstance(rvalue, TempNode) + ): + return " = ..." + if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs): + return " = ..." + # TODO: support other possible cases, where initializer is important # By default, no initializer is required: return "" @@ -1410,6 +1453,8 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool: return False if fullname in EXTRA_EXPORTED: return False + if name == "_": + return False return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS) def is_private_member(self, fullname: str) -> bool: diff --git a/mypy/test/teststubgen.py b/mypy/test/teststubgen.py index 79d380785a39d..7e30515ac8926 100644 --- a/mypy/test/teststubgen.py +++ b/mypy/test/teststubgen.py @@ -724,11 +724,22 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None: def parse_flags(self, program_text: str, extra: list[str]) -> Options: flags = re.search("# flags: (.*)$", program_text, flags=re.MULTILINE) + pyversion = None if flags: flag_list = flags.group(1).split() + for i, flag in enumerate(flag_list): + if flag.startswith("--python-version="): + pyversion = flag.split("=", 1)[1] + del flag_list[i] + break else: flag_list = [] options = parse_options(flag_list + extra) + if pyversion: + # A hack to allow testing old python versions with new language constructs + # This should be rarely used in general as stubgen output should not be version-specific + major, minor = pyversion.split(".", 1) + options.pyversion = (int(major), int(minor)) if "--verbose" not in flag_list: options.quiet = True else: diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 774a17b76161d..828680fadcf23 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -3512,3 +3512,185 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ... class X(_Incomplete): ... class Y(_Incomplete): ... + +[case testDataclass] +import dataclasses +import dataclasses as dcs +from dataclasses import dataclass, InitVar, KW_ONLY +from dataclasses import dataclass as dc +from typing import ClassVar + +@dataclasses.dataclass +class X: + a: int + b: str = "hello" + c: ClassVar + d: ClassVar = 200 + f: list[int] = field(init=False, default_factory=list) + g: int = field(default=2, kw_only=True) + _: KW_ONLY + h: int = 1 + i: InitVar[str] + j: InitVar = 100 + non_field = None + +@dcs.dataclass +class Y: ... + +@dataclass +class Z: ... + +@dc +class W: ... + +@dataclass(init=False, repr=False) +class V: ... + +[out] +import dataclasses +import dataclasses as dcs +from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc +from typing import ClassVar + +@dataclasses.dataclass +class X: + a: int + b: str = ... + c: ClassVar + d: ClassVar = ... + f: list[int] = ... + g: int = ... + _: KW_ONLY + h: int = ... + i: InitVar[str] + j: InitVar = ... + non_field = ... + +@dcs.dataclass +class Y: ... +@dataclass +class Z: ... +@dc +class W: ... +@dataclass(init=False, repr=False) +class V: ... + +[case testDataclass_semanal] +from dataclasses import dataclass, InitVar +from typing import ClassVar + +@dataclass +class X: + a: int + b: str = "hello" + c: ClassVar + d: ClassVar = 200 + f: list[int] = field(init=False, default_factory=list) + g: int = field(default=2, kw_only=True) + h: int = 1 + i: InitVar[str] + j: InitVar = 100 + non_field = None + +@dataclass(init=False, repr=False, frozen=True) +class Y: ... + +[out] +from dataclasses import InitVar, dataclass +from typing import ClassVar + +@dataclass +class X: + a: int + b: str = ... + c: ClassVar + d: ClassVar = ... + f: list[int] = ... + g: int = ... + h: int = ... + i: InitVar[str] + j: InitVar = ... + non_field = ... + def __init__(self, a, b, f, g, h, i, j) -> None: ... + +@dataclass(init=False, repr=False, frozen=True) +class Y: ... + +[case testDataclassWithKwOnlyField_semanal] +# flags: --python-version=3.10 +from dataclasses import dataclass, InitVar, KW_ONLY +from typing import ClassVar + +@dataclass +class X: + a: int + b: str = "hello" + c: ClassVar + d: ClassVar = 200 + f: list[int] = field(init=False, default_factory=list) + g: int = field(default=2, kw_only=True) + _: KW_ONLY + h: int = 1 + i: InitVar[str] + j: InitVar = 100 + non_field = None + +@dataclass(init=False, repr=False, frozen=True) +class Y: ... + +[out] +from dataclasses import InitVar, KW_ONLY, dataclass +from typing import ClassVar + +@dataclass +class X: + a: int + b: str = ... + c: ClassVar + d: ClassVar = ... + f: list[int] = ... + g: int = ... + _: KW_ONLY + h: int = ... + i: InitVar[str] + j: InitVar = ... + non_field = ... + def __init__(self, a, b, f, g, *, h, i, j) -> None: ... + +@dataclass(init=False, repr=False, frozen=True) +class Y: ... + +[case testDataclassWithExplicitGeneratedMethodsOverrides_semanal] +from dataclasses import dataclass + +@dataclass +class X: + a: int + def __init__(self, a: int, b: str = ...) -> None: ... + def __post_init__(self) -> None: ... + +[out] +from dataclasses import dataclass + +@dataclass +class X: + a: int + def __init__(self, a: int, b: str = ...) -> None: ... + def __post_init__(self) -> None: ... + +[case testDataclassInheritsFromAny_semanal] +from dataclasses import dataclass +import missing + +@dataclass +class X(missing.Base): + a: int + +[out] +import missing +from dataclasses import dataclass + +@dataclass +class X(missing.Base): + a: int + def __init__(self, *selfa_, a, **selfa__) -> None: ...