diff --git a/mypy/checker.py b/mypy/checker.py index e53e306a7e5d..b0be94af5ad5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -9,7 +9,7 @@ Any, Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable, Sequence, Mapping, Generic, AbstractSet, Callable, overload ) -from typing_extensions import Final, TypeAlias as _TypeAlias +from typing_extensions import Final, TypeAlias as _TypeAlias, TypeGuard from mypy.backports import nullcontext from mypy.errors import Errors, report_internal_error @@ -4698,22 +4698,40 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM # narrow their types. (For example, we shouldn't try narrowing the # types of literal string or enum expressions). + def is_narrowable_literal(expr: Expression) -> bool: + return (literal(expr) == LITERAL_TYPE and + not is_literal_none(expr) and + not is_literal_enum(type_map, expr)) + + def is_len_of_narrowable_literal(expr: Expression) -> TypeGuard[CallExpr]: + return ( + isinstance(expr, CallExpr) and + refers_to_fullname(expr.callee, 'builtins.len') and + len(expr.args) == 1 and + is_narrowable_literal(collapse_walrus(expr.args[0])) + ) + operands = [collapse_walrus(x) for x in node.operands] operand_types = [] narrowable_operand_index_to_hash = {} + narrowable_len_operand_index_to_hash = {} for i, expr in enumerate(operands): if expr not in type_map: return {}, {} expr_type = type_map[expr] operand_types.append(expr_type) - if (literal(expr) == LITERAL_TYPE - and not is_literal_none(expr) - and not is_literal_enum(type_map, expr)): + if is_narrowable_literal(expr): h = literal_hash(expr) if h is not None: narrowable_operand_index_to_hash[i] = h + if is_len_of_narrowable_literal(expr): + len_arg = collapse_walrus(expr.args[0]) + h = literal_hash(len_arg) + if h is not None: + narrowable_len_operand_index_to_hash[i] = h + # Step 2: Group operands chained by either the 'is' or '==' operands # together. For all other operands, we keep them in groups of size 2. # So the expression: @@ -4725,118 +4743,141 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM # [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]), # ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])] # - # We group identity/equality expressions so we can propagate information - # we discover about one operand across the entire chain. We don't bother + # We group identity/equality expressions for type and len checks so we can propagate + # information we discover about one operand across the entire chain. We don't bother # handling 'is not' and '!=' chains in a special way: those are very rare # in practice. simplified_operator_list = group_comparison_operands( node.pairwise(), - narrowable_operand_index_to_hash, + {**narrowable_operand_index_to_hash, **narrowable_len_operand_index_to_hash}, {'==', 'is'}, ) # Step 3: Analyze each group and infer more precise type maps for each - # assignable operand, if possible. We combine these type maps together - # in the final step. + # assignable operand, if possible. partial_type_maps = [] for operator, expr_indices in simplified_operator_list: - if operator in {'is', 'is not', '==', '!='}: - # is_valid_target: - # Controls which types we're allowed to narrow exprs to. Note that - # we cannot use 'is_literal_type_like' in both cases since doing - # 'x = 10000 + 1; x is 10001' is not always True in all Python - # implementations. - # - # coerce_only_in_literal_context: - # If true, coerce types into literal types only if one or more of - # the provided exprs contains an explicit Literal type. This could - # technically be set to any arbitrary value, but it seems being liberal - # with narrowing when using 'is' and conservative when using '==' seems - # to break the least amount of real-world code. - # - # should_narrow_by_identity: - # Set to 'false' only if the user defines custom __eq__ or __ne__ methods - # that could cause identity-based narrowing to produce invalid results. - if operator in {'is', 'is not'}: - is_valid_target: Callable[[Type], bool] = is_singleton_type - coerce_only_in_literal_context = False - should_narrow_by_identity = True - else: - def is_exactly_literal_type(t: Type) -> bool: - return isinstance(get_proper_type(t), LiteralType) + is_len_check_expression = any( + ind in narrowable_len_operand_index_to_hash + for ind in expr_indices) + + if not is_len_check_expression: + # We make sure there are no len comparisons before starting + # isinstance narrowing + if operator in {'is', 'is not', '==', '!='}: + # is_valid_target: + # Controls which types we're allowed to narrow exprs to. Note that + # we cannot use 'is_literal_type_like' in both cases since doing + # 'x = 10000 + 1; x is 10001' is not always True in all Python + # implementations. + # + # coerce_only_in_literal_context: + # If true, coerce types into literal types only if one or more of + # the provided exprs contains an explicit Literal type. This could + # technically be set to any arbitrary value, but it seems being liberal + # with narrowing when using 'is' and conservative when using '==' seems + # to break the least amount of real-world code. + # + # should_narrow_by_identity: + # Set to 'false' only if the user defines custom __eq__ or __ne__ methods + # that could cause identity-based narrowing to produce invalid results. + if operator in {'is', 'is not'}: + is_valid_target: Callable[[Type], bool] = is_singleton_type + coerce_only_in_literal_context = False + should_narrow_by_identity = True + else: + def is_exactly_literal_type(t: Type) -> bool: + return isinstance(get_proper_type(t), LiteralType) - def has_no_custom_eq_checks(t: Type) -> bool: - return (not custom_special_method(t, '__eq__', check_all=False) + def has_no_custom_eq_checks(t: Type) -> bool: + return ( + not custom_special_method(t, '__eq__', check_all=False) and not custom_special_method(t, '__ne__', check_all=False)) - is_valid_target = is_exactly_literal_type - coerce_only_in_literal_context = True - - expr_types = [operand_types[i] for i in expr_indices] - should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types)) + is_valid_target = is_exactly_literal_type + coerce_only_in_literal_context = True - if_map: TypeMap = {} - else_map: TypeMap = {} - if should_narrow_by_identity: - if_map, else_map = self.refine_identity_comparison_expression( - operands, - operand_types, - expr_indices, - narrowable_operand_index_to_hash.keys(), - is_valid_target, - coerce_only_in_literal_context, - ) + expr_types = [operand_types[i] for i in expr_indices] + should_narrow_by_identity = all( + map(has_no_custom_eq_checks, expr_types)) - # Strictly speaking, we should also skip this check if the objects in the expr - # chain have custom __eq__ or __ne__ methods. But we (maybe optimistically) - # assume nobody would actually create a custom objects that considers itself - # equal to None. - if if_map == {} and else_map == {}: - if_map, else_map = self.refine_away_none_in_comparison( + if_map: TypeMap = {} + else_map: TypeMap = {} + if should_narrow_by_identity: + if_map, else_map = self.refine_identity_comparison_expression( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + is_valid_target, + coerce_only_in_literal_context, + ) + + # Strictly speaking, we should also skip this check if the objects in + # the expr chain have custom __eq__ or __ne__ methods. + # But we (maybe optimistically) assume nobody would actually + # create a custom objects that considers itself equal to None. + if if_map == {} and else_map == {}: + if_map, else_map = self.refine_away_none_in_comparison( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + ) + + # If we haven't been able to narrow types yet, we might be dealing with a + # explicit type(x) == some_type check + if if_map == {} and else_map == {}: + if_map, else_map = self.find_type_equals_check(node, expr_indices) + elif operator in {'in', 'not in'}: + assert len(expr_indices) == 2 + left_index, right_index = expr_indices + if left_index not in narrowable_operand_index_to_hash: + continue + + item_type = operand_types[left_index] + collection_type = operand_types[right_index] + + # We only try and narrow away 'None' for now + if not is_optional(item_type): + continue + + collection_item_type = get_proper_type(builtin_item_type(collection_type)) + if collection_item_type is None or is_optional(collection_item_type): + continue + if (isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == 'builtins.object'): + continue + if is_overlapping_erased_types(item_type, collection_item_type): + if_map, else_map = { + operands[left_index]: remove_optional(item_type)}, {} + else: + continue + else: + if_map = {} + else_map = {} + else: + # comparison expression with len + if operator in {'==', '!=', '>=', '<=', '<', '>'}: + if_map, else_map = self.refine_len_comparison_expression( + operator, operands, operand_types, expr_indices, - narrowable_operand_index_to_hash.keys(), + narrowable_len_operand_index_to_hash.keys(), ) - - # If we haven't been able to narrow types yet, we might be dealing with a - # explicit type(x) == some_type check - if if_map == {} and else_map == {}: - if_map, else_map = self.find_type_equals_check(node, expr_indices) - elif operator in {'in', 'not in'}: - assert len(expr_indices) == 2 - left_index, right_index = expr_indices - if left_index not in narrowable_operand_index_to_hash: - continue - - item_type = operand_types[left_index] - collection_type = operand_types[right_index] - - # We only try and narrow away 'None' for now - if not is_optional(item_type): - continue - - collection_item_type = get_proper_type(builtin_item_type(collection_type)) - if collection_item_type is None or is_optional(collection_item_type): - continue - if (isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == 'builtins.object'): - continue - if is_overlapping_erased_types(item_type, collection_item_type): - if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {} else: - continue - else: - if_map = {} - else_map = {} + if_map = {} + else_map = {} - if operator in {'is not', '!=', 'not in'}: + if operator in {'is not', '!=', 'not in', '<', '>'}: if_map, else_map = else_map, if_map partial_type_maps.append((if_map, else_map)) + # We combine these type maps together in the final step. return reduce_conditional_maps(partial_type_maps) elif isinstance(node, AssignmentExpr): if_map = {} @@ -5193,6 +5234,160 @@ def refine_identity_comparison_expression(self, return reduce_conditional_maps(partial_type_maps) + def refine_len_comparison_expression(self, + operator: str, + operands: List[Expression], + operand_types: List[Type], + chain_indices: List[int], + narrowable_operand_indices: AbstractSet[int], + ) -> Tuple[TypeMap, TypeMap]: + """Produce conditional type maps refining expressions by len operator comparison. + + The 'operands' and 'operand_types' lists should be the full list of operands used + in the overall comparison expression. The 'chain_indices' list is the list of indices + actually used within this identity comparison chain. + + So if we have the expression: + + a <= b is c is d <= e + + ...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices' + would be the list [1, 2, 3]. + + The 'narrowable_operand_indices' parameter is the set of all indices we are allowed + to refine the types of: that is, all operands that will potentially be a part of + the output TypeMaps. + + Although this function could theoretically try setting the types of the operands + in the chains to the meet, doing that causes too many issues in real-world code. + Instead, we use 'is_valid_target' to identify which of the given chain types + we could plausibly use as the refined type for the expressions in the chain. + + Similarly, 'coerce_only_in_literal_context' controls whether we should try coercing + expressions in the chain to a Literal type. Performing this coercion is sometimes + too aggressive of a narrowing, depending on context. + """ + + target: Optional[int] = None + target_index: Optional[int] = None + possible_target_indices = [] + for i in chain_indices: + expr_type = operand_types[i] + expr_type = coerce_to_literal(expr_type) + proper_type = get_proper_type(expr_type) + if not isinstance(proper_type, LiteralType): + continue + if target and target != proper_type.value: + if operator in {'==', '!='}: + # We have multiple different target values. So the 'if' branch + # must be unreachable. + return None, {} + else: + # Other operators can go either way + return {}, {} + + target = proper_type.value # type: ignore[assignment] + target_index = i + possible_target_indices.append(i) + + # There's nothing we can currently infer if none of the operands are valid targets, + # so we end early and infer nothing. + if target is None: + return {}, {} + + partial_type_maps = [] + for i in chain_indices: + # Naturally, we can't refine operands which are not permitted to be refined. + if i not in narrowable_operand_indices: + continue + + # we already checked that operand[i] is CallExpr since it is narrowable + expr = operands[i].args[0] # type: ignore[attr-defined] + expr_type = self.type_map[expr] + + # We intentionally use 'conditional_type_map' directly here instead of + # 'self.conditional_type_map_with_intersection': we only compute ad-hoc + # intersections when working with pure instances. + partial_type_maps.append( + self.conditional_len_map(operator, expr, expr_type, i, target, target_index)) + + return reduce_conditional_maps(partial_type_maps) + + def narrow_type_by_length(self, operator: str, typ: Type, length: int) -> Type: + if operator not in {"==", "!="}: + return typ + proper_type = get_proper_type(typ) + if (isinstance(proper_type, Instance) and proper_type.type.fullname == "builtins.tuple" + and length >= 0): + return TupleType([proper_type.args[0]] * length, self.named_type('builtins.tuple')) + return typ + + def conditional_len_map(self, + operator: str, + expr: Expression, + current_type: Optional[Type], + expr_index: int, + length: Optional[int], + target_index: Optional[int], + ) -> Tuple[TypeMap, TypeMap]: + """Takes in an expression, the current type of the expression, and a + proposed length of that expression. + + Returns a 2-tuple: The first element is a map from the expression to + the proposed type, if the expression can be the proposed length. The + second element is a map from the expression to the type it would hold + if it was not the proposed length, if any. None means bot, {} means top""" + if length is not None and current_type is not None and target_index is not None: + proper_type = get_proper_type(current_type) + if isinstance(proper_type, AnyType): + # We don't really know much about the proposed type, so we shouldn't + # attempt to narrow anything. Instead, we broaden the expr to Any to + # avoid false positives + return {expr: current_type}, {} + else: + possible_types = union_items(current_type) + len_of_types = [len_of_type(typ) for typ in possible_types] + + if operator in {'>=', '<=', '<', '>'} and target_index < expr_index: + if operator == '>=': + operator = '<=' + elif operator == '>': + operator = '<' + elif operator == '<=': + operator = '>=' + else: + operator = '>' + + # We reverse the map for some operator outside this function + length_op_translator = { + '==': int.__eq__, + '!=': int.__eq__, + '>=': int.__ge__, + '<': int.__ge__, + '<=': int.__le__, + '>': int.__le__, + } + + assert operator in length_op_translator + length_op = length_op_translator[operator] + + proposed_type = make_simplified_union([ + self.narrow_type_by_length(operator, typ, length) + for typ, l in zip(possible_types, len_of_types) + if l is None or length_op(l, length)]) + remaining_type = make_simplified_union([ + typ for typ, l in zip(possible_types, len_of_types) + if l is None or not length_op(l, length)]) + if_map: TypeMap = ( + {} if is_same_type(proposed_type, current_type) + else {expr: proposed_type}) + else_map: TypeMap = ( + {} if is_same_type(remaining_type, current_type) + else {expr: remaining_type}) + return if_map, else_map + else: + return {}, {} + def refine_away_none_in_comparison(self, operands: List[Expression], operand_types: List[Type], @@ -5774,6 +5969,17 @@ def conditional_types_to_typemaps(expr: Expression, return cast(Tuple[TypeMap, TypeMap], tuple(maps)) +def len_of_type(typ: Type) -> Optional[int]: + """Takes a type and returns an int that represents the length + of instances of that type or None if not applicable or variant length""" + proper_type = get_proper_type(typ) + if isinstance(proper_type, TupleType): + return len(proper_type.items) + if isinstance(proper_type, LiteralType) and isinstance(proper_type.value, str): + return len(proper_type.value) + return None + + def gen_unique_name(base: str, table: SymbolTable) -> str: """Generate a name that does not appear in table by appending numbers to base.""" if base not in table: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 23715b24d43e..a6f107ef2824 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1245,3 +1245,164 @@ def two_type_vars(x: Union[str, Dict[str, int], Dict[bool, object], int]) -> Non else: reveal_type(x) # N: Revealed type is "builtins.int" [builtins fixtures/dict.pyi] +[case testNarrowingLenItemAndLenCompare] +from typing import Tuple, Union, Any + +x: Any +if len(x) == x: + reveal_type(x) # N: Revealed type is "Any" +[builtins fixtures/len.pyi] + +[case testNarrowingLenTuple] +from typing import Tuple, Union + +VarTuple = Union[Tuple[int, int], Tuple[int, int, int]] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +a = b = c = 0 +if len(x) == 3: + a, b, c = x +else: + a, b = x + +if len(x) != 3: + a, b = x +else: + a, b, c = x +[builtins fixtures/len.pyi] + +[case testNarrowingLenVariantLengthTuple] +from typing import Tuple, Union + +def make_tuple() -> Tuple[int, ...]: + return (1, 1) + +x = make_tuple() + +if len(x) == 3: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" +else: + reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" + +if len(x) != 3: + reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" +else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenTypeUnaffected] +from typing import Tuple, Union, List, Any + +def make() -> Union[str, List[int]]: + return "" + +x = make() + +if len(x) == 3: + reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.list[builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.list[builtins.int]]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenAnyListElseNotAffected] +from typing import Any +def f(self, value: Any) -> Any: + if isinstance(value, list) and len(value) == 0: + reveal_type(value) # N: Revealed type is "builtins.list[Any]" + return value + reveal_type(value) # N: Revealed type is "Any" + return None +[builtins fixtures/len.pyi] + +[case testNarrowingLenLiteral] +from typing import Tuple, Union +from typing_extensions import Literal + +def make() -> Literal['a', 'bb', 'cc', 'd']: + return "a" + +x = make() + +if len(x) == 2: + reveal_type(x) # N: Revealed type is "Union[Literal['bb'], Literal['cc']]" +else: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['d']]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenMultiple] +from typing import Tuple, Union + +VarTuple = Union[Tuple[int, int], Tuple[int, int, int]] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +y = make_tuple() +if len(x) == len(y) == 3: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" + reveal_type(y) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenFinal] +from typing import Tuple, Union +from typing_extensions import Final + +VarTuple = Union[Tuple[int, int], Tuple[int, int, int]] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +fin: Final = 3 +if len(x) == fin: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenBiggerThan] +from typing import Tuple, Union + +VarTuple = Union[Tuple[int], Tuple[int, int], Tuple[int, int, int]] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +if len(x) > 1: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int]" + +if len(x) < 2: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int]" +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" + +if len(x) >= 2: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int]" + +if len(x) <= 2: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int], Tuple[builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenBiggerThanVariantTuple] +from typing import Tuple + +VarTuple = Tuple[int, ...] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +if len(x) < 3: + reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" +else: + reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" +[builtins fixtures/len.pyi] diff --git a/test-data/unit/fixtures/len.pyi b/test-data/unit/fixtures/len.pyi new file mode 100644 index 000000000000..8a8bdf864061 --- /dev/null +++ b/test-data/unit/fixtures/len.pyi @@ -0,0 +1,38 @@ +from typing import Tuple, TypeVar, Generic, Union, Any, Type, Sequence +from typing_extensions import Protocol + +T = TypeVar('T') + +class object: + def __init__(self) -> None: pass + +class type: + def __init__(self, x) -> None: pass + +class tuple(Generic[T]): + def __len__(self) -> int: pass + +class list(Sequence[T]): pass + +class function: pass + +class Sized(Protocol): + def __len__(self) -> int: pass + +def len(__obj: Sized) -> int: ... +def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass + +class int: + def __add__(self, other: 'int') -> 'int': pass + def __eq__(self, other: 'int') -> 'bool': pass + def __ne__(self, other: 'int') -> 'bool': pass + def __lt__(self, n: 'int') -> 'bool': pass + def __gt__(self, n: 'int') -> 'bool': pass + def __le__(self, n: 'int') -> 'bool': pass + def __ge__(self, n: 'int') -> 'bool': pass +class float: pass +class bool(int): pass +class str: + def __add__(self, other: 'str') -> 'str': pass + def __len__(self) -> int: pass +class ellipsis: pass diff --git a/test-data/unit/lib-stub/typing.pyi b/test-data/unit/lib-stub/typing.pyi index 57563fc9d2f6..873f5d6372a3 100644 --- a/test-data/unit/lib-stub/typing.pyi +++ b/test-data/unit/lib-stub/typing.pyi @@ -43,6 +43,7 @@ class Generator(Iterator[T], Generic[T, U, V]): class Sequence(Iterable[T_co]): def __getitem__(self, n: Any) -> T_co: pass + def __len__(self) -> int: pass # Mapping type is oversimplified intentionally. class Mapping(Iterable[T], Generic[T, T_co]): pass