Skip to content

Commit

Permalink
Better handling of generic functions in partial plugin (#17925)
Browse files Browse the repository at this point in the history
Fixes #17411

The fix is that we remove type variables that can never be inferred from
the initial `check_call()` call. Actual diff is tiny, I just moved a
bunch of code, since I need formal to actual mapping sooner now.
  • Loading branch information
ilevkivskyi authored Oct 14, 2024
1 parent c32d11e commit 1a074b6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 25 deletions.
62 changes: 39 additions & 23 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.typeops import get_all_type_vars
from mypy.types import (
AnyType,
CallableType,
Instance,
Overloaded,
Type,
TypeOfAny,
TypeVarType,
UnboundType,
UnionType,
get_proper_type,
Expand Down Expand Up @@ -164,21 +166,6 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
ctx.api.type_context[-1] = None
wrapped_return = False

defaulted = fn_type.copy_modified(
arg_kinds=[
(
ArgKind.ARG_OPT
if k == ArgKind.ARG_POS
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
],
ret_type=ret_type,
)
if defaulted.line < 0:
# Make up a line number if we don't have one
defaulted.set_line(ctx.default_return_type)

# Flatten actual to formal mapping, since this is what check_call() expects.
actual_args = []
actual_arg_kinds = []
Expand All @@ -199,6 +186,43 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
actual_arg_names.append(ctx.arg_names[i][j])
actual_types.append(ctx.arg_types[i][j])

formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
actual_names=actual_arg_names,
formal_kinds=fn_type.arg_kinds,
formal_names=fn_type.arg_names,
actual_arg_type=lambda i: actual_types[i],
)

# We need to remove any type variables that appear only in formals that have
# no actuals, to avoid eagerly binding them in check_call() below.
can_infer_ids = set()
for i, arg_type in enumerate(fn_type.arg_types):
if not formal_to_actual[i]:
continue
can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})

defaulted = fn_type.copy_modified(
arg_kinds=[
(
ArgKind.ARG_OPT
if k == ArgKind.ARG_POS
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
],
ret_type=ret_type,
variables=[
tv
for tv in fn_type.variables
# Keep TypeVarTuple/ParamSpec to avoid spurious errors on empty args.
if tv.id in can_infer_ids or not isinstance(tv, TypeVarType)
],
)
if defaulted.line < 0:
# Make up a line number if we don't have one
defaulted.set_line(ctx.default_return_type)

# Create a valid context for various ad-hoc inspections in check_call().
call_expr = CallExpr(
callee=ctx.args[0][0],
Expand Down Expand Up @@ -231,14 +255,6 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
return ctx.default_return_type
bound = bound.copy_modified(ret_type=ret_type.args[0])

formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
actual_names=actual_arg_names,
formal_kinds=fn_type.arg_kinds,
formal_names=fn_type.arg_names,
actual_arg_type=lambda i: actual_types[i],
)

partial_kinds = []
partial_types = []
partial_names = []
Expand Down
31 changes: 29 additions & 2 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,6 @@ def bar(f: S) -> S:
return f
[builtins fixtures/primitives.pyi]


[case testFunctoolsPartialAbstractType]
# flags: --python-version 3.9
from abc import ABC, abstractmethod
Expand All @@ -597,7 +596,6 @@ def f2() -> None:
partial_cls() # E: Cannot instantiate abstract class "A" with abstract attribute "method"
[builtins fixtures/tuple.pyi]


[case testFunctoolsPartialSelfType]
from functools import partial
from typing_extensions import Self
Expand All @@ -610,3 +608,32 @@ class A:
factory = partial(cls, ts=0)
return factory(msg=msg)
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialTypeVarValues]
from functools import partial
from typing import TypeVar

T = TypeVar("T", int, str)

def f(x: int, y: T) -> T:
return y

def g(x: T, y: int) -> T:
return x

def h(x: T, y: T) -> T:
return x

fp = partial(f, 1)
reveal_type(fp(1)) # N: Revealed type is "builtins.int"
reveal_type(fp("a")) # N: Revealed type is "builtins.str"
fp(object()) # E: Value of type variable "T" of "f" cannot be "object"

gp = partial(g, 1)
reveal_type(gp(1)) # N: Revealed type is "builtins.int"
gp("a") # E: Argument 1 to "g" has incompatible type "str"; expected "int"

hp = partial(h, 1)
reveal_type(hp(1)) # N: Revealed type is "builtins.int"
hp("a") # E: Argument 1 to "h" has incompatible type "str"; expected "int"
[builtins fixtures/tuple.pyi]

0 comments on commit 1a074b6

Please sign in to comment.