Skip to content

Commit

Permalink
variance modifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
KotlinIsland committed Jan 2, 2025
1 parent 8cd497f commit d7ca220
Show file tree
Hide file tree
Showing 17 changed files with 398 additions and 41 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Basedmypy Changelog

## [Unreleased]
### Added
- explicit and use-site variance modifiers `In`/`Out`/`InOut`

## [2.9.0]
### Added
Expand Down
72 changes: 72 additions & 0 deletions docs/source/based_features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,78 @@ Using the ``&`` operator or ``basedtyping.Intersection`` you can denote intersec
x.reset()
x.add("first")
Explicit Variance
-----------------

it is frequently desirable to explicitly declare the variance of type parameters on types and classes.
but until dedicated syntax is added:

.. code-block:: python
from basedtyping import In, InOut, Out
class Example[
Contravariant: In, # In designates contravariant, as values can only pass 'into' the class
Invariant: InOut, # I nOut designates invariant, as values can pass both 'into' and 'out' of the class
Covariant: Out, # Out designates covariant, as the values can only pass 'out' of the class
]: ...
The same applies to type declarations:

.. code-block:: python
type Example[Contravariant: In, Invariant: InOut, Covariant: Out] = ...
when a bound is supplied, it is provided as an argument to the variance modifier:

.. code-block:: python
class Example[T: Out[int]]: ...
Use-site Variance
-----------------

use-site variance is a concept that can be used to modify an invariant type
parameter to be modified as covariant or contravariant

given:

.. code-block:: python
def f(data: list[object]): # we can't use `Sequence[object]` because we need `clear`
for element in data:
print(element)
data.clear()
a = [1, 2, 3]
f(a) # error: list[int] is incompatible with list[object]
we can implement use-site variance here to make the api both type-safe and ergonomic:

.. code-block:: python
def f(data: list[Out[object]]):
for element in data:
print(element)
data.clear()
a = [1, 2, 3]
f(a) # no error, list[int] is a valid subtype of the covariant list[out object]
what makes this typesafe is that the usages of the type parameter in input positions
are replaced with `Never` (or output positions and the upper bound in the case of contravariance):

.. code-block:: python
class A[T: int | str]:
def f(self, t: T) -> T: ...
A[Out[int]]().f # (t: Never) -> int
A[In[int]]().f # (t: int) -> int | str
Type Joins
----------

Expand Down
6 changes: 5 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
ARG_POS,
ARG_STAR,
ARG_STAR2,
CONTRAVARIANT,
COVARIANT,
EXCLUDED_ENUM_ATTRIBUTES,
SYMBOL_FUNCBASE_TYPES,
Context,
Expand Down Expand Up @@ -811,7 +813,9 @@ def analyze_var(
mx.msg.cant_assign_to_classvar(name, mx.context)
t = freshen_all_functions_type_vars(typ)
t = expand_self_type_if_needed(t, mx, var, original_itype)
t = expand_type_by_instance(t, itype)
t = expand_type_by_instance(
t, itype, use_variance=CONTRAVARIANT if mx.is_lvalue else COVARIANT
)
freeze_all_type_vars(t)
result = t
typ = get_proper_type(typ)
Expand Down
86 changes: 69 additions & 17 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import Final, Iterable, Mapping, Sequence, TypeVar, cast, overload
from contextlib import contextmanager
from typing import Final, Generator, Iterable, Mapping, Sequence, TypeVar, cast, overload

from mypy.nodes import ARG_STAR, FakeInfo, Var
from mypy.nodes import ARG_STAR, CONTRAVARIANT, COVARIANT, FakeInfo, Var, INVARIANT
from mypy.state import state
from mypy.types import (
ANY_STRATEGY,
Expand Down Expand Up @@ -38,6 +39,7 @@
UninhabitedType,
UnionType,
UnpackType,
VarianceModifier,
flatten_nested_unions,
get_proper_type,
split_with_prefix_and_suffix,
Expand All @@ -53,37 +55,49 @@


@overload
def expand_type(typ: CallableType, env: Mapping[TypeVarId, Type]) -> CallableType: ...
def expand_type(
typ: CallableType, env: Mapping[TypeVarId, Type], *, variance: int | None = ...
) -> CallableType: ...


@overload
def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType: ...
def expand_type(
typ: ProperType, env: Mapping[TypeVarId, Type], *, variance: int | None = ...
) -> ProperType: ...


@overload
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: ...
def expand_type(
typ: Type, env: Mapping[TypeVarId, Type], *, variance: int | None = ...
) -> Type: ...


def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
def expand_type(typ: Type, env: Mapping[TypeVarId, Type], *, variance=None) -> Type:
"""Substitute any type variable references in a type given by a type
environment.
"""
return typ.accept(ExpandTypeVisitor(env))
return typ.accept(ExpandTypeVisitor(env, variance=variance))


@overload
def expand_type_by_instance(typ: CallableType, instance: Instance) -> CallableType: ...
def expand_type_by_instance(
typ: CallableType, instance: Instance, *, use_variance: int | None = ...
) -> CallableType: ...


@overload
def expand_type_by_instance(typ: ProperType, instance: Instance) -> ProperType: ...
def expand_type_by_instance(
typ: ProperType, instance: Instance, *, use_variance: int | None = ...
) -> ProperType: ...


@overload
def expand_type_by_instance(typ: Type, instance: Instance) -> Type: ...
def expand_type_by_instance(
typ: Type, instance: Instance, *, use_variance: int | None = ...
) -> Type: ...


def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
def expand_type_by_instance(typ: Type, instance: Instance, use_variance=None) -> Type:
"""Substitute type variables in type using values from an Instance.
Type variables are considered to be bound by the class declaration."""
if not instance.args and not instance.type.has_type_var_tuple_type:
Expand All @@ -108,12 +122,11 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
else:
tvars = tuple(instance.type.defn.type_vars)
instance_args = instance.args

for binder, arg in zip(tvars, instance_args):
assert isinstance(binder, TypeVarLikeType)
variables[binder.id] = arg

return expand_type(typ, variables)
return expand_type(typ, variables, variance=use_variance)


F = TypeVar("F", bound=FunctionLike)
Expand Down Expand Up @@ -181,10 +194,28 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):

variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value

def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
def __init__(
self, variables: Mapping[TypeVarId, Type], *, variance: int | None = None
) -> None:
super().__init__()
self.variables = variables
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}
self.variance = variance
self.using_variance: int | None = None

@contextmanager
def in_variance(self) -> Generator[None]:
using_variance = self.using_variance
self.using_variance = CONTRAVARIANT
yield
self.using_variance = using_variance

@contextmanager
def out_variance(self) -> Generator[None]:
using_variance = self.using_variance
self.using_variance = COVARIANT
yield
self.using_variance = using_variance

def visit_unbound_type(self, t: UnboundType) -> Type:
return t
Expand Down Expand Up @@ -238,6 +269,19 @@ def visit_type_var(self, t: TypeVarType) -> Type:
if t.id.is_self():
t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
repl = self.variables.get(t.id, t)
use_site_variance = repl.variance if isinstance(repl, VarianceModifier) else None
positional_variance = self.using_variance or self.variance
if (
positional_variance is not None
and use_site_variance is not None
and use_site_variance is not INVARIANT
and positional_variance != use_site_variance
):
repl = (
t.upper_bound.accept(self)
if positional_variance == COVARIANT
else UninhabitedType()
)
if isinstance(repl, ProperType) and isinstance(repl, Instance):
# TODO: do we really need to do this?
# If I try to remove this special-casing ~40 tests fail on reveal_type().
Expand Down Expand Up @@ -414,10 +458,15 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
needs_normalization = True
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
else:
arg_types = self.expand_types(t.arg_types)
with self.in_variance():
arg_types = self.expand_types(t.arg_types)
with self.out_variance():
ret_type = t.ret_type.accept(self)
if isinstance(ret_type, VarianceModifier):
ret_type = ret_type.value
expanded = t.copy_modified(
arg_types=arg_types,
ret_type=t.ret_type.accept(self),
ret_type=ret_type,
type_guard=t.type_guard and cast(TypeGuardType, t.type_guard.accept(self)),
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
)
Expand Down Expand Up @@ -538,7 +587,10 @@ def visit_typeguard_type(self, t: TypeGuardType) -> Type:
def expand_types(self, types: Iterable[Type]) -> list[Type]:
a: list[Type] = []
for t in types:
a.append(t.accept(self))
typ = t.accept(self)
if isinstance(typ, VarianceModifier):
typ = typ.value
a.append(typ)
return a


Expand Down
4 changes: 2 additions & 2 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
)
FORMAT_REQUIRES_MAPPING: Final = "Format requires a mapping"
RETURN_TYPE_CANNOT_BE_CONTRAVARIANT: Final = ErrorMessage(
"This usage of this contravariant type variable is unsafe as a return type.",
"This usage of this contravariant type variable is unsafe as a return type",
codes.UNSAFE_VARIANCE,
)
FUNCTION_PARAMETER_CANNOT_BE_COVARIANT: Final = ErrorMessage(
"This usage of this covariant type variable is unsafe as an input parameter.",
"This usage of this covariant type variable is unsafe as an input parameter",
codes.UNSAFE_VARIANCE,
)
UNSAFE_VARIANCE_NOTE = ErrorMessage(
Expand Down
4 changes: 4 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
UninhabitedType,
UnionType,
UnpackType,
VarianceModifier,
flatten_nested_unions,
get_proper_type,
get_proper_types,
Expand Down Expand Up @@ -2676,6 +2677,9 @@ def format_literal_value(typ: LiteralType) -> str:
type_str += f"[{format_list(typ.args)}]"
return type_str

if isinstance(typ, VarianceModifier):
return typ.render(format)

# TODO: always mention type alias names in errors.
typ = get_proper_type(typ)

Expand Down
1 change: 1 addition & 0 deletions mypy/plugins/proper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def is_special_target(right: ProperType) -> bool:
"mypy.types.DeletedType",
"mypy.types.RequiredType",
"mypy.types.ReadOnlyType",
"mypy.types.VarianceModifier",
):
# Special case: these are not valid targets for a type alias and thus safe.
# TODO: introduce a SyntheticType base to simplify this?
Expand Down
Loading

0 comments on commit d7ca220

Please sign in to comment.