From d7ca2201b20a2cdf78e8d38ff9f9fed310f0dd25 Mon Sep 17 00:00:00 2001 From: KotlinIsland <65446343+kotlinisland@users.noreply.github.com> Date: Wed, 18 Dec 2024 01:37:19 +1000 Subject: [PATCH] variance modifiers --- CHANGELOG.md | 2 + docs/source/based_features.rst | 72 ++++++++++++++ mypy/checkmember.py | 6 +- mypy/expandtype.py | 86 +++++++++++++---- mypy/message_registry.py | 4 +- mypy/messages.py | 4 + mypy/plugins/proper_plugin.py | 1 + mypy/semanal.py | 36 +++++-- mypy/subtypes.py | 3 + mypy/type_visitor.py | 8 ++ mypy/typeanal.py | 40 ++++++++ mypy/types.py | 59 +++++++++++- .../unit/check-based-unsafe-variance.test | 4 +- .../unit/check-based-variance-modifiers.test | 95 +++++++++++++++++++ test-data/unit/check-functions.test | 8 +- test-data/unit/check-protocols.test | 4 +- test-data/unit/lib-stub/basedtyping.pyi | 7 +- 17 files changed, 398 insertions(+), 41 deletions(-) create mode 100644 test-data/unit/check-based-variance-modifiers.test diff --git a/CHANGELOG.md b/CHANGELOG.md index 68d1bae2a..2fcf169a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # Basedmypy Changelog ## [Unreleased] +### Added +- explicit and use-site variance modifiers `In`/`Out`/`InOut` ## [2.9.0] ### Added diff --git a/docs/source/based_features.rst b/docs/source/based_features.rst index ffc6c1a00..25495ff2c 100644 --- a/docs/source/based_features.rst +++ b/docs/source/based_features.rst @@ -25,6 +25,78 @@ Using the ``&`` operator or ``basedtyping.Intersection`` you can denote intersec x.reset() x.add("first") + +Explicit Variance +----------------- + +it is frequently desirable to explicitly declare the variance of type parameters on types and classes. +but until dedicated syntax is added: + +.. code-block:: python + + from basedtyping import In, InOut, Out + + class Example[ + Contravariant: In, # In designates contravariant, as values can only pass 'into' the class + Invariant: InOut, # I nOut designates invariant, as values can pass both 'into' and 'out' of the class + Covariant: Out, # Out designates covariant, as the values can only pass 'out' of the class + ]: ... + +The same applies to type declarations: + +.. code-block:: python + + type Example[Contravariant: In, Invariant: InOut, Covariant: Out] = ... + +when a bound is supplied, it is provided as an argument to the variance modifier: + +.. code-block:: python + + class Example[T: Out[int]]: ... + + +Use-site Variance +----------------- + +use-site variance is a concept that can be used to modify an invariant type +parameter to be modified as covariant or contravariant + +given: + +.. code-block:: python + + def f(data: list[object]): # we can't use `Sequence[object]` because we need `clear` + for element in data: + print(element) + data.clear() + + a = [1, 2, 3] + f(a) # error: list[int] is incompatible with list[object] + +we can implement use-site variance here to make the api both type-safe and ergonomic: + +.. code-block:: python + + def f(data: list[Out[object]]): + for element in data: + print(element) + data.clear() + + a = [1, 2, 3] + f(a) # no error, list[int] is a valid subtype of the covariant list[out object] + +what makes this typesafe is that the usages of the type parameter in input positions +are replaced with `Never` (or output positions and the upper bound in the case of contravariance): + +.. code-block:: python + + class A[T: int | str]: + def f(self, t: T) -> T: ... + + A[Out[int]]().f # (t: Never) -> int + A[In[int]]().f # (t: int) -> int | str + + Type Joins ---------- diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 29b9961ec..db1c9a932 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -17,6 +17,8 @@ ARG_POS, ARG_STAR, ARG_STAR2, + CONTRAVARIANT, + COVARIANT, EXCLUDED_ENUM_ATTRIBUTES, SYMBOL_FUNCBASE_TYPES, Context, @@ -811,7 +813,9 @@ def analyze_var( mx.msg.cant_assign_to_classvar(name, mx.context) t = freshen_all_functions_type_vars(typ) t = expand_self_type_if_needed(t, mx, var, original_itype) - t = expand_type_by_instance(t, itype) + t = expand_type_by_instance( + t, itype, use_variance=CONTRAVARIANT if mx.is_lvalue else COVARIANT + ) freeze_all_type_vars(t) result = t typ = get_proper_type(typ) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index ff9d845ec..77a023ca4 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import Final, Iterable, Mapping, Sequence, TypeVar, cast, overload +from contextlib import contextmanager +from typing import Final, Generator, Iterable, Mapping, Sequence, TypeVar, cast, overload -from mypy.nodes import ARG_STAR, FakeInfo, Var +from mypy.nodes import ARG_STAR, CONTRAVARIANT, COVARIANT, FakeInfo, Var, INVARIANT from mypy.state import state from mypy.types import ( ANY_STRATEGY, @@ -38,6 +39,7 @@ UninhabitedType, UnionType, UnpackType, + VarianceModifier, flatten_nested_unions, get_proper_type, split_with_prefix_and_suffix, @@ -53,37 +55,49 @@ @overload -def expand_type(typ: CallableType, env: Mapping[TypeVarId, Type]) -> CallableType: ... +def expand_type( + typ: CallableType, env: Mapping[TypeVarId, Type], *, variance: int | None = ... +) -> CallableType: ... @overload -def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType: ... +def expand_type( + typ: ProperType, env: Mapping[TypeVarId, Type], *, variance: int | None = ... +) -> ProperType: ... @overload -def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: ... +def expand_type( + typ: Type, env: Mapping[TypeVarId, Type], *, variance: int | None = ... +) -> Type: ... -def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: +def expand_type(typ: Type, env: Mapping[TypeVarId, Type], *, variance=None) -> Type: """Substitute any type variable references in a type given by a type environment. """ - return typ.accept(ExpandTypeVisitor(env)) + return typ.accept(ExpandTypeVisitor(env, variance=variance)) @overload -def expand_type_by_instance(typ: CallableType, instance: Instance) -> CallableType: ... +def expand_type_by_instance( + typ: CallableType, instance: Instance, *, use_variance: int | None = ... +) -> CallableType: ... @overload -def expand_type_by_instance(typ: ProperType, instance: Instance) -> ProperType: ... +def expand_type_by_instance( + typ: ProperType, instance: Instance, *, use_variance: int | None = ... +) -> ProperType: ... @overload -def expand_type_by_instance(typ: Type, instance: Instance) -> Type: ... +def expand_type_by_instance( + typ: Type, instance: Instance, *, use_variance: int | None = ... +) -> Type: ... -def expand_type_by_instance(typ: Type, instance: Instance) -> Type: +def expand_type_by_instance(typ: Type, instance: Instance, use_variance=None) -> Type: """Substitute type variables in type using values from an Instance. Type variables are considered to be bound by the class declaration.""" if not instance.args and not instance.type.has_type_var_tuple_type: @@ -108,12 +122,11 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type: else: tvars = tuple(instance.type.defn.type_vars) instance_args = instance.args - for binder, arg in zip(tvars, instance_args): assert isinstance(binder, TypeVarLikeType) variables[binder.id] = arg - return expand_type(typ, variables) + return expand_type(typ, variables, variance=use_variance) F = TypeVar("F", bound=FunctionLike) @@ -181,10 +194,28 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator): variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value - def __init__(self, variables: Mapping[TypeVarId, Type]) -> None: + def __init__( + self, variables: Mapping[TypeVarId, Type], *, variance: int | None = None + ) -> None: super().__init__() self.variables = variables self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {} + self.variance = variance + self.using_variance: int | None = None + + @contextmanager + def in_variance(self) -> Generator[None]: + using_variance = self.using_variance + self.using_variance = CONTRAVARIANT + yield + self.using_variance = using_variance + + @contextmanager + def out_variance(self) -> Generator[None]: + using_variance = self.using_variance + self.using_variance = COVARIANT + yield + self.using_variance = using_variance def visit_unbound_type(self, t: UnboundType) -> Type: return t @@ -238,6 +269,19 @@ def visit_type_var(self, t: TypeVarType) -> Type: if t.id.is_self(): t = t.copy_modified(upper_bound=t.upper_bound.accept(self)) repl = self.variables.get(t.id, t) + use_site_variance = repl.variance if isinstance(repl, VarianceModifier) else None + positional_variance = self.using_variance or self.variance + if ( + positional_variance is not None + and use_site_variance is not None + and use_site_variance is not INVARIANT + and positional_variance != use_site_variance + ): + repl = ( + t.upper_bound.accept(self) + if positional_variance == COVARIANT + else UninhabitedType() + ) if isinstance(repl, ProperType) and isinstance(repl, Instance): # TODO: do we really need to do this? # If I try to remove this special-casing ~40 tests fail on reveal_type(). @@ -414,10 +458,15 @@ def visit_callable_type(self, t: CallableType) -> CallableType: needs_normalization = True arg_types = self.interpolate_args_for_unpack(t, var_arg.typ) else: - arg_types = self.expand_types(t.arg_types) + with self.in_variance(): + arg_types = self.expand_types(t.arg_types) + with self.out_variance(): + ret_type = t.ret_type.accept(self) + if isinstance(ret_type, VarianceModifier): + ret_type = ret_type.value expanded = t.copy_modified( arg_types=arg_types, - ret_type=t.ret_type.accept(self), + ret_type=ret_type, type_guard=t.type_guard and cast(TypeGuardType, t.type_guard.accept(self)), type_is=(t.type_is.accept(self) if t.type_is is not None else None), ) @@ -538,7 +587,10 @@ def visit_typeguard_type(self, t: TypeGuardType) -> Type: def expand_types(self, types: Iterable[Type]) -> list[Type]: a: list[Type] = [] for t in types: - a.append(t.accept(self)) + typ = t.accept(self) + if isinstance(typ, VarianceModifier): + typ = typ.value + a.append(typ) return a diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 608a3783e..889972abb 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -113,11 +113,11 @@ def with_additional_msg(self, info: str) -> ErrorMessage: ) FORMAT_REQUIRES_MAPPING: Final = "Format requires a mapping" RETURN_TYPE_CANNOT_BE_CONTRAVARIANT: Final = ErrorMessage( - "This usage of this contravariant type variable is unsafe as a return type.", + "This usage of this contravariant type variable is unsafe as a return type", codes.UNSAFE_VARIANCE, ) FUNCTION_PARAMETER_CANNOT_BE_COVARIANT: Final = ErrorMessage( - "This usage of this covariant type variable is unsafe as an input parameter.", + "This usage of this covariant type variable is unsafe as an input parameter", codes.UNSAFE_VARIANCE, ) UNSAFE_VARIANCE_NOTE = ErrorMessage( diff --git a/mypy/messages.py b/mypy/messages.py index 319b6ca89..4e05ae21d 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -93,6 +93,7 @@ UninhabitedType, UnionType, UnpackType, + VarianceModifier, flatten_nested_unions, get_proper_type, get_proper_types, @@ -2676,6 +2677,9 @@ def format_literal_value(typ: LiteralType) -> str: type_str += f"[{format_list(typ.args)}]" return type_str + if isinstance(typ, VarianceModifier): + return typ.render(format) + # TODO: always mention type alias names in errors. typ = get_proper_type(typ) diff --git a/mypy/plugins/proper_plugin.py b/mypy/plugins/proper_plugin.py index f51685c80..2901f6ed6 100644 --- a/mypy/plugins/proper_plugin.py +++ b/mypy/plugins/proper_plugin.py @@ -107,6 +107,7 @@ def is_special_target(right: ProperType) -> bool: "mypy.types.DeletedType", "mypy.types.RequiredType", "mypy.types.ReadOnlyType", + "mypy.types.VarianceModifier", ): # Special case: these are not valid targets for a type alias and thus safe. # TODO: introduce a SyntheticType base to simplify this? diff --git a/mypy/semanal.py b/mypy/semanal.py index 2f34392c4..225f5f792 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -298,6 +298,7 @@ UnionType, UnpackType, UntypedType, + VarianceModifier, get_proper_type, get_proper_types, has_type_vars, @@ -1962,14 +1963,25 @@ def analyze_type_param( self, type_param: TypeParam, context: Context ) -> TypeVarLikeExpr | None: fullname = self.qualified_name(type_param.name) + variance = VARIANCE_NOT_READY + upper_bound = None if type_param.upper_bound: - upper_bound = self.anal_type(type_param.upper_bound, allow_placeholder=True) - # TODO: we should validate the upper bound is valid for a given kind. - if upper_bound is None: - # This and below copies special-casing for old-style type variables, that - # is equally necessary for new-style classes to break a vicious circle. - upper_bound = PlaceholderType(None, [], context.line) - else: + variance_or_bound = self.anal_type( + type_param.upper_bound, + allow_placeholder=True, + is_type_var_bound=isinstance(context, (ClassDef, TypeAliasStmt)), + ) + if isinstance(variance_or_bound, VarianceModifier): + variance = variance_or_bound.variance + upper_bound = variance_or_bound._value + else: + upper_bound = variance_or_bound + # TODO: we should validate the upper bound is valid for a given kind. + if upper_bound is None: + # This and below copies special-casing for old-style type variables, that + # is equally necessary for new-style classes to break a vicious circle. + upper_bound = PlaceholderType(None, [], context.line) + if upper_bound is None: if type_param.kind == TYPE_VAR_TUPLE_KIND: upper_bound = self.named_type("builtins.tuple", [self.object_type()]) else: @@ -2012,7 +2024,7 @@ def analyze_type_param( values=values, upper_bound=upper_bound, default=default, - variance=VARIANCE_NOT_READY, + variance=variance, is_new_style=True, line=context.line, ) @@ -6261,6 +6273,7 @@ def analyze_type_application_args(self, expr: IndexExpr) -> list[Type] | None: allow_param_spec_literals=has_param_spec, allow_unpack=allow_unpack, runtime=True, + nested=True, ) if analyzed is None: return None @@ -7545,6 +7558,7 @@ def type_analyzer( report_invalid_types: bool = True, prohibit_self_type: str | None = None, allow_type_any: bool = False, + is_type_var_bound=False, ) -> TypeAnalyser: if tvar_scope is None: tvar_scope = self.tvar_scope @@ -7564,6 +7578,7 @@ def type_analyzer( allow_unpack=allow_unpack, prohibit_self_type=prohibit_self_type, allow_type_any=allow_type_any, + is_type_var_bound=is_type_var_bound, ) tpan.in_dynamic_func = bool(self.function_stack and self.function_stack[-1].is_dynamic()) tpan.global_scope = not self.type and not self.function_stack @@ -7589,6 +7604,8 @@ def anal_type( prohibit_self_type: str | None = None, allow_type_any: bool = False, runtime: bool | None = None, + is_type_var_bound=False, + nested=False, ) -> Type | None: """Semantically analyze a type. @@ -7624,6 +7641,7 @@ def anal_type( report_invalid_types=report_invalid_types, prohibit_self_type=prohibit_self_type, allow_type_any=allow_type_any, + is_type_var_bound=is_type_var_bound, ) if not a.api.is_stub_file and runtime: a.always_allow_new_syntax = False @@ -7631,6 +7649,8 @@ def anal_type( a.always_allow_new_syntax = True if self.is_stub_file: a.always_allow_new_syntax = True + if nested: + a.nesting_level += 1 tag = self.track_incomplete_refs() typ = typ.accept(a) if self.found_incomplete_ref(tag): diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a0d2552bd..e45808a11 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -67,6 +67,7 @@ UnionType, UnpackType, UntypedType, + VarianceModifier, find_unpack_in_list, get_proper_type, is_named_instance, @@ -378,6 +379,8 @@ def check_type_parameter( p_left = get_proper_type(left) if isinstance(p_left, UninhabitedType) and p_left.ambiguous: variance = COVARIANT + if isinstance(right, VarianceModifier): + variance = right.variance # If variance hasn't been inferred yet, we are lenient and default to # covariance. This shouldn't happen often, but it's very difficult to # avoid these cases altogether. diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 044a1afd0..d4372f0d6 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -49,6 +49,7 @@ UninhabitedType, UnionType, UnpackType, + VarianceModifier, get_proper_type, ) @@ -87,6 +88,10 @@ def visit_erased_type(self, t: ErasedType) -> T: def visit_deleted_type(self, t: DeletedType) -> T: pass + def visit_variance_modifier(self, t: VarianceModifier) -> T: + assert t.value + return t.value.accept(self) + @abstractmethod def visit_type_var(self, t: TypeVarType) -> T: pass @@ -245,6 +250,9 @@ def visit_instance(self, t: Instance) -> Type: result.metadata = t.metadata return result + def visit_variance_modifier(self, t: VarianceModifier) -> Type: + return VarianceModifier(t.variance, t.value.accept(self)) + def visit_type_var(self, t: TypeVarType) -> Type: return t diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 9d6d0ec4d..d6f8f63b7 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -30,6 +30,9 @@ ARG_POS, ARG_STAR, ARG_STAR2, + CONTRAVARIANT, + COVARIANT, + INVARIANT, MISSING_FALLBACK, SYMBOL_FUNCBASE_TYPES, ArgKind, @@ -112,6 +115,7 @@ UnionType, UnpackType, UntypedType, + VarianceModifier, callable_with_ellipsis, find_unpack_in_list, flatten_nested_tuples, @@ -247,6 +251,7 @@ def __init__( allowed_alias_tvars: list[TypeVarLikeType] | None = None, allow_type_any: bool = False, alias_type_params_names: list[str] | None = None, + is_type_var_bound=False, ) -> None: self.api = api self.fail_func = api.fail @@ -294,6 +299,7 @@ def __init__( self.allow_type_any = allow_type_any self.allow_type_var_tuple = False self.allow_unpack = allow_unpack + self.is_type_var_bound = is_type_var_bound def lookup_qualified( self, name: str, ctx: Context, suppress_errors: bool = False @@ -534,6 +540,37 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) elif node.fullname in ("typing_extensions.Concatenate", "typing.Concatenate"): # We check the return type further up the stack for valid use locations return self.apply_concatenate_operator(t) + elif node.fullname in ("basedtyping.In", "basedtyping.Out", "basedtyping.InOut"): + if node.fullname == "basedtyping.In": + variance = CONTRAVARIANT + elif node.fullname == "basedtyping.Out": + variance = COVARIANT + elif node.fullname == "basedtyping.InOut": + variance = INVARIANT + else: + raise ValueError(node.fullname) + if not self.nesting_level and not self.is_type_var_bound: + self.fail("Top level use-site variance is invalid", t, code=codes.VALID_TYPE) + return AnyType(TypeOfAny.from_error) + if self.nesting_level and variance is INVARIANT: + self.fail("Use-site invariance is not supported", t, code=codes.VALID_TYPE) + return AnyType(TypeOfAny.from_error) + if not t.args and (not self.is_type_var_bound or self.nesting_level): + self.fail( + "Use-site variance modifiers must take a single argument", + t, + code=codes.VALID_TYPE, + ) + return AnyType(TypeOfAny.from_error) + if t.args: + if len(t.args) > 1: + self.fail( + "Use-site variance modifiers must only take a single argument", + t, + code=codes.VALID_TYPE, + ) + return VarianceModifier(variance, self.anal_type(t.args[0])) + return VarianceModifier(variance, None) else: return self.analyze_unbound_type_without_type_info(t, sym, defining_literal) else: # sym is None @@ -1126,6 +1163,9 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type: # TODO: should we do something here? return t + def visit_variance_modifier(self, t: VarianceModifier) -> Type: + return t + def visit_type_var(self, t: TypeVarType) -> Type: return t diff --git a/mypy/types.py b/mypy/types.py index 6036065c4..5b5f4bb5c 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -32,6 +32,7 @@ ARG_STAR, ARG_STAR2, INVARIANT, + VARIANCE_NOT_READY, ArgKind, FakeInfo, FuncDef, @@ -39,7 +40,7 @@ ) from mypy.options import Options from mypy.state import state -from mypy.util import IdMapper +from mypy.util import IdMapper, safe T = TypeVar("T") @@ -335,9 +336,17 @@ def _expand_once(self) -> Type: assert isinstance(self.alias.target, Instance) # type: ignore[misc] return self.alias.target.copy_modified(args=self.args) + def apply_explicit_variance(type_var: TypeVarLikeType, substitute: Type) -> Type: + if isinstance(type_var, TypeVarType) and type_var.variance != VARIANCE_NOT_READY: + return VarianceModifier(type_var.variance, substitute) + return substitute + # TODO: this logic duplicates the one in expand_type_by_instance(). if self.alias.tvar_tuple_index is None: - mapping = {v.id: s for (v, s) in zip(self.alias.alias_tvars, self.args)} + mapping = { + v.id: apply_explicit_variance(v, s) + for (v, s) in zip(self.alias.alias_tvars, self.args) + } else: prefix = self.alias.tvar_tuple_index suffix = len(self.alias.alias_tvars) - self.alias.tvar_tuple_index - 1 @@ -348,7 +357,7 @@ def _expand_once(self) -> Type: for tvar, sub in zip( self.alias.alias_tvars[:prefix] + self.alias.alias_tvars[prefix + 1 :], start + end ): - mapping[tvar.id] = sub + mapping[tvar.id] = apply_explicit_variance(tvar, sub) new_tp = self.alias.target.accept(InstantiateAliasVisitor(mapping)) new_tp.accept(LocationSetter(self.line, self.column)) @@ -1028,6 +1037,43 @@ def deserialize(cls, data: JsonDict) -> UnboundType: ) +class VarianceModifier(Type): + # maybe it would make more sense for this to be a `ProperType` + # the current implementation is quite hacky + __slots__ = ("variance", "_value") + + def __init__(self, variance: int, value: Type | None, line: int = -1, column: int = -1): + super().__init__(line, column) + self.variance = variance + self._value = value + + def render(self, renderer: Callable[[Type], str]) -> str: + return f"{self.variance_keyword}{renderer(self.value)}" + + @property + def value(self) -> Type: + return safe(self._value) + + @property + def variance_keyword(self) -> str: + return {1: "out ", 2: "in "}.get(self.variance, "") + + def accept(self, visitor: TypeVisitor[T]) -> T: + return visitor.visit_variance_modifier(self) + + def serialize(self) -> JsonDict: + return { + ".class": "VarianceModifier", + "variance": self.variance, + "value": self.value.serialize(), + } + + @classmethod + def deserialize(cls, data: JsonDict) -> Self: + assert data[".class"] == "VarianceModifier" + return cls(data["variance"], deserialize_type(data["value"])) + + class CallableArgument(ProperType): """Represents a Arg(type, 'name') inside a Callable's type list. @@ -1569,6 +1615,7 @@ class Instance(ProperType): __slots__ = ( "type", "args", + "_variances", "invalid", "type_ref", "last_known_value", @@ -3521,6 +3568,9 @@ def get_proper_type(typ: Type | None) -> ProperType | None: return None if isinstance(typ, TypeGuardedType): # type: ignore[misc] typ = typ.type_guard + if isinstance(typ, VarianceModifier): + # this is quite hacky, it might be better as a ProperType + typ = typ.value while isinstance(typ, TypeAliasType): typ = typ._expand_once() # TODO: store the name of original type alias on this type, so we can show it in errors. @@ -3653,6 +3703,9 @@ def strip_builtins(s: str) -> str: return s.partition(".")[2] return s + def visit_variance_modifier(self, t: VarianceModifier) -> str: + return t.render(lambda x: x.accept(self)) + @contextlib.contextmanager def own_type_vars( self, type_vars: Sequence[TypeVarLikeType] | None diff --git a/test-data/unit/check-based-unsafe-variance.test b/test-data/unit/check-based-unsafe-variance.test index 1eac651a5..d5e7a7ec0 100644 --- a/test-data/unit/check-based-unsafe-variance.test +++ b/test-data/unit/check-based-unsafe-variance.test @@ -2,7 +2,7 @@ from helper import T_out from typing import Generic class G(Generic[T_out]): - def f(self, t: T_out): ... # E: This usage of this covariant type variable is unsafe as an input parameter. [unsafe-variance] \ + def f(self, t: T_out): ... # E: This usage of this covariant type variable is unsafe as an input parameter [unsafe-variance] \ # N: If you are using the value in a 'variance safe' way (not storing or retrieving values), this error could be ignored @@ -10,7 +10,7 @@ class G(Generic[T_out]): from helper import T_in from typing import Generic class G(Generic[T_in]): - def f(self) -> T_in: ... # E: This usage of this contravariant type variable is unsafe as a return type. [unsafe-variance] \ + def f(self) -> T_in: ... # E: This usage of this contravariant type variable is unsafe as a return type [unsafe-variance] \ # N: If you are using the value in a 'variance safe' way (not storing or retrieving values), this error could be ignored diff --git a/test-data/unit/check-based-variance-modifiers.test b/test-data/unit/check-based-variance-modifiers.test new file mode 100644 index 000000000..866b3f958 --- /dev/null +++ b/test-data/unit/check-based-variance-modifiers.test @@ -0,0 +1,95 @@ +[case testVarianceModifiers-3.12] +from basedtyping import In, Out, InOut + +class A[T: In[int]]: + def f(self, t: T) -> T: # E: This usage of this contravariant type variable is unsafe as a return type [unsafe-variance] \ + # N: If you are using the value in a 'variance safe' way (not storing or retrieving values), this error could be ignored + return t + +class B[T: Out[int]]: + def f(self, t: T) -> T: # E: This usage of this covariant type variable is unsafe as an input parameter [unsafe-variance] \ + # N: If you are using the value in a 'variance safe' way (not storing or retrieving values), this error could be ignored + return t + +class C[T: InOut[int]]: + def f(self, t: T) -> T: ... + +class D[T: InOut[int]]: + def f(self) -> T: ... +# check it doesn't infer +d1: D[int] +d2: D[object] = d1 # E: # E: Type argument "object" of "D" must be a subtype of "int" [type-var] \ + # E: Incompatible types in assignment (expression has type "D[int]", variable has type "D[object]") [assignment] + + +[case testVarianceModifiersBare-3.12] +from basedtyping import In, Out + +class A[T: Out]: ... +aint1: A[int] +aobject1: A[object] = aint1 +aobject2: A[object] +aint2: A[int] = aobject2 # E: erm + +class B[T]: ... +type BAlias[T: In] = B[T] +bint1: BAlias[int] +bobject1: BAlias[object] = bint1 # E: erm +bobject2: BAlias[object] +bint2: BAlias[int] = bobject2 +[typing fixtures/typing-full.pyi] +[builtins fixtures/tuple.pyi] + + +[case testVarianceModifiersErrors-3.12] +from basedtyping import In, Out, InOut +class A[T: list[In]]: ... # E: Use-site variance modifiers must take a single argument [valid-type] +class B[T: In[int, str]]: ... # E: Use-site variance modifiers must only take a single argument [valid-type] +def f[T: Out](): ... # E: Top level use-site variance is invalid [valid-type] + + +[case testUseSiteVarianceModifiersErrors] +from basedtyping import In, Out, InOut +from typing import List +a: List[In] # E: Use-site variance modifiers must take a single argument [valid-type] +b: List[InOut] # E: Use-site invariance is not supported [valid-type] +c: List[Out] # E: Use-site variance modifiers must take a single argument [valid-type] +d: Out[int] # E: Top level use-site variance is invalid [valid-type] + + +[case testUseSiteVarianceModifiers] +from __future__ import annotations +from basedtyping import In, Out, InOut +from helper import T +from typing import Generic + +class A(Generic[T]): + t: T + def do(self, value: T) -> T: return value +def f(out_a: A[Out[int | str]], in_a: A[In[bool]]): + reveal_type(out_a.do) # N: Revealed type is "_NamedCallable & (value: Never) -> int | str" + reveal_type(in_a.do) # N: Revealed type is "_NamedCallable & (value: bool) -> object" + reveal_type(out_a.t) # N: Revealed type is "int | str" + reveal_type(in_a.t) # N: Revealed type is "object" + out_a.t = True # E: Incompatible types in assignment (expression has type "bool", variable has type "Never") [assignment] + in_a.t = True # no error, takes bool + in_a.t = 1 # error, needs bool # E: Incompatible types in assignment (expression has type "int", variable has type "bool") [assignment] +a: A[int] +f(a, a) +b: A[None] +f(b, b) # E: Argument 1 to "f" has incompatible type "A[None]"; expected "A[out int | str]" [arg-type] \ + # E: Argument 2 to "f" has incompatible type "A[None]"; expected "A[in bool]" [arg-type] +[builtins fixtures/tuple.pyi] + + +[case testUseSiteAlias-3.12] +from basedtyping import Out +from helper import T + +OutList1 = list[Out[T]] +type OutList2[T2] = list[Out[T2]] +a: list[int] +b: OutList1[object] = a +c: OutList2[object] = a +[typing fixtures/typing-full.pyi] +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index c15265345..402ac8b35 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -2177,7 +2177,7 @@ class A(Generic[t]): return None [builtins fixtures/bool.pyi] [out] -main:5: error: This usage of this covariant type variable is unsafe as an input parameter. +main:5: error: This usage of this covariant type variable is unsafe as an input parameter main:5: note: If you are using the value in a 'variance safe' way (not storing or retrieving values), this error could be ignored [case testRejectCovariantArgumentSplitLine] @@ -2190,7 +2190,7 @@ class A(Generic[t]): return None [builtins fixtures/bool.pyi] [out] -main:6: error: This usage of this covariant type variable is unsafe as an input parameter. +main:6: error: This usage of this covariant type variable is unsafe as an input parameter main:6: note: If you are using the value in a 'variance safe' way (not storing or retrieving values), this error could be ignored [case testRejectCovariantArgumentInLambda] @@ -2204,7 +2204,7 @@ class Thing(Generic[t]): lambda _: None) [builtins fixtures/bool.pyi] [out] -main:8: error: This usage of this covariant type variable is unsafe as an input parameter. +main:8: error: This usage of this covariant type variable is unsafe as an input parameter main:8: note: If you are using the value in a 'variance safe' way (not storing or retrieving values), this error could be ignored [case testRejectCovariantArgumentInLambdaSplitLine] @@ -2220,7 +2220,7 @@ class A(Generic[t]): return None [builtins fixtures/bool.pyi] [out] -main:6: error: This usage of this contravariant type variable is unsafe as a return type. +main:6: error: This usage of this contravariant type variable is unsafe as a return type main:6: note: If you are using the value in a 'variance safe' way (not storing or retrieving values), this error could be ignored [case testAcceptCovariantReturnType] diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 9cb0845d5..6a9dd4709 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -843,10 +843,10 @@ T_co = TypeVar('T_co', covariant=True) T_contra = TypeVar('T_contra', contravariant=True) class Proto(Protocol[T_co, T_contra]): # type: ignore - def one(self, x: T_co) -> None: # E: This usage of this covariant type variable is unsafe as an input parameter. \ + def one(self, x: T_co) -> None: # E: This usage of this covariant type variable is unsafe as an input parameter \ # N: If you are using the value in a 'variance safe' way (not storing or retrieving values), this error could be ignored pass - def other(self) -> T_contra: # E: This usage of this contravariant type variable is unsafe as a return type. \ + def other(self) -> T_contra: # E: This usage of this contravariant type variable is unsafe as a return type \ # N: If you are using the value in a 'variance safe' way (not storing or retrieving values), this error could be ignored pass diff --git a/test-data/unit/lib-stub/basedtyping.pyi b/test-data/unit/lib-stub/basedtyping.pyi index 6433b360f..98ba5ca9b 100644 --- a/test-data/unit/lib-stub/basedtyping.pyi +++ b/test-data/unit/lib-stub/basedtyping.pyi @@ -5,5 +5,8 @@ # will slow down tests. Untyped = 0 -Intersection = 1 -FunctionType = 2 +Intersection = 0 +FunctionType = 0 +In = 0 +Out = 0 +InOut = 0