Skip to content

Commit

Permalink
Allow class variables in subsubclasses to only match base class type
Browse files Browse the repository at this point in the history
Related python#10375 and python#10506.

For an unhinted assignment to a class variable defined in a base class,
this allows subderived classes to only match the type in the base class,
rather than the one inferred from the assignment value.
  • Loading branch information
wrwrwr committed Jan 21, 2023
1 parent cc1bcc9 commit c69803a
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 34 deletions.
71 changes: 40 additions & 31 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2952,39 +2952,49 @@ def check_compatibility_all_supers(
# Show only one error per variable
break

direct_bases = lvalue_node.info.direct_base_classes()
last_immediate_base = direct_bases[-1] if direct_bases else None
if is_private(lvalue_node.name):
return False

for base in lvalue_node.info.mro[1:]:
# The type of "__slots__" and some other attributes usually doesn't need to
# be compatible with a base class. We'll still check the type of "__slots__"
# against "object" as an exception.
if lvalue_node.allow_incompatible_override and not (
lvalue_node.name == "__slots__" and base.fullname == "builtins.object"
for base, base_type, base_node in self.classvar_base_types(lvalue_node):
if not self.check_compatibility_super(
lvalue, lvalue_type, rvalue, base, base_type, base_node
):
continue

if is_private(lvalue_node.name):
continue

base_type, base_node = self.lvalue_type_from_base(lvalue_node, base)
if isinstance(base_type, PartialType):
base_type = None

if base_type:
assert base_node is not None
if not self.check_compatibility_super(
lvalue, lvalue_type, rvalue, base, base_type, base_node
):
# Only show one error per variable; even if other
# base classes are also incompatible
return True
if base is last_immediate_base:
# At this point, the attribute was found to be compatible with all
# immediate parents.
break
# Only show one error per variable; even if other
# base classes are also incompatible
return True
return False

def classvar_base_types(self, node: Var) -> list[tuple[TypeInfo, Type, Node]]:
"""Determine base classes that a class variable should be checked against."""
base_types = []
direct_bases = node.info.direct_base_classes()
last_immediate_base = direct_bases[-1] if direct_bases else None
for base in node.info.mro[1:]:
# The type of "__slots__" and some other attributes usually doesn't need to
# be compatible with a base class. We'll still check the type of "__slots__"
# against "object" as an exception.
if node.allow_incompatible_override and not (
node.name == "__slots__" and base.fullname == "builtins.object"
):
continue
base_type, base_node = self.lvalue_type_from_base(node, base)
if not base_type or isinstance(base_type, PartialType):
continue
assert base_node is not None
if isinstance(base_node, Var) and base_node.is_inferred:
# Skip the type inferred from the value if there is a superclass
# with an annotation or a (possibly more general) inferred type.
base_node_base_types = self.classvar_base_types(base_node)
if base_node_base_types:
base_types.extend(base_node_base_types)
else:
base_types.append((base, base_type, base_node))
else:
base_types.append((base, base_type, base_node))
if base is last_immediate_base:
break
return base_types

def check_compatibility_super(
self,
lvalue: RefExpr,
Expand Down Expand Up @@ -3053,8 +3063,7 @@ def check_compatibility_super(
def lvalue_type_from_base(
self, expr_node: Var, base: TypeInfo
) -> tuple[Type | None, Node | None]:
"""For a NameExpr that is part of a class, walk all base classes and try
to find the first class that defines a Type for the same name."""
"""Get a variable type with respect to a base class."""
expr_name = expr_node.name
base_var = base.names.get(expr_name)

Expand Down
4 changes: 1 addition & 3 deletions test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -4107,7 +4107,7 @@ from typing import Union
class A:
a = None # type: Union[int, str]
class B(A):
a = 1
a = 1 # type: int
class C(B):
a = "str"
class D(A):
Expand Down Expand Up @@ -4344,8 +4344,6 @@ class B(A):
x = 1
class C(B):
x = ''
[out]
main:6: error: Incompatible types in assignment (expression has type "str", base class "B" defined the type as "int")

[case testSlots]
class A:
Expand Down
41 changes: 41 additions & 0 deletions test-data/unit/check-classvar.test
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,44 @@ class C:
c:C
c.foo() # E: Too few arguments \
# N: "foo" is considered instance variable, to make it class variable use ClassVar[...]

[case testClassVarVariableLengthTuple]
from typing import ClassVar, Tuple
class A:
x: ClassVar[Tuple[int, ...]]
class B(A):
x = (1,)
class C(B):
x = (2, 3)
class D(B):
x = ("a",) # E: Incompatible types in assignment (expression has type "Tuple[str]", base class "A" defined the type as "Tuple[int, ...]")
[builtins fixtures/tuple.pyi]

[case testClassVarVariableLengthTupleTwoBases]
from typing import ClassVar, Tuple, Union
class A:
x: ClassVar[Tuple[Union[int, str], ...]]
class B:
x: ClassVar[Tuple[int, ...]]
class C(A, B):
x = (1,)
class D(C):
x = (2, 3)
class E(A, B):
x = ("a",) # E: Incompatible types in assignment (expression has type "Tuple[str]", base class "B" defined the type as "Tuple[int, ...]")
[builtins fixtures/tuple.pyi]

[case testClassVarVariableLengthTupleGeneric]
from typing import ClassVar, Generic, Tuple, TypeVar
T = TypeVar('T')
class A(Generic[T]):
x: ClassVar[Tuple[T, ...]] # E: ClassVar cannot contain type variables
class B(A[int]):
pass
class C(B):
x = (1,)
class D(C):
x = (2, 3)
class E(B):
x = ("a",) # E: Incompatible types in assignment (expression has type "Tuple[str]", base class "A" defined the type as "Tuple[int, ...]")
[builtins fixtures/tuple.pyi]
24 changes: 24 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -3302,6 +3302,30 @@ class C(P):
x = ['a'] # E: List item 0 has incompatible type "str"; expected "int"
[builtins fixtures/list.pyi]

[case testUseSupertypeForUntypedList]
from typing import List
class A:
x: List[int] = []
class B(A):
x = []
class C(B):
x: List[str] = [] # E: Incompatible types in assignment (expression has type "List[str]", base class "A" defined the type as "List[int]")
class D(B):
x = ["a"] # E: List item 0 has incompatible type "str"; expected "int"
[builtins fixtures/list.pyi]

[case testUseSupertypeForUntypedTuple]
from typing import Tuple
class A:
x = (1,)
class B(A):
x = (2,)
class C(A):
x = (2, 3) # E: Incompatible types in assignment (expression has type "Tuple[int, int]", base class "A" defined the type as "Tuple[int]")
class D(B):
x = (2, 3) # E: Incompatible types in assignment (expression has type "Tuple[int, int]", base class "A" defined the type as "Tuple[int]")
[builtins fixtures/tuple.pyi]

[case testUseSupertypeAsInferenceContextPartial]
from typing import List

Expand Down

0 comments on commit c69803a

Please sign in to comment.