diff --git a/docs/source/config_file.rst b/docs/source/config_file.rst index 5cfc5f86e37f1..22893ff069d5f 100644 --- a/docs/source/config_file.rst +++ b/docs/source/config_file.rst @@ -676,6 +676,13 @@ section of the command line docs. from foo import bar __all__ = ['bar'] +.. confval:: strict_concatenate + + :type: boolean + :default: False + + Make arguments prepended via ``Concatenate`` be truly positional-only. + .. confval:: strict_equality :type: boolean diff --git a/mypy/applytype.py b/mypy/applytype.py index 5b803a4aaa0b4..a967d834f1a2e 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -5,7 +5,7 @@ from mypy.expandtype import expand_type from mypy.types import ( Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types, - TypeVarLikeType, ProperType, ParamSpecType, get_proper_type + TypeVarLikeType, ProperType, ParamSpecType, Parameters, get_proper_type ) from mypy.nodes import Context @@ -94,7 +94,7 @@ def apply_generic_arguments( nt = id_to_type.get(param_spec.id) if nt is not None: nt = get_proper_type(nt) - if isinstance(nt, CallableType): + if isinstance(nt, CallableType) or isinstance(nt, Parameters): callable = callable.expand_param_spec(nt) # Apply arguments to argument types. diff --git a/mypy/checker.py b/mypy/checker.py index a02a877a808fe..d0f6922fdf5e2 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5224,7 +5224,7 @@ def check_subtype(self, code: Optional[ErrorCode] = None, outer_context: Optional[Context] = None) -> bool: """Generate an error if the subtype is not compatible with supertype.""" - if is_subtype(subtype, supertype): + if is_subtype(subtype, supertype, options=self.options): return True if isinstance(msg, ErrorMessage): @@ -5260,6 +5260,7 @@ def check_subtype(self, self.msg.note(note, context, code=code) if note_msg: self.note(note_msg, context, code=code) + self.msg.maybe_note_concatenate_pos_args(subtype, supertype, context, code=code) if (isinstance(supertype, Instance) and supertype.type.is_protocol and isinstance(subtype, (Instance, TupleType, TypedDictType))): self.msg.report_protocol_problems(subtype, supertype, context, code=code) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 2c0bf9656d060..45d5818d4eeba 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1556,7 +1556,7 @@ def check_arg(self, isinstance(callee_type.item, Instance) and (callee_type.item.type.is_abstract or callee_type.item.type.is_protocol)): self.msg.concrete_only_call(callee_type, context) - elif not is_subtype(caller_type, callee_type): + elif not is_subtype(caller_type, callee_type, options=self.chk.options): if self.chk.should_suppress_optional_error([caller_type, callee_type]): return code = messages.incompatible_argument(n, diff --git a/mypy/constraints.py b/mypy/constraints.py index 8a05b527b6bd2..0b2217b21ae0c 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -8,7 +8,7 @@ TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType, ProperType, ParamSpecType, get_proper_type, TypeAliasType, is_union_with_any, - UnpackType, callable_with_ellipsis, TUPLE_LIKE_INSTANCE_NAMES, + UnpackType, callable_with_ellipsis, Parameters, TUPLE_LIKE_INSTANCE_NAMES, ) from mypy.maptype import map_instance_to_supertype import mypy.subtypes @@ -406,6 +406,9 @@ def visit_param_spec(self, template: ParamSpecType) -> List[Constraint]: def visit_unpack_type(self, template: UnpackType) -> List[Constraint]: raise NotImplementedError + def visit_parameters(self, template: Parameters) -> List[Constraint]: + raise RuntimeError("Parameters cannot be constrained to") + # Non-leaf types def visit_instance(self, template: Instance) -> List[Constraint]: @@ -446,7 +449,7 @@ def visit_instance(self, template: Instance) -> List[Constraint]: # N.B: We use zip instead of indexing because the lengths might have # mismatches during daemon reprocessing. for tvar, mapped_arg, instance_arg in zip(tvars, mapped.args, instance.args): - # TODO: ParamSpecType + # TODO(PEP612): More ParamSpec work (or is Parameters the only thing accepted) if isinstance(tvar, TypeVarType): # The constraints for generic type parameters depend on variance. # Include constraints from both directions if invariant. @@ -456,6 +459,27 @@ def visit_instance(self, template: Instance) -> List[Constraint]: if tvar.variance != COVARIANT: res.extend(infer_constraints( mapped_arg, instance_arg, neg_op(self.direction))) + elif isinstance(tvar, ParamSpecType) and isinstance(mapped_arg, ParamSpecType): + suffix = get_proper_type(instance_arg) + + if isinstance(suffix, CallableType): + prefix = mapped_arg.prefix + from_concat = bool(prefix.arg_types) or suffix.from_concatenate + suffix = suffix.copy_modified(from_concatenate=from_concat) + + if isinstance(suffix, Parameters) or isinstance(suffix, CallableType): + # no such thing as variance for ParamSpecs + # TODO: is there a case I am missing? + # TODO: constraints between prefixes + prefix = mapped_arg.prefix + suffix = suffix.copy_modified( + suffix.arg_types[len(prefix.arg_types):], + suffix.arg_kinds[len(prefix.arg_kinds):], + suffix.arg_names[len(prefix.arg_names):]) + res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix)) + elif isinstance(suffix, ParamSpecType): + res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix)) + return res elif (self.direction == SUPERTYPE_OF and instance.type.has_base(template.type.fullname)): @@ -464,7 +488,6 @@ def visit_instance(self, template: Instance) -> List[Constraint]: # N.B: We use zip instead of indexing because the lengths might have # mismatches during daemon reprocessing. for tvar, mapped_arg, template_arg in zip(tvars, mapped.args, template.args): - # TODO: ParamSpecType if isinstance(tvar, TypeVarType): # The constraints for generic type parameters depend on variance. # Include constraints from both directions if invariant. @@ -474,6 +497,28 @@ def visit_instance(self, template: Instance) -> List[Constraint]: if tvar.variance != COVARIANT: res.extend(infer_constraints( template_arg, mapped_arg, neg_op(self.direction))) + elif (isinstance(tvar, ParamSpecType) and + isinstance(template_arg, ParamSpecType)): + suffix = get_proper_type(mapped_arg) + + if isinstance(suffix, CallableType): + prefix = template_arg.prefix + from_concat = bool(prefix.arg_types) or suffix.from_concatenate + suffix = suffix.copy_modified(from_concatenate=from_concat) + + if isinstance(suffix, Parameters) or isinstance(suffix, CallableType): + # no such thing as variance for ParamSpecs + # TODO: is there a case I am missing? + # TODO: constraints between prefixes + prefix = template_arg.prefix + + suffix = suffix.copy_modified( + suffix.arg_types[len(prefix.arg_types):], + suffix.arg_kinds[len(prefix.arg_kinds):], + suffix.arg_names[len(prefix.arg_names):]) + res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix)) + elif isinstance(suffix, ParamSpecType): + res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix)) return res if (template.type.is_protocol and self.direction == SUPERTYPE_OF and # We avoid infinite recursion for structural subtypes by checking @@ -564,11 +609,34 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: # Negate direction due to function argument type contravariance. res.extend(infer_constraints(t, a, neg_op(self.direction))) else: + # sometimes, it appears we try to get constraints between two paramspec callables? # TODO: Direction - # TODO: Deal with arguments that come before param spec ones? - res.append(Constraint(param_spec.id, - SUBTYPE_OF, - cactual.copy_modified(ret_type=NoneType()))) + # TODO: check the prefixes match + prefix = param_spec.prefix + prefix_len = len(prefix.arg_types) + cactual_ps = cactual.param_spec() + + if not cactual_ps: + res.append(Constraint(param_spec.id, + SUBTYPE_OF, + cactual.copy_modified( + arg_types=cactual.arg_types[prefix_len:], + arg_kinds=cactual.arg_kinds[prefix_len:], + arg_names=cactual.arg_names[prefix_len:], + ret_type=NoneType()))) + else: + res.append(Constraint(param_spec.id, SUBTYPE_OF, cactual_ps)) + + # compare prefixes + cactual_prefix = cactual.copy_modified( + arg_types=cactual.arg_types[:prefix_len], + arg_kinds=cactual.arg_kinds[:prefix_len], + arg_names=cactual.arg_names[:prefix_len]) + + # TODO: see above "FIX" comments for param_spec is None case + # TODO: this assume positional arguments + for t, a in zip(prefix.arg_types, cactual_prefix.arg_types): + res.extend(infer_constraints(t, a, neg_op(self.direction))) template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type if template.type_guard is not None: diff --git a/mypy/erasetype.py b/mypy/erasetype.py index e1a0becc8447b..ff0ef6c0784ee 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -4,7 +4,7 @@ Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType, - get_proper_type, get_proper_types, TypeAliasType, ParamSpecType, UnpackType + get_proper_type, get_proper_types, TypeAliasType, ParamSpecType, Parameters, UnpackType ) from mypy.nodes import ARG_STAR, ARG_STAR2 @@ -59,6 +59,9 @@ def visit_type_var(self, t: TypeVarType) -> ProperType: def visit_param_spec(self, t: ParamSpecType) -> ProperType: return AnyType(TypeOfAny.special_form) + def visit_parameters(self, t: Parameters) -> ProperType: + raise RuntimeError("Parameters should have been bound to a class") + def visit_unpack_type(self, t: UnpackType) -> ProperType: raise NotImplementedError diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 36cc83f439dd6..39606c263f6b7 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -5,7 +5,8 @@ NoneType, Overloaded, TupleType, TypedDictType, UnionType, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, FunctionLike, TypeVarType, LiteralType, get_proper_type, ProperType, - TypeAliasType, ParamSpecType, TypeVarLikeType, UnpackType + TypeAliasType, ParamSpecType, TypeVarLikeType, Parameters, ParamSpecFlavor, + UnpackType ) @@ -101,15 +102,41 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: repl = get_proper_type(self.variables.get(t.id, t)) if isinstance(repl, Instance): inst = repl + # Return copy of instance with type erasure flag on. + # TODO: what does prefix mean in this case? + # TODO: why does this case even happen? Instances aren't plural. return Instance(inst.type, inst.args, line=inst.line, column=inst.column) elif isinstance(repl, ParamSpecType): - return repl.with_flavor(t.flavor) + return repl.copy_modified(flavor=t.flavor, prefix=t.prefix.copy_modified( + arg_types=t.prefix.arg_types + repl.prefix.arg_types, + arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds, + arg_names=t.prefix.arg_names + repl.prefix.arg_names, + )) + elif isinstance(repl, Parameters) or isinstance(repl, CallableType): + # if the paramspec is *P.args or **P.kwargs: + if t.flavor != ParamSpecFlavor.BARE: + assert isinstance(repl, CallableType), "Should not be able to get here." + # Is this always the right thing to do? + param_spec = repl.param_spec() + if param_spec: + return param_spec.with_flavor(t.flavor) + else: + return repl + else: + return Parameters(t.prefix.arg_types + repl.arg_types, + t.prefix.arg_kinds + repl.arg_kinds, + t.prefix.arg_names + repl.arg_names, + variables=[*t.prefix.variables, *repl.variables]) else: + # TODO: should this branch be removed? better not to fail silently return repl def visit_unpack_type(self, t: UnpackType) -> Type: raise NotImplementedError + def visit_parameters(self, t: Parameters) -> Type: + return t.copy_modified(arg_types=self.expand_types(t.arg_types)) + def visit_callable_type(self, t: CallableType) -> Type: param_spec = t.param_spec() if param_spec is not None: @@ -121,13 +148,18 @@ def visit_callable_type(self, t: CallableType) -> Type: # must expand both of them with all the argument types, # kinds and names in the replacement. The return type in # the replacement is ignored. - if isinstance(repl, CallableType): + if isinstance(repl, CallableType) or isinstance(repl, Parameters): # Substitute *args: P.args, **kwargs: P.kwargs - t = t.expand_param_spec(repl) - # TODO: Substitute remaining arg types - return t.copy_modified(ret_type=t.ret_type.accept(self), - type_guard=(t.type_guard.accept(self) - if t.type_guard is not None else None)) + prefix = param_spec.prefix + # we need to expand the types in the prefix, so might as well + # not get them in the first place + t = t.expand_param_spec(repl, no_prefix=True) + return t.copy_modified( + arg_types=self.expand_types(prefix.arg_types) + t.arg_types, + arg_kinds=prefix.arg_kinds + t.arg_kinds, + arg_names=prefix.arg_names + t.arg_names, + ret_type=t.ret_type.accept(self), + type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None)) return t.copy_modified(arg_types=self.expand_types(t.arg_types), ret_type=t.ret_type.accept(self), diff --git a/mypy/fixup.py b/mypy/fixup.py index cd10ae9156116..302bd38097b30 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -11,7 +11,7 @@ CallableType, Instance, Overloaded, TupleType, TypedDictType, TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType, TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny, ParamSpecType, - UnpackType, + Parameters, UnpackType, ) from mypy.visitor import NodeVisitor from mypy.lookup import lookup_fully_qualified @@ -255,6 +255,11 @@ def visit_param_spec(self, p: ParamSpecType) -> None: def visit_unpack_type(self, u: UnpackType) -> None: u.type.accept(self) + def visit_parameters(self, p: Parameters) -> None: + for argt in p.arg_types: + if argt is not None: + argt.accept(self) + def visit_unbound_type(self, o: UnboundType) -> None: for a in o.args: a.accept(self) diff --git a/mypy/indirection.py b/mypy/indirection.py index 9c9959c5f6566..0888c2afad209 100644 --- a/mypy/indirection.py +++ b/mypy/indirection.py @@ -70,6 +70,9 @@ def visit_param_spec(self, t: types.ParamSpecType) -> Set[str]: def visit_unpack_type(self, t: types.UnpackType) -> Set[str]: return t.type.accept(self) + def visit_parameters(self, t: types.Parameters) -> Set[str]: + return self._visit(t.arg_types) + def visit_instance(self, t: types.Instance) -> Set[str]: out = self._visit(t.args) if t.type: diff --git a/mypy/join.py b/mypy/join.py index e11cccb5fa447..a184efcc4bb43 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -7,7 +7,7 @@ Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType, TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, get_proper_type, - ProperType, get_proper_types, TypeAliasType, PlaceholderType, ParamSpecType, + ProperType, get_proper_types, TypeAliasType, PlaceholderType, ParamSpecType, Parameters, UnpackType ) from mypy.maptype import map_instance_to_supertype @@ -260,6 +260,12 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType: def visit_unpack_type(self, t: UnpackType) -> UnpackType: raise NotImplementedError + def visit_parameters(self, t: Parameters) -> ProperType: + if self.s == t: + return t + else: + return self.default(self.s) + def visit_instance(self, t: Instance) -> ProperType: if isinstance(self.s, Instance): if self.instance_joiner is None: diff --git a/mypy/main.py b/mypy/main.py index 3d98365872509..c4548ea9b6250 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -679,6 +679,10 @@ def add_invertible_flag(flag: str, " non-overlapping types", group=strictness_group) + add_invertible_flag('--strict-concatenate', default=False, strict_flag=True, + help="Make arguments prepended via Concatenate be truly positional-only", + group=strictness_group) + strict_help = "Strict mode; enables the following flags: {}".format( ", ".join(strict_flag_names)) strictness_group.add_argument( diff --git a/mypy/meet.py b/mypy/meet.py index e6b62ff13ad87..8a996146a344e 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -6,7 +6,7 @@ TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType, ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardedType, - ParamSpecType, UnpackType, + ParamSpecType, Parameters, UnpackType, ) from mypy.subtypes import is_equivalent, is_subtype, is_callable_compatible, is_proper_subtype from mypy.erasetype import erase_type @@ -509,6 +509,17 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType: def visit_unpack_type(self, t: UnpackType) -> ProperType: raise NotImplementedError + def visit_parameters(self, t: Parameters) -> ProperType: + # TODO: is this the right variance? + if isinstance(self.s, Parameters) or isinstance(self.s, CallableType): + if len(t.arg_types) != len(self.s.arg_types): + return self.default(self.s) + return t.copy_modified( + arg_types=[meet_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)] + ) + else: + return self.default(self.s) + def visit_instance(self, t: Instance) -> ProperType: if isinstance(self.s, Instance): if t.type == self.s.type: diff --git a/mypy/messages.py b/mypy/messages.py index f067e7e06bd4b..0e9a59ea40160 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -15,7 +15,9 @@ import difflib from textwrap import dedent -from typing import cast, List, Dict, Any, Sequence, Iterable, Iterator, Tuple, Set, Optional, Union +from typing import ( + cast, List, Dict, Any, Sequence, Iterable, Iterator, Tuple, Set, Optional, Union, Callable +) from typing_extensions import Final from mypy.erasetype import erase_type @@ -24,7 +26,7 @@ Type, CallableType, Instance, TypeVarType, TupleType, TypedDictType, LiteralType, UnionType, NoneType, AnyType, Overloaded, FunctionLike, DeletedType, TypeType, UninhabitedType, TypeOfAny, UnboundType, PartialType, get_proper_type, ProperType, - ParamSpecType, get_proper_types + ParamSpecType, Parameters, get_proper_types ) from mypy.typetraverser import TypeTraverserVisitor from mypy.nodes import ( @@ -624,6 +626,32 @@ def incompatible_argument_note(self, if call: self.note_call(original_caller_type, call, context, code=code) + self.maybe_note_concatenate_pos_args(original_caller_type, callee_type, context, code) + + def maybe_note_concatenate_pos_args(self, + original_caller_type: ProperType, + callee_type: ProperType, + context: Context, + code: Optional[ErrorCode] = None) -> None: + # pos-only vs positional can be confusing, with Concatenate + if (isinstance(callee_type, CallableType) and + isinstance(original_caller_type, CallableType) and + (original_caller_type.from_concatenate or callee_type.from_concatenate)): + names: List[str] = [] + for c, o in zip( + callee_type.formal_arguments(), + original_caller_type.formal_arguments()): + if None in (c.pos, o.pos): + # non-positional + continue + if c.name != o.name and c.name is None and o.name is not None: + names.append(o.name) + + if names: + missing_arguments = '"' + '", "'.join(names) + '"' + self.note(f'This may be because "{original_caller_type.name}" has arguments ' + f'named: {missing_arguments}', context, code=code) + def invalid_index_type(self, index_type: Type, expected_type: Type, base_str: str, context: Context, *, code: ErrorCode) -> None: index_str, expected_str = format_type_distinctly(index_type, expected_type) @@ -1652,6 +1680,32 @@ def quote_type_string(type_string: str) -> str: return '"{}"'.format(type_string) +def format_callable_args(arg_types: List[Type], arg_kinds: List[ArgKind], + arg_names: List[Optional[str]], format: Callable[[Type], str], + verbosity: int) -> str: + """Format a bunch of Callable arguments into a string""" + arg_strings = [] + for arg_name, arg_type, arg_kind in zip( + arg_names, arg_types, arg_kinds): + if (arg_kind == ARG_POS and arg_name is None + or verbosity == 0 and arg_kind.is_positional()): + + arg_strings.append(format(arg_type)) + else: + constructor = ARG_CONSTRUCTOR_NAMES[arg_kind] + if arg_kind.is_star() or arg_name is None: + arg_strings.append("{}({})".format( + constructor, + format(arg_type))) + else: + arg_strings.append("{}({}, {})".format( + constructor, + format(arg_type), + repr(arg_name))) + + return ", ".join(arg_strings) + + def format_type_inner(typ: Type, verbosity: int, fullnames: Optional[Set[str]]) -> str: @@ -1705,7 +1759,18 @@ def format_literal_value(typ: LiteralType) -> str: # This is similar to non-generic instance types. return typ.name elif isinstance(typ, ParamSpecType): - return typ.name_with_suffix() + # Concatenate[..., P] + if typ.prefix.arg_types: + args = format_callable_args( + typ.prefix.arg_types, + typ.prefix.arg_kinds, + typ.prefix.arg_names, + format, + verbosity) + + return f'[{args}, **{typ.name_with_suffix()}]' + else: + return typ.name_with_suffix() elif isinstance(typ, TupleType): # Prefer the name of the fallback class (if not tuple), as it's more informative. if typ.partial_fallback.type.fullname != 'builtins.tuple': @@ -1782,27 +1847,14 @@ def format_literal_value(typ: LiteralType) -> str: return 'Callable[..., {}]'.format(return_type) param_spec = func.param_spec() if param_spec is not None: - return f'Callable[{param_spec.name}, {return_type}]' - arg_strings = [] - for arg_name, arg_type, arg_kind in zip( - func.arg_names, func.arg_types, func.arg_kinds): - if (arg_kind == ARG_POS and arg_name is None - or verbosity == 0 and arg_kind.is_positional()): - - arg_strings.append(format(arg_type)) - else: - constructor = ARG_CONSTRUCTOR_NAMES[arg_kind] - if arg_kind.is_star() or arg_name is None: - arg_strings.append("{}({})".format( - constructor, - format(arg_type))) - else: - arg_strings.append("{}({}, {})".format( - constructor, - format(arg_type), - repr(arg_name))) - - return 'Callable[[{}], {}]'.format(", ".join(arg_strings), return_type) + return f'Callable[{format(param_spec)}, {return_type}]' + args = format_callable_args( + func.arg_types, + func.arg_kinds, + func.arg_names, + format, + verbosity) + return 'Callable[[{}], {}]'.format(args, return_type) else: # Use a simple representation for function types; proper # function types may result in long and difficult-to-read @@ -1810,6 +1862,14 @@ def format_literal_value(typ: LiteralType) -> str: return 'overloaded function' elif isinstance(typ, UnboundType): return str(typ) + elif isinstance(typ, Parameters): + args = format_callable_args( + typ.arg_types, + typ.arg_kinds, + typ.arg_names, + format, + verbosity) + return f'[{args}]' elif typ is None: raise RuntimeError('Type is None') else: diff --git a/mypy/nodes.py b/mypy/nodes.py index 98402ab9b71b0..7fcf5d85673cf 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1037,7 +1037,9 @@ def deserialize(self, data: JsonDict) -> 'ClassDef': assert data['.class'] == 'ClassDef' res = ClassDef(data['name'], Block([]), - [mypy.types.TypeVarType.deserialize(v) for v in data['type_vars']], + # https://github.com/python/mypy/issues/12257 + [cast(mypy.types.TypeVarLikeType, mypy.types.deserialize_type(v)) + for v in data['type_vars']], ) res.fullname = data['fullname'] return res @@ -2507,8 +2509,8 @@ class is generic then it will be a type constructor of higher kind. 'declared_metaclass', 'metaclass_type', 'names', 'is_abstract', 'is_protocol', 'runtime_protocol', 'abstract_attributes', 'deletable_attributes', 'slots', 'assuming', 'assuming_proper', - 'inferring', 'is_enum', 'fallback_to_any', 'type_vars', 'bases', - '_promote', 'tuple_type', 'is_named_tuple', 'typeddict_type', + 'inferring', 'is_enum', 'fallback_to_any', 'type_vars', 'has_param_spec_type', + 'bases', '_promote', 'tuple_type', 'is_named_tuple', 'typeddict_type', 'is_newtype', 'is_intersection', 'metadata', ) @@ -2591,6 +2593,9 @@ class is generic then it will be a type constructor of higher kind. # Generic type variable names (full names) type_vars: List[str] + # Whether this class has a ParamSpec type variable + has_param_spec_type: bool + # Direct base classes. bases: List["mypy.types.Instance"] @@ -2638,6 +2643,7 @@ def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> No self.defn = defn self.module_name = module_name self.type_vars = [] + self.has_param_spec_type = False self.bases = [] self.mro = [] self._mro_refs = None @@ -2668,7 +2674,9 @@ def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> No def add_type_vars(self) -> None: if self.defn.type_vars: for vd in self.defn.type_vars: - self.type_vars.append(vd.fullname) + if isinstance(vd, mypy.types.ParamSpecType): + self.has_param_spec_type = True + self.type_vars.append(vd.name) @property def name(self) -> str: @@ -2832,6 +2840,7 @@ def serialize(self) -> JsonDict: 'defn': self.defn.serialize(), 'abstract_attributes': self.abstract_attributes, 'type_vars': self.type_vars, + 'has_param_spec_type': self.has_param_spec_type, 'bases': [b.serialize() for b in self.bases], 'mro': [c.fullname for c in self.mro], '_promote': None if self._promote is None else self._promote.serialize(), @@ -2857,6 +2866,7 @@ def deserialize(cls, data: JsonDict) -> 'TypeInfo': # TODO: Is there a reason to reconstruct ti.subtypes? ti.abstract_attributes = data['abstract_attributes'] ti.type_vars = data['type_vars'] + ti.has_param_spec_type = data['has_param_spec_type'] ti.bases = [mypy.types.Instance.deserialize(b) for b in data['bases']] ti._promote = (None if data['_promote'] is None else mypy.types.deserialize_type(data['_promote'])) diff --git a/mypy/options.py b/mypy/options.py index b0dead4146b58..8e56d67bbeb8b 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -46,6 +46,7 @@ class BuildType: "mypyc", "no_implicit_optional", "show_none_errors", + "strict_concatenate", "strict_equality", "strict_optional", "strict_optional_whitelist", @@ -183,6 +184,9 @@ def __init__(self) -> None: # This makes 1 == '1', 1 in ['1'], and 1 is '1' errors. self.strict_equality = False + # Make arguments prepended via Concatenate be truly positional-only. + self.strict_concatenate = False + # Report an error for any branches inferred to be unreachable as a result of # type analysis. self.warn_unreachable = False diff --git a/mypy/sametypes.py b/mypy/sametypes.py index 70cd216031797..46798018410d2 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -4,7 +4,7 @@ Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType, UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType, Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType, - ProperType, get_proper_type, TypeAliasType, ParamSpecType, UnpackType + ProperType, get_proper_type, TypeAliasType, ParamSpecType, Parameters, UnpackType ) from mypy.typeops import tuple_fallback, make_simplified_union @@ -106,6 +106,12 @@ def visit_unpack_type(self, left: UnpackType) -> bool: return (isinstance(self.right, UnpackType) and is_same_type(left.type, self.right.type)) + def visit_parameters(self, left: Parameters) -> bool: + return (isinstance(self.right, Parameters) and + left.arg_names == self.right.arg_names and + is_same_types(left.arg_types, self.right.arg_types) and + left.arg_kinds == self.right.arg_kinds) + def visit_callable_type(self, left: CallableType) -> bool: # FIX generics if isinstance(self.right, CallableType): diff --git a/mypy/semanal.py b/mypy/semanal.py index 8872e6f4aa8ad..44ece06747321 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -97,7 +97,7 @@ NEVER_NAMES, FunctionLike, UnboundType, TypeVarType, TupleType, UnionType, StarType, CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue, TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType, - get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, + get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType, PROTOCOL_NAMES, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, FINAL_DECORATOR_NAMES, REVEAL_TYPE_NAMES, is_named_instance, ) @@ -1184,7 +1184,9 @@ def analyze_class(self, defn: ClassDef) -> None: self.prepare_class_def(defn) defn.type_vars = tvar_defs - defn.info.type_vars = [tvar.name for tvar in tvar_defs] + defn.info.type_vars = [] + # we want to make sure any additional logic in add_type_vars gets run + defn.info.add_type_vars() if base_error: defn.info.fallback_to_any = True @@ -4138,6 +4140,27 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]] items = items[:-1] else: items = [index] + + # whether param spec literals be allowed here + # TODO: should this be computed once and passed in? + # or is there a better way to do this? + base = expr.base + if isinstance(base, RefExpr) and isinstance(base.node, TypeAlias): + alias = base.node + target = get_proper_type(alias.target) + if isinstance(target, Instance): + has_param_spec = target.type.has_param_spec_type + num_args = len(target.type.type_vars) + else: + has_param_spec = False + num_args = -1 + elif isinstance(base, NameExpr) and isinstance(base.node, TypeInfo): + has_param_spec = base.node.has_param_spec_type + num_args = len(base.node.type_vars) + else: + has_param_spec = False + num_args = -1 + for item in items: try: typearg = self.expr_to_unanalyzed_type(item) @@ -4148,10 +4171,19 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]] # may be analysing a type alias definition rvalue. The error will be # reported elsewhere if it is not the case. analyzed = self.anal_type(typearg, allow_unbound_tvars=True, - allow_placeholder=True) + allow_placeholder=True, + allow_param_spec_literals=has_param_spec) if analyzed is None: return None types.append(analyzed) + + if has_param_spec and num_args == 1 and len(types) > 0: + first_arg = get_proper_type(types[0]) + if not (len(types) == 1 and (isinstance(first_arg, Parameters) or + isinstance(first_arg, ParamSpecType) or + isinstance(first_arg, AnyType))): + types = [Parameters(types, [ARG_POS] * len(types), [None] * len(types))] + return types def visit_slice_expr(self, expr: SliceExpr) -> None: @@ -5288,6 +5320,7 @@ def type_analyzer(self, *, allow_unbound_tvars: bool = False, allow_placeholder: bool = False, allow_required: bool = False, + allow_param_spec_literals: bool = False, report_invalid_types: bool = True) -> TypeAnalyser: if tvar_scope is None: tvar_scope = self.tvar_scope @@ -5300,7 +5333,8 @@ def type_analyzer(self, *, allow_tuple_literal=allow_tuple_literal, report_invalid_types=report_invalid_types, allow_placeholder=allow_placeholder, - allow_required=allow_required) + allow_required=allow_required, + allow_param_spec_literals=allow_param_spec_literals) tpan.in_dynamic_func = bool(self.function_stack and self.function_stack[-1].is_dynamic()) tpan.global_scope = not self.type and not self.function_stack return tpan @@ -5315,6 +5349,7 @@ def anal_type(self, allow_unbound_tvars: bool = False, allow_placeholder: bool = False, allow_required: bool = False, + allow_param_spec_literals: bool = False, report_invalid_types: bool = True, third_pass: bool = False) -> Optional[Type]: """Semantically analyze a type. @@ -5342,6 +5377,7 @@ def anal_type(self, allow_tuple_literal=allow_tuple_literal, allow_placeholder=allow_placeholder, allow_required=allow_required, + allow_param_spec_literals=allow_param_spec_literals, report_invalid_types=report_invalid_types) tag = self.track_incomplete_refs() typ = typ.accept(a) diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index 28d9423420b8f..437cb777c8d53 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -60,7 +60,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' Type, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType, ErasedType, DeletedType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, PartialType, TypeType, LiteralType, TypeAliasType, ParamSpecType, - UnpackType, + Parameters, UnpackType, ) from mypy.util import get_prefix @@ -321,6 +321,12 @@ def visit_param_spec(self, typ: ParamSpecType) -> SnapshotItem: def visit_unpack_type(self, typ: UnpackType) -> SnapshotItem: return ('UnpackType', snapshot_type(typ.type)) + def visit_parameters(self, typ: Parameters) -> SnapshotItem: + return ('Parameters', + snapshot_types(typ.arg_types), + tuple(encode_optional_str(name) for name in typ.arg_names), + tuple(typ.arg_kinds)) + def visit_callable_type(self, typ: CallableType) -> SnapshotItem: # FIX generics return ('CallableType', diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index 24e771fc868a1..deaf7a6e21b71 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -59,7 +59,7 @@ Type, SyntheticTypeVisitor, Instance, AnyType, NoneType, CallableType, ErasedType, DeletedType, TupleType, TypeType, TypedDictType, UnboundType, UninhabitedType, UnionType, Overloaded, TypeVarType, TypeList, CallableArgument, EllipsisType, StarType, LiteralType, - RawExpressionType, PartialType, PlaceholderType, TypeAliasType, ParamSpecType, + RawExpressionType, PartialType, PlaceholderType, TypeAliasType, ParamSpecType, Parameters, UnpackType ) from mypy.util import get_prefix, replace_object_state @@ -415,6 +415,10 @@ def visit_param_spec(self, typ: ParamSpecType) -> None: def visit_unpack_type(self, typ: UnpackType) -> None: typ.type.accept(self) + def visit_parameters(self, typ: Parameters) -> None: + for arg in typ.arg_types: + arg.accept(self) + def visit_typeddict_type(self, typ: TypedDictType) -> None: for value_type in typ.items.values(): value_type.accept(self) diff --git a/mypy/server/deps.py b/mypy/server/deps.py index ebd808de6a6df..646a024340482 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -99,7 +99,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a Type, Instance, AnyType, NoneType, TypeVisitor, CallableType, DeletedType, PartialType, TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, FunctionLike, Overloaded, TypeOfAny, LiteralType, ErasedType, get_proper_type, ProperType, - TypeAliasType, ParamSpecType, UnpackType + TypeAliasType, ParamSpecType, Parameters, UnpackType ) from mypy.server.trigger import make_trigger, make_wildcard_trigger from mypy.util import correct_relative_import @@ -964,6 +964,12 @@ def visit_param_spec(self, typ: ParamSpecType) -> List[str]: def visit_unpack_type(self, typ: UnpackType) -> List[str]: return typ.type.accept(self) + def visit_parameters(self, typ: Parameters) -> List[str]: + triggers = [] + for arg in typ.arg_types: + triggers.extend(self.get_type_triggers(arg)) + return triggers + def visit_typeddict_type(self, typ: TypedDictType) -> List[str]: triggers = [] for item in typ.items.values(): diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 17caea9db6f5a..314c260f293d3 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -8,7 +8,7 @@ Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, FunctionLike, TypeOfAny, LiteralType, get_proper_type, TypeAliasType, ParamSpecType, - UnpackType, TUPLE_LIKE_INSTANCE_NAMES, + Parameters, UnpackType, TUPLE_LIKE_INSTANCE_NAMES, ) import mypy.applytype import mypy.constraints @@ -24,6 +24,7 @@ from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance from mypy.typestate import TypeState, SubtypeKind +from mypy.options import Options from mypy import state # Flags for detected protocol members @@ -52,7 +53,8 @@ def is_subtype(left: Type, right: Type, ignore_type_params: bool = False, ignore_pos_arg_names: bool = False, ignore_declared_variance: bool = False, - ignore_promotions: bool = False) -> bool: + ignore_promotions: bool = False, + options: Optional[Options] = None) -> bool: """Is 'left' subtype of 'right'? Also consider Any to be a subtype of any type, and vice versa. This @@ -90,12 +92,14 @@ def is_subtype(left: Type, right: Type, ignore_type_params=ignore_type_params, ignore_pos_arg_names=ignore_pos_arg_names, ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions) + ignore_promotions=ignore_promotions, + options=options) return _is_subtype(left, right, ignore_type_params=ignore_type_params, ignore_pos_arg_names=ignore_pos_arg_names, ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions) + ignore_promotions=ignore_promotions, + options=options) def _is_subtype(left: Type, right: Type, @@ -103,7 +107,8 @@ def _is_subtype(left: Type, right: Type, ignore_type_params: bool = False, ignore_pos_arg_names: bool = False, ignore_declared_variance: bool = False, - ignore_promotions: bool = False) -> bool: + ignore_promotions: bool = False, + options: Optional[Options] = None) -> bool: orig_right = right orig_left = left left = get_proper_type(left) @@ -120,7 +125,8 @@ def _is_subtype(left: Type, right: Type, ignore_type_params=ignore_type_params, ignore_pos_arg_names=ignore_pos_arg_names, ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions) + ignore_promotions=ignore_promotions, + options=options) for item in right.items) # Recombine rhs literal types, to make an enum type a subtype # of a union of all enum items as literal types. Only do it if @@ -135,7 +141,8 @@ def _is_subtype(left: Type, right: Type, ignore_type_params=ignore_type_params, ignore_pos_arg_names=ignore_pos_arg_names, ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions) + ignore_promotions=ignore_promotions, + options=options) for item in right.items) # However, if 'left' is a type variable T, T might also have # an upper bound which is itself a union. This case will be @@ -152,19 +159,21 @@ def _is_subtype(left: Type, right: Type, ignore_type_params=ignore_type_params, ignore_pos_arg_names=ignore_pos_arg_names, ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions)) + ignore_promotions=ignore_promotions, + options=options)) def is_equivalent(a: Type, b: Type, *, ignore_type_params: bool = False, - ignore_pos_arg_names: bool = False + ignore_pos_arg_names: bool = False, + options: Optional[Options] = None ) -> bool: return ( is_subtype(a, b, ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names) + ignore_pos_arg_names=ignore_pos_arg_names, options=options) and is_subtype(b, a, ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names)) + ignore_pos_arg_names=ignore_pos_arg_names, options=options)) class SubtypeVisitor(TypeVisitor[bool]): @@ -174,7 +183,8 @@ def __init__(self, right: Type, ignore_type_params: bool, ignore_pos_arg_names: bool = False, ignore_declared_variance: bool = False, - ignore_promotions: bool = False) -> None: + ignore_promotions: bool = False, + options: Optional[Options] = None) -> None: self.right = get_proper_type(right) self.orig_right = right self.ignore_type_params = ignore_type_params @@ -183,6 +193,7 @@ def __init__(self, right: Type, self.ignore_promotions = ignore_promotions self.check_type_parameter = (ignore_type_parameter if ignore_type_params else check_type_parameter) + self.options = options self._subtype_kind = SubtypeVisitor.build_subtype_kind( ignore_type_params=ignore_type_params, ignore_pos_arg_names=ignore_pos_arg_names, @@ -206,7 +217,8 @@ def _is_subtype(self, left: Type, right: Type) -> bool: ignore_type_params=self.ignore_type_params, ignore_pos_arg_names=self.ignore_pos_arg_names, ignore_declared_variance=self.ignore_declared_variance, - ignore_promotions=self.ignore_promotions) + ignore_promotions=self.ignore_promotions, + options=self.options) # visit_x(left) means: is left (which is an instance of X) a subtype of # right? @@ -278,7 +290,7 @@ def visit_instance(self, left: Instance) -> bool: if not self.check_type_parameter(lefta, righta, tvar.variance): nominal = False else: - if not is_equivalent(lefta, righta): + if not self.check_type_parameter(lefta, righta, COVARIANT): nominal = False if nominal: TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) @@ -330,6 +342,16 @@ def visit_param_spec(self, left: ParamSpecType) -> bool: def visit_unpack_type(self, left: UnpackType) -> bool: raise NotImplementedError + def visit_parameters(self, left: Parameters) -> bool: + right = self.right + if isinstance(right, Parameters) or isinstance(right, CallableType): + return are_parameters_compatible( + left, right, + is_compat=self._is_subtype, + ignore_pos_arg_names=self.ignore_pos_arg_names) + else: + return False + def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): @@ -343,7 +365,8 @@ def visit_callable_type(self, left: CallableType) -> bool: return is_callable_compatible( left, right, is_compat=self._is_subtype, - ignore_pos_arg_names=self.ignore_pos_arg_names) + ignore_pos_arg_names=self.ignore_pos_arg_names, + strict_concatenate=self.options.strict_concatenate if self.options else True) elif isinstance(right, Overloaded): return all(self._is_subtype(left, item) for item in right.items) elif isinstance(right, Instance): @@ -358,6 +381,12 @@ def visit_callable_type(self, left: CallableType) -> bool: elif isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. return left.is_type_obj() and self._is_subtype(left.ret_type, right.item) + elif isinstance(right, Parameters): + # this doesn't check return types.... but is needed for is_equivalent + return are_parameters_compatible( + left, right, + is_compat=self._is_subtype, + ignore_pos_arg_names=self.ignore_pos_arg_names) else: return False @@ -404,7 +433,8 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: return False for name, l, r in left.zip(right): if not is_equivalent(l, r, - ignore_type_params=self.ignore_type_params): + ignore_type_params=self.ignore_type_params, + options=self.options): return False # Non-required key is not compatible with a required key since # indexing may fail unexpectedly if a required key is missing. @@ -471,12 +501,15 @@ def visit_overloaded(self, left: Overloaded) -> bool: else: # If this one overlaps with the supertype in any way, but it wasn't # an exact match, then it's a potential error. + strict_concat = self.options.strict_concatenate if self.options else True if (is_callable_compatible(left_item, right_item, is_compat=self._is_subtype, ignore_return=True, - ignore_pos_arg_names=self.ignore_pos_arg_names) or + ignore_pos_arg_names=self.ignore_pos_arg_names, + strict_concatenate=strict_concat) or is_callable_compatible(right_item, left_item, is_compat=self._is_subtype, ignore_return=True, - ignore_pos_arg_names=self.ignore_pos_arg_names)): + ignore_pos_arg_names=self.ignore_pos_arg_names, + strict_concatenate=strict_concat)): # If this is an overload that's already been matched, there's no # problem. if left_item not in matched_overloads: @@ -778,7 +811,8 @@ def is_callable_compatible(left: CallableType, right: CallableType, ignore_return: bool = False, ignore_pos_arg_names: bool = False, check_args_covariantly: bool = False, - allow_partial_overlap: bool = False) -> bool: + allow_partial_overlap: bool = False, + strict_concatenate: bool = False) -> bool: """Is the left compatible with the right, using the provided compatibility check? is_compat: @@ -914,6 +948,27 @@ def g(x: int) -> int: ... if check_args_covariantly: is_compat = flip_compat_check(is_compat) + if not strict_concatenate and (left.from_concatenate or right.from_concatenate): + strict_concatenate_check = False + else: + strict_concatenate_check = True + + return are_parameters_compatible(left, right, is_compat=is_compat, + ignore_pos_arg_names=ignore_pos_arg_names, + check_args_covariantly=check_args_covariantly, + allow_partial_overlap=allow_partial_overlap, + strict_concatenate_check=strict_concatenate_check) + + +def are_parameters_compatible(left: Union[Parameters, CallableType], + right: Union[Parameters, CallableType], + *, + is_compat: Callable[[Type, Type], bool], + ignore_pos_arg_names: bool = False, + check_args_covariantly: bool = False, + allow_partial_overlap: bool = False, + strict_concatenate_check: bool = True) -> bool: + """Helper function for is_callable_compatible, used for Parameter compatibility""" if right.is_ellipsis_args: return True @@ -1001,7 +1056,9 @@ def _incompatible(left_arg: Optional[FormalArgument], right_names = {name for name in right.arg_names if name is not None} left_only_names = set() for name, kind in zip(left.arg_names, left.arg_kinds): - if name is None or kind.is_star() or name in right_names: + if (name is None or kind.is_star() + or name in right_names + or not strict_concatenate_check): continue left_only_names.add(name) @@ -1037,7 +1094,8 @@ def _incompatible(left_arg: Optional[FormalArgument], if (right_by_name is not None and right_by_pos is not None and right_by_name != right_by_pos - and (right_by_pos.required or right_by_name.required)): + and (right_by_pos.required or right_by_name.required) + and strict_concatenate_check): return False # All *required* left-hand arguments must have a corresponding @@ -1363,6 +1421,13 @@ def visit_param_spec(self, left: ParamSpecType) -> bool: def visit_unpack_type(self, left: UnpackType) -> bool: raise NotImplementedError + def visit_parameters(self, left: Parameters) -> bool: + right = self.right + if isinstance(right, Parameters) or isinstance(right, CallableType): + return are_parameters_compatible(left, right, is_compat=self._is_proper_subtype) + else: + return False + def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 2adc73c009387..05688a1e5071e 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -20,7 +20,7 @@ from mypy.types import ( Type, AnyType, CallableType, Overloaded, TupleType, TypedDictType, LiteralType, - RawExpressionType, Instance, NoneType, TypeType, + Parameters, RawExpressionType, Instance, NoneType, TypeType, UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarLikeType, UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument, PlaceholderType, TypeAliasType, ParamSpecType, UnpackType, get_proper_type @@ -67,6 +67,10 @@ def visit_type_var(self, t: TypeVarType) -> T: def visit_param_spec(self, t: ParamSpecType) -> T: pass + @abstractmethod + def visit_parameters(self, t: Parameters) -> T: + pass + @abstractmethod def visit_instance(self, t: Instance) -> T: pass @@ -190,6 +194,9 @@ def visit_type_var(self, t: TypeVarType) -> Type: def visit_param_spec(self, t: ParamSpecType) -> Type: return t + def visit_parameters(self, t: Parameters) -> Type: + return t.copy_modified(arg_types=self.translate_types(t.arg_types)) + def visit_partial_type(self, t: PartialType) -> Type: return t @@ -311,6 +318,9 @@ def visit_param_spec(self, t: ParamSpecType) -> T: def visit_unpack_type(self, t: UnpackType) -> T: return self.query_types([t.type]) + def visit_parameters(self, t: Parameters) -> T: + return self.query_types(t.arg_types) + def visit_partial_type(self, t: PartialType) -> T: return self.strategy([]) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index b40bec6ca3379..276e46df03ee4 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -15,7 +15,7 @@ NEVER_NAMES, Type, UnboundType, TupleType, TypedDictType, UnionType, Instance, AnyType, CallableType, NoneType, ErasedType, DeletedType, TypeList, TypeVarType, SyntheticTypeVisitor, StarType, PartialType, EllipsisType, UninhabitedType, TypeType, CallableArgument, - TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, + Parameters, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, PlaceholderType, Overloaded, get_proper_type, TypeAliasType, RequiredType, TypeVarLikeType, ParamSpecType, ParamSpecFlavor, UnpackType, callable_with_ellipsis, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, @@ -129,6 +129,7 @@ def __init__(self, allow_unbound_tvars: bool = False, allow_placeholder: bool = False, allow_required: bool = False, + allow_param_spec_literals: bool = False, report_invalid_types: bool = True) -> None: self.api = api self.lookup_qualified = api.lookup_qualified @@ -153,6 +154,8 @@ def __init__(self, self.allow_placeholder = allow_placeholder # Are we in a context where Required[] is allowed? self.allow_required = allow_required + # Are we in a context where ParamSpec literals are allowed? + self.allow_param_spec_literals = allow_param_spec_literals # Should we report an error whenever we encounter a RawExpressionType outside # of a Literal context: e.g. whenever we encounter an invalid type? Normally, # we want to report an error, but the caller may want to do more specialized @@ -264,6 +267,10 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) return self.analyze_type_with_type_info(node, t.args, t) elif node.fullname in TYPE_ALIAS_NAMES: return AnyType(TypeOfAny.special_form) + # Concatenate is an operator, no need for a proper type + elif node.fullname in ('typing_extensions.Concatenate', 'typing.Concatenate'): + # We check the return type further up the stack for valid use locations + return self.apply_concatenate_operator(t) else: return self.analyze_unbound_type_without_type_info(t, sym, defining_literal) else: # sym is None @@ -277,6 +284,33 @@ def cannot_resolve_type(self, t: UnboundType) -> None: 'Cannot resolve name "{}" (possible cyclic definition)'.format(t.name), t) + def apply_concatenate_operator(self, t: UnboundType) -> Type: + if len(t.args) == 0: + self.api.fail('Concatenate needs type arguments', t) + return AnyType(TypeOfAny.from_error) + + # last argument has to be ParamSpec + ps = self.anal_type(t.args[-1], allow_param_spec=True) + if not isinstance(ps, ParamSpecType): + self.api.fail('The last parameter to Concatenate needs to be a ParamSpec', t) + return AnyType(TypeOfAny.from_error) + + # TODO: this may not work well with aliases, if those worked. + # Those should be special-cased. + elif ps.prefix.arg_types: + self.api.fail('Nested Concatenates are invalid', t) + + args = self.anal_array(t.args[:-1]) + pre = ps.prefix + + # mypy can't infer this :( + names: List[Optional[str]] = [None] * len(args) + + pre = Parameters(args + pre.arg_types, + [ARG_POS] * len(args) + pre.arg_kinds, + names + pre.arg_names) + return ps.copy_modified(prefix=pre) + def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Optional[Type]: """Bind special type that is recognized through magic name such as 'typing.Any'. @@ -403,13 +437,32 @@ def analyze_type_with_type_info( if len(args) > 0 and info.fullname == 'builtins.tuple': fallback = Instance(info, [AnyType(TypeOfAny.special_form)], ctx.line) return TupleType(self.anal_array(args), fallback, ctx.line) - # Analyze arguments and (usually) construct Instance type. The - # number of type arguments and their values are - # checked only later, since we do not always know the - # valid count at this point. Thus we may construct an - # Instance with an invalid number of type arguments. - instance = Instance(info, self.anal_array(args, allow_param_spec=True), - ctx.line, ctx.column) + + # This is a heuristic: it will be checked later anyways but the error + # message may be worse. + with self.set_allow_param_spec_literals(info.has_param_spec_type): + # Analyze arguments and (usually) construct Instance type. The + # number of type arguments and their values are + # checked only later, since we do not always know the + # valid count at this point. Thus we may construct an + # Instance with an invalid number of type arguments. + instance = Instance(info, self.anal_array(args, allow_param_spec=True), + ctx.line, ctx.column) + + # "aesthetic" paramspec literals + # these do not support mypy_extensions VarArgs, etc. as they were already analyzed + # TODO: should these be re-analyzed to get rid of this inconsistency? + # another inconsistency is with empty type args (Z[] is more possibly an error imo) + if len(info.type_vars) == 1 and info.has_param_spec_type and len(instance.args) > 0: + first_arg = get_proper_type(instance.args[0]) + + # TODO: can I use tuple syntax to isinstance multiple in 3.6? + if not (len(instance.args) == 1 and (isinstance(first_arg, Parameters) or + isinstance(first_arg, ParamSpecType) or + isinstance(first_arg, AnyType))): + args = instance.args + instance.args = (Parameters(args, [ARG_POS] * len(args), [None] * len(args)),) + # Check type argument count. if len(instance.args) != len(info.type_vars) and not self.defining_alias: fix_instance(instance, self.fail, self.note, @@ -546,9 +599,19 @@ def visit_deleted_type(self, t: DeletedType) -> Type: return t def visit_type_list(self, t: TypeList) -> Type: - self.fail('Bracketed expression "[...]" is not valid as a type', t) - self.note('Did you mean "List[...]"?', t) - return AnyType(TypeOfAny.from_error) + # paramspec literal (Z[[int, str, Whatever]]) + if self.allow_param_spec_literals: + params = self.analyze_callable_args(t) + if params: + ts, kinds, names = params + # bind these types + return Parameters(self.anal_array(ts), kinds, names) + else: + return AnyType(TypeOfAny.from_error) + else: + self.fail('Bracketed expression "[...]" is not valid as a type', t) + self.note('Did you mean "List[...]"?', t) + return AnyType(TypeOfAny.from_error) def visit_callable_argument(self, t: CallableArgument) -> Type: self.fail('Invalid type', t) @@ -570,6 +633,9 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: def visit_unpack_type(self, t: UnpackType) -> Type: raise NotImplementedError + def visit_parameters(self, t: Parameters) -> Type: + raise NotImplementedError("ParamSpec literals cannot have unbound TypeVars") + def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: # Every Callable can bind its own type variables, if they're not in the outer scope with self.tvar_scope_frame(): @@ -728,8 +794,15 @@ def visit_partial_type(self, t: PartialType) -> Type: assert False, "Internal error: Unexpected partial type" def visit_ellipsis_type(self, t: EllipsisType) -> Type: - self.fail('Unexpected "..."', t) - return AnyType(TypeOfAny.from_error) + if self.allow_param_spec_literals: + any_type = AnyType(TypeOfAny.explicit) + return Parameters([any_type, any_type], + [ARG_STAR, ARG_STAR2], + [None, None], + is_ellipsis_args=True) + else: + self.fail('Unexpected "..."', t) + return AnyType(TypeOfAny.from_error) def visit_type_type(self, t: TypeType) -> Type: return TypeType.make_normalized(self.anal_type(t.item), line=t.line) @@ -773,6 +846,48 @@ def analyze_callable_args_for_paramspec( fallback=fallback, ) + def analyze_callable_args_for_concatenate( + self, + callable_args: Type, + ret_type: Type, + fallback: Instance, + ) -> Optional[CallableType]: + """Construct a 'Callable[C, RET]', where C is Concatenate[..., P], returning None if we + cannot. + """ + if not isinstance(callable_args, UnboundType): + return None + sym = self.lookup_qualified(callable_args.name, callable_args) + if sym is None: + return None + if sym.node is None: + return None + if sym.node.fullname not in ('typing_extensions.Concatenate', 'typing.Concatenate'): + return None + + tvar_def = self.anal_type(callable_args, allow_param_spec=True) + if not isinstance(tvar_def, ParamSpecType): + return None + + # TODO: Use tuple[...] or Mapping[..] instead? + obj = self.named_type('builtins.object') + # ick, CallableType should take ParamSpecType + prefix = tvar_def.prefix + # we don't set the prefix here as generic arguments will get updated at some point + # in the future. CallableType.param_spec() accounts for this. + return CallableType( + [*prefix.arg_types, + ParamSpecType(tvar_def.name, tvar_def.fullname, tvar_def.id, ParamSpecFlavor.ARGS, + upper_bound=obj), + ParamSpecType(tvar_def.name, tvar_def.fullname, tvar_def.id, ParamSpecFlavor.KWARGS, + upper_bound=obj)], + [*prefix.arg_kinds, nodes.ARG_STAR, nodes.ARG_STAR2], + [*prefix.arg_names, None, None], + ret_type=ret_type, + fallback=fallback, + from_concatenate=True, + ) + def analyze_callable_type(self, t: UnboundType) -> Type: fallback = self.named_type('builtins.function') if len(t.args) == 0: @@ -804,6 +919,10 @@ def analyze_callable_type(self, t: UnboundType) -> Type: callable_args, ret_type, fallback + ) or self.analyze_callable_args_for_concatenate( + callable_args, + ret_type, + fallback ) if maybe_ret is None: # Callable[?, RET] (where ? is something invalid) @@ -1039,12 +1158,15 @@ def anal_type(self, t: Type, nested: bool = True, *, allow_param_spec: bool = Fa if (not allow_param_spec and isinstance(analyzed, ParamSpecType) and analyzed.flavor == ParamSpecFlavor.BARE): - self.fail('Invalid location for ParamSpec "{}"'.format(analyzed.name), t) - self.note( - 'You can use ParamSpec as the first argument to Callable, e.g., ' - "'Callable[{}, int]'".format(analyzed.name), - t - ) + if analyzed.prefix.arg_types: + self.fail('Invalid location for Concatenate', t) + else: + self.fail('Invalid location for ParamSpec "{}"'.format(analyzed.name), t) + self.note( + 'You can use ParamSpec as the first argument to Callable, e.g., ' + "'Callable[{}, int]'".format(analyzed.name), + t + ) return analyzed def anal_var_def(self, var_def: TypeVarLikeType) -> TypeVarLikeType: @@ -1089,6 +1211,15 @@ def tuple_type(self, items: List[Type]) -> TupleType: any_type = AnyType(TypeOfAny.special_form) return TupleType(items, fallback=self.named_type('builtins.tuple', [any_type])) + @contextmanager + def set_allow_param_spec_literals(self, to: bool) -> Iterator[None]: + old = self.allow_param_spec_literals + try: + self.allow_param_spec_literals = to + yield + finally: + self.allow_param_spec_literals = old + TypeVarLikeList = List[Tuple[str, TypeVarLikeExpr]] @@ -1280,7 +1411,7 @@ def __init__(self, def _seems_like_callable(self, type: UnboundType) -> bool: if not type.args: return False - if isinstance(type.args[0], (EllipsisType, TypeList)): + if isinstance(type.args[0], (EllipsisType, TypeList, ParamSpecType)): return True return False diff --git a/mypy/typeops.py b/mypy/typeops.py index fa05231109414..d97e9f7baf359 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -5,7 +5,7 @@ since these may assume that MROs are ready. """ -from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any +from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union from typing_extensions import Type as TypingType import itertools import sys @@ -14,7 +14,7 @@ TupleType, Instance, FunctionLike, Type, CallableType, TypeVarLikeType, Overloaded, TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types, - copy_type, TypeAliasType, TypeQuery, ParamSpecType, + copy_type, TypeAliasType, TypeQuery, ParamSpecType, Parameters, ENUM_REMOVED_PROPS ) from mypy.nodes import ( @@ -272,7 +272,7 @@ def erase_to_bound(t: Type) -> Type: return t -def callable_corresponding_argument(typ: CallableType, +def callable_corresponding_argument(typ: Union[CallableType, Parameters], model: FormalArgument) -> Optional[FormalArgument]: """Return the argument a function that corresponds to `model`""" diff --git a/mypy/types.py b/mypy/types.py index 465af1c50e337..06cf3f9e9dff8 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -559,26 +559,44 @@ class ParamSpecType(TypeVarLikeType): always just 'object'). """ - __slots__ = ('flavor',) + __slots__ = ('flavor', 'prefix') flavor: int + prefix: 'Parameters' def __init__( self, name: str, fullname: str, id: Union[TypeVarId, int], flavor: int, - upper_bound: Type, *, line: int = -1, column: int = -1 + upper_bound: Type, *, line: int = -1, column: int = -1, + prefix: Optional['Parameters'] = None ) -> None: super().__init__(name, fullname, id, upper_bound, line=line, column=column) self.flavor = flavor + self.prefix = prefix or Parameters([], [], []) @staticmethod def new_unification_variable(old: 'ParamSpecType') -> 'ParamSpecType': new_id = TypeVarId.new(meta_level=1) return ParamSpecType(old.name, old.fullname, new_id, old.flavor, old.upper_bound, - line=old.line, column=old.column) + line=old.line, column=old.column, prefix=old.prefix) def with_flavor(self, flavor: int) -> 'ParamSpecType': return ParamSpecType(self.name, self.fullname, self.id, flavor, - upper_bound=self.upper_bound) + upper_bound=self.upper_bound, prefix=self.prefix) + + def copy_modified(self, *, + id: Bogus[Union[TypeVarId, int]] = _dummy, + flavor: Bogus[int] = _dummy, + prefix: Bogus['Parameters'] = _dummy) -> 'ParamSpecType': + return ParamSpecType( + self.name, + self.fullname, + id if id is not _dummy else self.id, + flavor if flavor is not _dummy else self.flavor, + self.upper_bound, + line=self.line, + column=self.column, + prefix=prefix if prefix is not _dummy else self.prefix, + ) def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_param_spec(self) @@ -609,6 +627,7 @@ def serialize(self) -> JsonDict: 'id': self.id.raw_id, 'flavor': self.flavor, 'upper_bound': self.upper_bound.serialize(), + 'prefix': self.prefix.serialize() } @classmethod @@ -620,6 +639,7 @@ def deserialize(cls, data: JsonDict) -> 'ParamSpecType': data['id'], data['flavor'], deserialize_type(data['upper_bound']), + prefix=Parameters.deserialize(data['prefix']) ) @@ -1183,6 +1203,183 @@ def get_name(self) -> Optional[str]: pass ('required', bool)]) +# TODO: should this take bound typevars too? what would this take? +# ex: class Z(Generic[P, T]): ...; Z[[V], V] +# What does a typevar even mean in this context? +class Parameters(ProperType): + """Type that represents the parameters to a function. + + Used for ParamSpec analysis.""" + __slots__ = ('arg_types', + 'arg_kinds', + 'arg_names', + 'min_args', + 'is_ellipsis_args', + 'variables') + + def __init__(self, + arg_types: Sequence[Type], + arg_kinds: List[ArgKind], + arg_names: Sequence[Optional[str]], + *, + variables: Optional[Sequence[TypeVarLikeType]] = None, + is_ellipsis_args: bool = False, + line: int = -1, + column: int = -1 + ) -> None: + super().__init__(line, column) + self.arg_types = list(arg_types) + self.arg_kinds = arg_kinds + self.arg_names = list(arg_names) + assert len(arg_types) == len(arg_kinds) == len(arg_names) + self.min_args = arg_kinds.count(ARG_POS) + self.is_ellipsis_args = is_ellipsis_args + self.variables = variables or [] + + def copy_modified(self, + arg_types: Bogus[Sequence[Type]] = _dummy, + arg_kinds: Bogus[List[ArgKind]] = _dummy, + arg_names: Bogus[Sequence[Optional[str]]] = _dummy, + *, + variables: Bogus[Sequence[TypeVarLikeType]] = _dummy, + is_ellipsis_args: Bogus[bool] = _dummy + ) -> 'Parameters': + return Parameters( + arg_types=arg_types if arg_types is not _dummy else self.arg_types, + arg_kinds=arg_kinds if arg_kinds is not _dummy else self.arg_kinds, + arg_names=arg_names if arg_names is not _dummy else self.arg_names, + is_ellipsis_args=(is_ellipsis_args if is_ellipsis_args is not _dummy + else self.is_ellipsis_args), + variables=variables if variables is not _dummy else self.variables + ) + + # the following are copied from CallableType. Is there a way to decrease code duplication? + def var_arg(self) -> Optional[FormalArgument]: + """The formal argument for *args.""" + for position, (type, kind) in enumerate(zip(self.arg_types, self.arg_kinds)): + if kind == ARG_STAR: + return FormalArgument(None, position, type, False) + return None + + def kw_arg(self) -> Optional[FormalArgument]: + """The formal argument for **kwargs.""" + for position, (type, kind) in enumerate(zip(self.arg_types, self.arg_kinds)): + if kind == ARG_STAR2: + return FormalArgument(None, position, type, False) + return None + + def formal_arguments(self, include_star_args: bool = False) -> List[FormalArgument]: + """Yields the formal arguments corresponding to this callable, ignoring *arg and **kwargs. + + To handle *args and **kwargs, use the 'callable.var_args' and 'callable.kw_args' fields, + if they are not None. + + If you really want to include star args in the yielded output, set the + 'include_star_args' parameter to 'True'.""" + args = [] + done_with_positional = False + for i in range(len(self.arg_types)): + kind = self.arg_kinds[i] + if kind.is_named() or kind.is_star(): + done_with_positional = True + if not include_star_args and kind.is_star(): + continue + + required = kind.is_required() + pos = None if done_with_positional else i + arg = FormalArgument( + self.arg_names[i], + pos, + self.arg_types[i], + required + ) + args.append(arg) + return args + + def argument_by_name(self, name: Optional[str]) -> Optional[FormalArgument]: + if name is None: + return None + seen_star = False + for i, (arg_name, kind, typ) in enumerate( + zip(self.arg_names, self.arg_kinds, self.arg_types)): + # No more positional arguments after these. + if kind.is_named() or kind.is_star(): + seen_star = True + if kind.is_star(): + continue + if arg_name == name: + position = None if seen_star else i + return FormalArgument(name, position, typ, kind.is_required()) + return self.try_synthesizing_arg_from_kwarg(name) + + def argument_by_position(self, position: Optional[int]) -> Optional[FormalArgument]: + if position is None: + return None + if position >= len(self.arg_names): + return self.try_synthesizing_arg_from_vararg(position) + name, kind, typ = ( + self.arg_names[position], + self.arg_kinds[position], + self.arg_types[position], + ) + if kind.is_positional(): + return FormalArgument(name, position, typ, kind == ARG_POS) + else: + return self.try_synthesizing_arg_from_vararg(position) + + def try_synthesizing_arg_from_kwarg(self, + name: Optional[str]) -> Optional[FormalArgument]: + kw_arg = self.kw_arg() + if kw_arg is not None: + return FormalArgument(name, None, kw_arg.typ, False) + else: + return None + + def try_synthesizing_arg_from_vararg(self, + position: Optional[int]) -> Optional[FormalArgument]: + var_arg = self.var_arg() + if var_arg is not None: + return FormalArgument(None, position, var_arg.typ, False) + else: + return None + + def accept(self, visitor: 'TypeVisitor[T]') -> T: + return visitor.visit_parameters(self) + + def serialize(self) -> JsonDict: + return {'.class': 'Parameters', + 'arg_types': [t.serialize() for t in self.arg_types], + 'arg_kinds': [int(x.value) for x in self.arg_kinds], + 'arg_names': self.arg_names, + 'variables': [tv.serialize() for tv in self.variables], + } + + @classmethod + def deserialize(cls, data: JsonDict) -> 'Parameters': + assert data['.class'] == 'Parameters' + return Parameters( + [deserialize_type(t) for t in data['arg_types']], + [ArgKind(x) for x in data['arg_kinds']], + data['arg_names'], + variables=[cast(TypeVarLikeType, deserialize_type(v)) for v in data['variables']], + ) + + def __hash__(self) -> int: + return hash((self.is_ellipsis_args, tuple(self.arg_types), + tuple(self.arg_names), tuple(self.arg_kinds))) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Parameters) or isinstance(other, CallableType): + return ( + self.arg_types == other.arg_types and + self.arg_names == other.arg_names and + self.arg_kinds == other.arg_kinds and + self.is_ellipsis_args == other.is_ellipsis_args + ) + else: + return NotImplemented + + class CallableType(FunctionLike): """Type of a non-overloaded callable object (such as function).""" @@ -1209,9 +1406,12 @@ class CallableType(FunctionLike): 'def_extras', # Information about original definition we want to serialize. # This is used for more detailed error messages. 'type_guard', # T, if -> TypeGuard[T] (ret_type is bool in this case). + 'from_concatenate', # whether this callable is from a concatenate object + # (this is used for error messages) ) def __init__(self, + # maybe this should be refactored to take a Parameters object arg_types: Sequence[Type], arg_kinds: List[ArgKind], arg_names: Sequence[Optional[str]], @@ -1229,6 +1429,7 @@ def __init__(self, bound_args: Sequence[Optional[Type]] = (), def_extras: Optional[Dict[str, Any]] = None, type_guard: Optional[Type] = None, + from_concatenate: bool = False ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) @@ -1248,6 +1449,7 @@ def __init__(self, self.implicit = implicit self.special_sig = special_sig self.from_type_type = from_type_type + self.from_concatenate = from_concatenate if not bound_args: bound_args = () self.bound_args = bound_args @@ -1290,6 +1492,7 @@ def copy_modified(self, bound_args: Bogus[List[Optional[Type]]] = _dummy, def_extras: Bogus[Dict[str, Any]] = _dummy, type_guard: Bogus[Optional[Type]] = _dummy, + from_concatenate: Bogus[bool] = _dummy, ) -> 'CallableType': return CallableType( arg_types=arg_types if arg_types is not _dummy else self.arg_types, @@ -1310,6 +1513,8 @@ def copy_modified(self, bound_args=bound_args if bound_args is not _dummy else self.bound_args, def_extras=def_extras if def_extras is not _dummy else dict(self.def_extras), type_guard=type_guard if type_guard is not _dummy else self.type_guard, + from_concatenate=(from_concatenate if from_concatenate is not _dummy + else self.from_concatenate), ) def var_arg(self) -> Optional[FormalArgument]: @@ -1468,13 +1673,32 @@ def param_spec(self) -> Optional[ParamSpecType]: arg_type = self.arg_types[-2] if not isinstance(arg_type, ParamSpecType): return None + # sometimes paramspectypes are analyzed in from mysterious places, + # e.g. def f(prefix..., *args: P.args, **kwargs: P.kwargs) -> ...: ... + prefix = arg_type.prefix + if not prefix.arg_types: + # TODO: confirm that all arg kinds are positional + prefix = Parameters(self.arg_types[:-2], self.arg_kinds[:-2], self.arg_names[:-2]) return ParamSpecType(arg_type.name, arg_type.fullname, arg_type.id, ParamSpecFlavor.BARE, - arg_type.upper_bound) - - def expand_param_spec(self, c: 'CallableType') -> 'CallableType': - return self.copy_modified(arg_types=self.arg_types[:-2] + c.arg_types, - arg_kinds=self.arg_kinds[:-2] + c.arg_kinds, - arg_names=self.arg_names[:-2] + c.arg_names) + arg_type.upper_bound, prefix=prefix) + + def expand_param_spec(self, + c: Union['CallableType', Parameters], + no_prefix: bool = False) -> 'CallableType': + variables = c.variables + + if no_prefix: + return self.copy_modified(arg_types=c.arg_types, + arg_kinds=c.arg_kinds, + arg_names=c.arg_names, + is_ellipsis_args=c.is_ellipsis_args, + variables=[*variables, *self.variables]) + else: + return self.copy_modified(arg_types=self.arg_types[:-2] + c.arg_types, + arg_kinds=self.arg_kinds[:-2] + c.arg_kinds, + arg_names=self.arg_names[:-2] + c.arg_names, + is_ellipsis_args=c.is_ellipsis_args, + variables=[*variables, *self.variables]) def __hash__(self) -> int: return hash((self.ret_type, self.is_type_obj(), @@ -1511,6 +1735,7 @@ def serialize(self) -> JsonDict: for t in self.bound_args], 'def_extras': dict(self.def_extras), 'type_guard': self.type_guard.serialize() if self.type_guard is not None else None, + 'from_concatenate': self.from_concatenate, } @classmethod @@ -1531,6 +1756,7 @@ def deserialize(cls, data: JsonDict) -> 'CallableType': def_extras=data['def_extras'], type_guard=(deserialize_type(data['type_guard']) if data['type_guard'] is not None else None), + from_concatenate=data['from_concatenate'], ) @@ -2362,14 +2588,49 @@ def visit_type_var(self, t: TypeVarType) -> str: return s def visit_param_spec(self, t: ParamSpecType) -> str: + # prefixes are displayed as Concatenate + s = '' + if t.prefix.arg_types: + s += f'[{self.list_str(t.prefix.arg_types)}, **' if t.name is None: # Anonymous type variable type (only numeric id). - s = f'`{t.id}' + s += f'`{t.id}' else: # Named type variable type. - s = f'{t.name_with_suffix()}`{t.id}' + s += f'{t.name_with_suffix()}`{t.id}' + if t.prefix.arg_types: + s += ']' return s + def visit_parameters(self, t: Parameters) -> str: + # This is copied from visit_callable -- is there a way to decrease duplication? + if t.is_ellipsis_args: + return '...' + + s = '' + bare_asterisk = False + for i in range(len(t.arg_types)): + if s != '': + s += ', ' + if t.arg_kinds[i].is_named() and not bare_asterisk: + s += '*, ' + bare_asterisk = True + if t.arg_kinds[i] == ARG_STAR: + s += '*' + if t.arg_kinds[i] == ARG_STAR2: + s += '**' + name = t.arg_names[i] + if name: + s += f'{name}: ' + r = t.arg_types[i].accept(self) + + s += r + + if t.arg_kinds[i].is_optional(): + s += ' =' + + return f'[{s}]' + def visit_callable_type(self, t: CallableType) -> str: param_spec = t.param_spec() if param_spec is not None: diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index 7bfae5aed1c27..94eeee79be93c 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -6,7 +6,7 @@ Type, SyntheticTypeVisitor, AnyType, UninhabitedType, NoneType, ErasedType, DeletedType, TypeVarType, LiteralType, Instance, CallableType, TupleType, TypedDictType, UnionType, Overloaded, TypeType, CallableArgument, UnboundType, TypeList, StarType, EllipsisType, - PlaceholderType, PartialType, RawExpressionType, TypeAliasType, ParamSpecType, + PlaceholderType, PartialType, RawExpressionType, TypeAliasType, ParamSpecType, Parameters, UnpackType ) @@ -41,6 +41,9 @@ def visit_type_var(self, t: TypeVarType) -> None: def visit_param_spec(self, t: ParamSpecType) -> None: pass + def visit_parameters(self, t: Parameters) -> None: + self.traverse_types(t.arg_types) + def visit_literal_type(self, t: LiteralType) -> None: t.fallback.accept(self) diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 7b41db2c94d63..fe2354612fbb9 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -25,8 +25,7 @@ def foo1(x: Callable[P, int]) -> Callable[P, str]: ... def foo2(x: P) -> P: ... # E: Invalid location for ParamSpec "P" \ # N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]' -# TODO(PEP612): uncomment once we have support for Concatenate -# def foo3(x: Concatenate[int, P]) -> int: ... $ E: Invalid location for Concatenate +def foo3(x: Concatenate[int, P]) -> int: ... # E: Invalid location for Concatenate def foo4(x: List[P]) -> None: ... # E: Invalid location for ParamSpec "P" \ # N: You can use ParamSpec as the first argument to Callable, e.g., 'Callable[P, int]' @@ -100,7 +99,7 @@ class C(Generic[P]): def f(x: int, y: str) -> None: ... -reveal_type(C(f)) # N: Revealed type is "__main__.C[def (x: builtins.int, y: builtins.str)]" +reveal_type(C(f)) # N: Revealed type is "__main__.C[[x: builtins.int, y: builtins.str]]" reveal_type(C(f).m) # N: Revealed type is "def (x: builtins.int, y: builtins.str) -> builtins.int" [builtins fixtures/dict.pyi] @@ -142,7 +141,7 @@ def dec() -> Callable[[Callable[P, R]], W[P, R]]: @dec() def f(a: int, b: str) -> None: ... -reveal_type(f) # N: Revealed type is "__main__.W[def (a: builtins.int, b: builtins.str), None]" +reveal_type(f) # N: Revealed type is "__main__.W[[a: builtins.int, b: builtins.str], None]" reveal_type(f(1, '')) # N: Revealed type is "None" reveal_type(f.x) # N: Revealed type is "builtins.int" @@ -416,3 +415,614 @@ with f() as x: pass [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] + +[case testParamSpecLiterals] +from typing_extensions import ParamSpec, TypeAlias +from typing import Generic, TypeVar + +P = ParamSpec("P") +T = TypeVar("T") + +class Z(Generic[P]): ... + +# literals can be applied +n: Z[[int]] + +# TODO: type aliases too +nt1 = Z[[int]] +nt2: TypeAlias = Z[[int]] + +unt1: nt1 +unt2: nt2 + +# literals actually keep types +reveal_type(n) # N: Revealed type is "__main__.Z[[builtins.int]]" +reveal_type(unt1) # N: Revealed type is "__main__.Z[[builtins.int]]" +reveal_type(unt2) # N: Revealed type is "__main__.Z[[builtins.int]]" + +# passing into a function keeps the type +def fT(a: T) -> T: ... +def fP(a: Z[P]) -> Z[P]: ... + +reveal_type(fT(n)) # N: Revealed type is "__main__.Z[[builtins.int]]" +reveal_type(fP(n)) # N: Revealed type is "__main__.Z[[builtins.int]]" + +# literals can be in function args and return type +def k(a: Z[[int]]) -> Z[[str]]: ... + +# functions work +reveal_type(k(n)) # N: Revealed type is "__main__.Z[[builtins.str]]" + +# literals can be matched in arguments +def kb(a: Z[[bytes]]) -> Z[[str]]: ... + +reveal_type(kb(n)) # N: Revealed type is "__main__.Z[[builtins.str]]" \ + # E: Argument 1 to "kb" has incompatible type "Z[[int]]"; expected "Z[[bytes]]" + + +n2: Z[bytes] + +reveal_type(kb(n2)) # N: Revealed type is "__main__.Z[[builtins.str]]" +[builtins fixtures/tuple.pyi] + +[case testParamSpecConcatenateFromPep] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable, TypeVar, Generic + +P = ParamSpec("P") +R = TypeVar("R") + +# CASE 1 +class Request: + ... + +def with_request(f: Callable[Concatenate[Request, P], R]) -> Callable[P, R]: + def inner(*args: P.args, **kwargs: P.kwargs) -> R: + return f(Request(), *args, **kwargs) + return inner + +@with_request +def takes_int_str(request: Request, x: int, y: str) -> int: + # use request + return x + 7 + +reveal_type(takes_int_str) # N: Revealed type is "def (x: builtins.int, y: builtins.str) -> builtins.int" + +takes_int_str(1, "A") # Accepted +takes_int_str("B", 2) # E: Argument 1 to "takes_int_str" has incompatible type "str"; expected "int" \ + # E: Argument 2 to "takes_int_str" has incompatible type "int"; expected "str" + +# CASE 2 +T = TypeVar("T") +P_2 = ParamSpec("P_2") + +class X(Generic[T, P]): + f: Callable[P, int] + x: T + +def f1(x: X[int, P_2]) -> str: ... # Accepted +def f2(x: X[int, Concatenate[int, P_2]]) -> str: ... # Accepted +def f3(x: X[int, [int, bool]]) -> str: ... # Accepted +# ellipsis only show up here, but I can assume it works like Callable[..., R] +def f4(x: X[int, ...]) -> str: ... # Accepted +# TODO: this is not rejected: +# def f5(x: X[int, int]) -> str: ... # Rejected + +# CASE 3 +def bar(x: int, *args: bool) -> int: ... +def add(x: Callable[P, int]) -> Callable[Concatenate[str, P], bool]: ... + +reveal_type(add(bar)) # N: Revealed type is "def (builtins.str, x: builtins.int, *args: builtins.bool) -> builtins.bool" + +def remove(x: Callable[Concatenate[int, P], int]) -> Callable[P, bool]: ... + +reveal_type(remove(bar)) # N: Revealed type is "def (*args: builtins.bool) -> builtins.bool" + +def transform( + x: Callable[Concatenate[int, P], int] +) -> Callable[Concatenate[str, P], bool]: ... + +# In the PEP, "__a" appears. What is that? Autogenerated names? To what spec? +reveal_type(transform(bar)) # N: Revealed type is "def (builtins.str, *args: builtins.bool) -> builtins.bool" + +# CASE 4 +def expects_int_first(x: Callable[Concatenate[int, P], int]) -> None: ... + +@expects_int_first # E: Argument 1 to "expects_int_first" has incompatible type "Callable[[str], int]"; expected "Callable[[int], int]" \ + # N: This may be because "one" has arguments named: "x" +def one(x: str) -> int: ... + +@expects_int_first # E: Argument 1 to "expects_int_first" has incompatible type "Callable[[NamedArg(int, 'x')], int]"; expected "Callable[[int], int]" +def two(*, x: int) -> int: ... + +@expects_int_first # E: Argument 1 to "expects_int_first" has incompatible type "Callable[[KwArg(int)], int]"; expected "Callable[[int], int]" +def three(**kwargs: int) -> int: ... + +@expects_int_first # Accepted +def four(*args: int) -> int: ... +[builtins fixtures/tuple.pyi] +[builtins fixtures/dict.pyi] + +[case testParamSpecTwiceSolving] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable, TypeVar + +P = ParamSpec("P") +R = TypeVar("R") + +def f(one: Callable[Concatenate[int, P], R], two: Callable[Concatenate[str, P], R]) -> Callable[P, R]: ... + +a: Callable[[int, bytes], str] +b: Callable[[str, bytes], str] + +reveal_type(f(a, b)) # N: Revealed type is "def (builtins.bytes) -> builtins.str" +[builtins fixtures/tuple.pyi] + +[case testParamSpecConcatenateInReturn] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable, Protocol + +P = ParamSpec("P") + +def f(i: Callable[Concatenate[int, P], str]) -> Callable[Concatenate[int, P], str]: ... + +n: Callable[[int, bytes], str] + +reveal_type(f(n)) # N: Revealed type is "def (builtins.int, builtins.bytes) -> builtins.str" +[builtins fixtures/tuple.pyi] + +[case testParamSpecConcatenateNamedArgs] +# flags: --strict-concatenate +# this is one noticeable deviation from PEP but I believe it is for the better +from typing_extensions import ParamSpec, Concatenate +from typing import Callable, TypeVar + +P = ParamSpec("P") +R = TypeVar("R") + +def f1(c: Callable[P, R]) -> Callable[Concatenate[int, P], R]: + def result(x: int, /, *args: P.args, **kwargs: P.kwargs) -> R: ... + + return result # Accepted + +def f2(c: Callable[P, R]) -> Callable[Concatenate[int, P], R]: + def result(x: int, *args: P.args, **kwargs: P.kwargs) -> R: ... + + return result # Rejected + +# reason for rejection: +f2(lambda x: 42)(42, x=42) +[builtins fixtures/tuple.pyi] +[out] +main:10: error: invalid syntax +[out version>=3.8] +main:17: error: Incompatible return value type (got "Callable[[Arg(int, 'x'), **P], R]", expected "Callable[[int, **P], R]") +main:17: note: This may be because "result" has arguments named: "x" + +[case testNonStrictParamSpecConcatenateNamedArgs] +# this is one noticeable deviation from PEP but I believe it is for the better +from typing_extensions import ParamSpec, Concatenate +from typing import Callable, TypeVar + +P = ParamSpec("P") +R = TypeVar("R") + +def f1(c: Callable[P, R]) -> Callable[Concatenate[int, P], R]: + def result(x: int, /, *args: P.args, **kwargs: P.kwargs) -> R: ... + + return result # Accepted + +def f2(c: Callable[P, R]) -> Callable[Concatenate[int, P], R]: + def result(x: int, *args: P.args, **kwargs: P.kwargs) -> R: ... + + return result # Rejected -> Accepted + +# reason for rejection: +f2(lambda x: 42)(42, x=42) +[builtins fixtures/tuple.pyi] +[out] +main:9: error: invalid syntax +[out version>=3.8] + +[case testParamSpecConcatenateWithTypeVar] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable, TypeVar + +P = ParamSpec("P") +R = TypeVar("R") +S = TypeVar("S") + +def f(c: Callable[Concatenate[S, P], R]) -> Callable[Concatenate[S, P], R]: ... + +def a(n: int) -> None: ... + +n = f(a) + +reveal_type(n) # N: Revealed type is "def (builtins.int)" +reveal_type(n(42)) # N: Revealed type is "None" +[builtins fixtures/tuple.pyi] + +[case testCallablesAsParameters] +# credits to https://github.com/microsoft/pyright/issues/2705 +from typing_extensions import ParamSpec, Concatenate +from typing import Generic, Callable, Any + +P = ParamSpec("P") + +class Foo(Generic[P]): + def __init__(self, func: Callable[P, Any]) -> None: ... +def bar(baz: Foo[Concatenate[int, P]]) -> Foo[P]: ... + +def test(a: int, /, b: str) -> str: ... + +abc = Foo(test) +reveal_type(abc) +bar(abc) +[builtins fixtures/tuple.pyi] +[out] +main:11: error: invalid syntax +[out version>=3.8] +main:14: note: Revealed type is "__main__.Foo[[builtins.int, b: builtins.str]]" + +[case testSolveParamSpecWithSelfType] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable, Generic + +P = ParamSpec("P") + +class Foo(Generic[P]): + def foo(self: 'Foo[P]', other: Callable[P, None]) -> None: ... + +n: Foo[[int]] +def f(x: int) -> None: ... + +n.foo(f) +[builtins fixtures/tuple.pyi] + +[case testParamSpecLiteralsTypeApplication] +from typing_extensions import ParamSpec +from typing import Generic, Callable + +P = ParamSpec("P") + +class Z(Generic[P]): + def __init__(self, c: Callable[P, None]) -> None: + ... + +# it allows valid functions +reveal_type(Z[[int]](lambda x: None)) # N: Revealed type is "__main__.Z[[builtins.int]]" +reveal_type(Z[[]](lambda: None)) # N: Revealed type is "__main__.Z[[]]" +reveal_type(Z[bytes, str](lambda b, s: None)) # N: Revealed type is "__main__.Z[[builtins.bytes, builtins.str]]" + +# it disallows invalid functions +def f1(n: str) -> None: ... +def f2(b: bytes, i: int) -> None: ... + +Z[[int]](lambda one, two: None) # E: Cannot infer type of lambda \ + # E: Argument 1 to "Z" has incompatible type "Callable[[Any, Any], None]"; expected "Callable[[int], None]" +Z[[int]](f1) # E: Argument 1 to "Z" has incompatible type "Callable[[str], None]"; expected "Callable[[int], None]" + +Z[[]](lambda one: None) # E: Cannot infer type of lambda \ + # E: Argument 1 to "Z" has incompatible type "Callable[[Any], None]"; expected "Callable[[], None]" + +Z[bytes, str](lambda one: None) # E: Cannot infer type of lambda \ + # E: Argument 1 to "Z" has incompatible type "Callable[[Any], None]"; expected "Callable[[bytes, str], None]" +Z[bytes, str](f2) # E: Argument 1 to "Z" has incompatible type "Callable[[bytes, int], None]"; expected "Callable[[bytes, str], None]" + +[builtins fixtures/tuple.pyi] + +[case testParamSpecLiteralEllipsis] +from typing_extensions import ParamSpec +from typing import Generic, Callable + +P = ParamSpec("P") + +class Z(Generic[P]): + def __init__(self: 'Z[P]', c: Callable[P, None]) -> None: + ... + +def f1() -> None: ... +def f2(*args: int) -> None: ... +def f3(a: int, *, b: bytes) -> None: ... + +def f4(b: bytes) -> None: ... + +argh: Callable[..., None] = f4 + +# check it works +Z[...](f1) +Z[...](f2) +Z[...](f3) + +# check subtyping works +n: Z[...] +n = Z(f1) +n = Z(f2) +n = Z(f3) + +[builtins fixtures/tuple.pyi] + +[case testParamSpecApplyConcatenateTwice] +from typing_extensions import ParamSpec, Concatenate +from typing import Generic, Callable, Optional + +P = ParamSpec("P") + +class C(Generic[P]): + # think PhantomData from rust + phantom: Optional[Callable[P, None]] + + def add_str(self) -> C[Concatenate[int, P]]: + return C[Concatenate[int, P]]() + + def add_int(self) -> C[Concatenate[str, P]]: + return C[Concatenate[str, P]]() + +def f(c: C[P]) -> None: + reveal_type(c) # N: Revealed type is "__main__.C[P`-1]" + + n1 = c.add_str() + reveal_type(n1) # N: Revealed type is "__main__.C[[builtins.int, **P`-1]]" + n2 = n1.add_int() + reveal_type(n2) # N: Revealed type is "__main__.C[[builtins.str, builtins.int, **P`-1]]" + + p1 = c.add_int() + reveal_type(p1) # N: Revealed type is "__main__.C[[builtins.str, **P`-1]]" + p2 = p1.add_str() + reveal_type(p2) # N: Revealed type is "__main__.C[[builtins.int, builtins.str, **P`-1]]" +[builtins fixtures/tuple.pyi] + +[case testParamSpecLiteralJoin] +from typing import Generic, Callable, Union +from typing_extensions import ParamSpec + + +_P = ParamSpec("_P") + +class Job(Generic[_P]): + def __init__(self, target: Callable[_P, None]) -> None: + self.target = target + +def func( + action: Union[Job[int], Callable[[int], None]], +) -> None: + job = action if isinstance(action, Job) else Job(action) + reveal_type(job) # N: Revealed type is "__main__.Job[[builtins.int]]" +[builtins fixtures/tuple.pyi] + +[case testApplyParamSpecToParamSpecLiterals] +from typing import TypeVar, Generic, Callable +from typing_extensions import ParamSpec + +_P = ParamSpec("_P") +_R_co = TypeVar("_R_co", covariant=True) + +class Job(Generic[_P, _R_co]): + def __init__(self, target: Callable[_P, _R_co]) -> None: + self.target = target + +def run_job(job: Job[_P, None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run_job" defined here + ... + + +def func(job: Job[[int, str], None]) -> None: + run_job(job, 42, "Hello") + run_job(job, "Hello", 42) # E: Argument 2 to "run_job" has incompatible type "str"; expected "int" \ + # E: Argument 3 to "run_job" has incompatible type "int"; expected "str" + run_job(job, 42, msg="Hello") # E: Unexpected keyword argument "msg" for "run_job" + run_job(job, "Hello") # E: Too few arguments for "run_job" \ + # E: Argument 2 to "run_job" has incompatible type "str"; expected "int" + +def func2(job: Job[..., None]) -> None: + run_job(job, 42, "Hello") + run_job(job, "Hello", 42) + run_job(job, 42, msg="Hello") + run_job(job, x=42, msg="Hello") +[builtins fixtures/tuple.pyi] + +[case testExpandNonBareParamSpecAgainstCallable] +from typing import Callable, TypeVar, Any +from typing_extensions import ParamSpec + +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) +_P = ParamSpec("_P") +_R = TypeVar("_R") + +def simple_decorator(callable: CallableT) -> CallableT: + # set some attribute on 'callable' + return callable + + +class A: + @simple_decorator + def func(self, action: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: + ... + +reveal_type(A.func) # N: Revealed type is "def [_P, _R] (self: __main__.A, action: def (*_P.args, **_P.kwargs) -> _R`-2, *_P.args, **_P.kwargs) -> _R`-2" + +# TODO: _R` keeps flip-flopping between 5 (?), 13, 14, 15. Spooky. +# reveal_type(A().func) $ N: Revealed type is "def [_P, _R] (action: def (*_P.args, **_P.kwargs) -> _R`13, *_P.args, **_P.kwargs) -> _R`13" + +def f(x: int) -> int: + ... + +reveal_type(A().func(f, 42)) # N: Revealed type is "builtins.int" + +# TODO: this should reveal `int` +reveal_type(A().func(lambda x: x + x, 42)) # N: Revealed type is "Any" +[builtins fixtures/tuple.pyi] + +[case testParamSpecConstraintOnOtherParamSpec] +from typing import Callable, TypeVar, Any, Generic +from typing_extensions import ParamSpec + +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) +_P = ParamSpec("_P") +_R_co = TypeVar("_R_co", covariant=True) + +def simple_decorator(callable: CallableT) -> CallableT: + ... + +class Job(Generic[_P, _R_co]): + def __init__(self, target: Callable[_P, _R_co]) -> None: + ... + + +class A: + @simple_decorator + def func(self, action: Job[_P, None]) -> Job[_P, None]: + ... + +reveal_type(A.func) # N: Revealed type is "def [_P] (self: __main__.A, action: __main__.Job[_P`-1, None]) -> __main__.Job[_P`-1, None]" +# TODO: flakey, _P`4 alternates around. +# reveal_type(A().func) $ N: Revealed type is "def [_P] (action: __main__.Job[_P`4, None]) -> __main__.Job[_P`4, None]" +reveal_type(A().func(Job(lambda x: x))) # N: Revealed type is "__main__.Job[[x: Any], None]" + +def f(x: int, y: int) -> None: ... +reveal_type(A().func(Job(f))) # N: Revealed type is "__main__.Job[[x: builtins.int, y: builtins.int], None]" +[builtins fixtures/tuple.pyi] + +[case testConstraintBetweenParamSpecFunctions1] +from typing import Callable, TypeVar, Any, Generic +from typing_extensions import ParamSpec + +_P = ParamSpec("_P") +_R_co = TypeVar("_R_co", covariant=True) + +def simple_decorator(callable: Callable[_P, _R_co]) -> Callable[_P, _R_co]: ... +class Job(Generic[_P]): ... + + +@simple_decorator +def func(__action: Job[_P]) -> Callable[_P, None]: + ... + +reveal_type(func) # N: Revealed type is "def [_P] (__main__.Job[_P`-1]) -> def (*_P.args, **_P.kwargs)" +[builtins fixtures/tuple.pyi] + +[case testConstraintBetweenParamSpecFunctions2] +from typing import Callable, TypeVar, Any, Generic +from typing_extensions import ParamSpec + +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) +_P = ParamSpec("_P") + +def simple_decorator(callable: CallableT) -> CallableT: ... +class Job(Generic[_P]): ... + + +@simple_decorator +def func(__action: Job[_P]) -> Callable[_P, None]: + ... + +reveal_type(func) # N: Revealed type is "def [_P] (__main__.Job[_P`-1]) -> def (*_P.args, **_P.kwargs)" +[builtins fixtures/tuple.pyi] + +[case testConstraintsBetweenConcatenatePrefixes] +from typing import Any, Callable, Generic, TypeVar +from typing_extensions import Concatenate, ParamSpec + +_P = ParamSpec("_P") +_T = TypeVar("_T") + +class Awaitable(Generic[_T]): ... + +def adds_await() -> Callable[ + [Callable[Concatenate[_T, _P], None]], + Callable[Concatenate[_T, _P], Awaitable[None]], +]: + def decorator( + func: Callable[Concatenate[_T, _P], None], + ) -> Callable[Concatenate[_T, _P], Awaitable[None]]: + ... + + return decorator # we want `_T` and `_P` to refer to the same things. +[builtins fixtures/tuple.pyi] + +[case testParamSpecVariance] +from typing import Callable, Generic +from typing_extensions import ParamSpec + +_P = ParamSpec("_P") + +class Job(Generic[_P]): + def __init__(self, target: Callable[_P, None]) -> None: ... + def into_callable(self) -> Callable[_P, None]: ... + +class A: + def func(self, var: int) -> None: ... + def other_func(self, job: Job[[int]]) -> None: ... + + +job = Job(A().func) +reveal_type(job) # N: Revealed type is "__main__.Job[[var: builtins.int]]" +A().other_func(job) # This should NOT error (despite the keyword) + +# and yet the keyword should remain +job.into_callable()(var=42) +job.into_callable()(x=42) # E: Unexpected keyword argument "x" + +# similar for other functions +def f1(n: object) -> None: ... +def f2(n: int) -> None: ... +def f3(n: bool) -> None: ... + +# just like how this is legal... +a1: Callable[[bool], None] +a1 = f3 +a1 = f2 +a1 = f1 + +# ... this is also legal +a2: Job[[bool]] +a2 = Job(f3) +a2 = Job(f2) +a2 = Job(f1) + +# and this is not legal +def f4(n: bytes) -> None: ... +a1 = f4 # E: Incompatible types in assignment (expression has type "Callable[[bytes], None]", variable has type "Callable[[bool], None]") +a2 = Job(f4) # E: Argument 1 to "Job" has incompatible type "Callable[[bytes], None]"; expected "Callable[[bool], None]" + +# nor is this: +a4: Job[[int]] +a4 = Job(f3) # E: Argument 1 to "Job" has incompatible type "Callable[[bool], None]"; expected "Callable[[int], None]" +a4 = Job(f2) +a4 = Job(f1) + +# just like this: +a3: Callable[[int], None] +a3 = f3 # E: Incompatible types in assignment (expression has type "Callable[[bool], None]", variable has type "Callable[[int], None]") +a3 = f2 +a3 = f1 +[builtins fixtures/tuple.pyi] + +[case testGenericsInInferredParamspec] +from typing import Callable, TypeVar, Generic +from typing_extensions import ParamSpec + +_P = ParamSpec("_P") +_T = TypeVar("_T") + +class Job(Generic[_P]): + def __init__(self, target: Callable[_P, None]) -> None: ... + def into_callable(self) -> Callable[_P, None]: ... + +def generic_f(x: _T) -> None: ... + +j = Job(generic_f) +reveal_type(j) # N: Revealed type is "__main__.Job[[x: _T`-1]]" + +jf = j.into_callable() +reveal_type(jf) # N: Revealed type is "def [_T] (x: _T`-1)" +reveal_type(jf(1)) # N: Revealed type is "None" +[builtins fixtures/tuple.pyi] + +[case testStackedConcatenateIsIllegal] +from typing_extensions import Concatenate, ParamSpec +from typing import Callable + +P = ParamSpec("P") + +def x(f: Callable[Concatenate[int, Concatenate[int, P]], None]) -> None: ... # E: Nested Concatenates are invalid +[builtins fixtures/tuple.pyi]