Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can understand functools.total_ordering #7831

Merged
merged 4 commits into from
Apr 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get_class_decorator_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
from mypy.plugins import attrs
from mypy.plugins import dataclasses
from mypy.plugins import functools

if fullname in attrs.attr_class_makers:
return attrs.attr_class_maker_callback
Expand All @@ -114,6 +115,9 @@ def get_class_decorator_hook(self, fullname: str
)
elif fullname in dataclasses.dataclass_makers:
return dataclasses.dataclass_class_maker_callback
elif fullname in functools.functools_total_ordering_makers:
return functools.functools_total_ordering_maker_callback

return None


Expand Down
105 changes: 105 additions & 0 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Plugin for supporting the functools standard library module."""
from typing import Dict, NamedTuple, Optional

import mypy.plugin
from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR2, Argument, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType, CallableType, get_proper_type, Type, TypeOfAny, UnboundType


functools_total_ordering_makers = {
'functools.total_ordering',
}

_ORDERING_METHODS = {
'__lt__',
'__le__',
'__gt__',
'__ge__',
}


_MethodInfo = NamedTuple('_MethodInfo', [('is_static', bool), ('type', CallableType)])


def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext,
auto_attribs_default: bool = False) -> None:
"""Add dunder methods to classes decorated with functools.total_ordering."""
if ctx.api.options.python_version < (3,):
ctx.api.fail('"functools.total_ordering" is not supported in Python 2', ctx.reason)
return

comparison_methods = _analyze_class(ctx)
if not comparison_methods:
ctx.api.fail(
'No ordering operation defined when using "functools.total_ordering": < > <= >=',
ctx.reason)
return

# prefer __lt__ to __le__ to __gt__ to __ge__
root = max(comparison_methods, key=lambda k: (comparison_methods[k] is None, k))
root_method = comparison_methods[root]
if not root_method:
# None of the defined comparison methods can be analysed
return

other_type = _find_other_type(root_method)
bool_type = ctx.api.named_type('__builtins__.bool')
ret_type = bool_type # type: Type
if root_method.type.ret_type != ctx.api.named_type('__builtins__.bool'):
proper_ret_type = get_proper_type(root_method.type.ret_type)
if not (isinstance(proper_ret_type, UnboundType)
and proper_ret_type.name.split('.')[-1] == 'bool'):
ret_type = AnyType(TypeOfAny.implementation_artifact)
for additional_op in _ORDERING_METHODS:
# Either the method is not implemented
# or has an unknown signature that we can now extrapolate.
if not comparison_methods.get(additional_op):
args = [Argument(Var('other', other_type), other_type, None, ARG_POS)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If other_type is None, the argument type shoud probably be AnyType instead of None.

add_method_to_class(ctx.api, ctx.cls, additional_op, args, ret_type)


def _find_other_type(method: _MethodInfo) -> Type:
"""Find the type of the ``other`` argument in a comparison method."""
first_arg_pos = 0 if method.is_static else 1
cur_pos_arg = 0
other_arg = None
for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types):
if arg_kind in (ARG_POS, ARG_OPT):
if cur_pos_arg == first_arg_pos:
other_arg = arg_type
break

cur_pos_arg += 1
elif arg_kind != ARG_STAR2:
other_arg = arg_type
break

if other_arg is None:
return AnyType(TypeOfAny.implementation_artifact)

return other_arg


def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> Dict[str, Optional[_MethodInfo]]:
"""Analyze the class body, its parents, and return the comparison methods found."""
# Traverse the MRO and collect ordering methods.
comparison_methods = {} # type: Dict[str, Optional[_MethodInfo]]
# Skip object because total_ordering does not use methods from object
for cls in ctx.cls.info.mro[:-1]:
for name in _ORDERING_METHODS:
if name in cls.names and name not in comparison_methods:
node = cls.names[name].node
if isinstance(node, FuncItem) and isinstance(node.type, CallableType):
comparison_methods[name] = _MethodInfo(node.is_static, node.type)
continue

if isinstance(node, Var):
proper_type = get_proper_type(node.type)
if isinstance(proper_type, CallableType):
comparison_methods[name] = _MethodInfo(node.is_staticmethod, proper_type)
continue

comparison_methods[name] = None

return comparison_methods
1 change: 1 addition & 0 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
'check-parameter-specification.test',
'check-generic-alias.test',
'check-typeguard.test',
'check-functools.test',
]

# Tests that use Python 3.8-only AST features (like expression-scoped ignores):
Expand Down
109 changes: 109 additions & 0 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
[case testTotalOrderingEqLt]
from functools import total_ordering

@total_ordering
class Ord:
def __eq__(self, other: object) -> bool:
return False

def __lt__(self, other: "Ord") -> bool:
return False

reveal_type(Ord() < Ord()) # N: Revealed type is "builtins.bool"
reveal_type(Ord() <= Ord()) # N: Revealed type is "builtins.bool"
reveal_type(Ord() == Ord()) # N: Revealed type is "builtins.bool"
reveal_type(Ord() > Ord()) # N: Revealed type is "builtins.bool"
reveal_type(Ord() >= Ord()) # N: Revealed type is "builtins.bool"

Ord() < 1 # E: Unsupported operand types for < ("Ord" and "int")
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
Ord() == 1
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add reveal_type(Ord() <= Ord()) or similar to verify the type of the operation is as expected. Maybe use a non-bool return type in some test case.

[builtins fixtures/ops.pyi]
[builtins fixtures/dict.pyi]

[case testTotalOrderingLambda]
from functools import total_ordering
from typing import Any, Callable

@total_ordering
class Ord:
__eq__: Callable[[Any, object], bool] = lambda self, other: False
__lt__: Callable[[Any, "Ord"], bool] = lambda self, other: False

reveal_type(Ord() < Ord()) # N: Revealed type is "builtins.bool"
reveal_type(Ord() <= Ord()) # N: Revealed type is "builtins.bool"
reveal_type(Ord() == Ord()) # N: Revealed type is "builtins.bool"
reveal_type(Ord() > Ord()) # N: Revealed type is "builtins.bool"
reveal_type(Ord() >= Ord()) # N: Revealed type is "builtins.bool"

Ord() < 1 # E: Argument 1 has incompatible type "int"; expected "Ord"
Ord() <= 1 # E: Unsupported operand types for <= ("Ord" and "int")
Ord() == 1
Ord() > 1 # E: Unsupported operand types for > ("Ord" and "int")
Ord() >= 1 # E: Unsupported operand types for >= ("Ord" and "int")
[builtins fixtures/ops.pyi]
[builtins fixtures/dict.pyi]

[case testTotalOrderingNonCallable]
from functools import total_ordering

@total_ordering
class Ord(object):
def __eq__(self, other: object) -> bool:
return False

__lt__ = 5

Ord() <= Ord() # E: Unsupported left operand type for <= ("Ord")
Ord() > Ord() # E: "int" not callable
Ord() >= Ord() # E: Unsupported left operand type for >= ("Ord")

[builtins fixtures/ops.pyi]
[builtins fixtures/dict.pyi]

[case testTotalOrderingReturnNotBool]
from functools import total_ordering

@total_ordering
class Ord:
def __eq__(self, other: object) -> bool:
return False

def __lt__(self, other: "Ord") -> str:
return "blah"

reveal_type(Ord() < Ord()) # N: Revealed type is "builtins.str"
reveal_type(Ord() <= Ord()) # N: Revealed type is "Any"
reveal_type(Ord() == Ord()) # N: Revealed type is "builtins.bool"
reveal_type(Ord() > Ord()) # N: Revealed type is "Any"
reveal_type(Ord() >= Ord()) # N: Revealed type is "Any"

[builtins fixtures/ops.pyi]
[builtins fixtures/dict.pyi]

[case testTotalOrderingAllowsAny]
from functools import total_ordering

@total_ordering
class Ord:
def __eq__(self, other):
return False

def __gt__(self, other):
return False

reveal_type(Ord() < Ord()) # N: Revealed type is "Any"
Ord() <= Ord() # E: Unsupported left operand type for <= ("Ord")
reveal_type(Ord() == Ord()) # N: Revealed type is "Any"
reveal_type(Ord() > Ord()) # N: Revealed type is "Any"
Ord() >= Ord() # E: Unsupported left operand type for >= ("Ord")

Ord() < 1 # E: Unsupported left operand type for < ("Ord")
Ord() <= 1 # E: Unsupported left operand type for <= ("Ord")
Ord() == 1
Ord() > 1
Ord() >= 1 # E: Unsupported left operand type for >= ("Ord")
[builtins fixtures/ops.pyi]
[builtins fixtures/dict.pyi]