Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plugin to typecheck functools.total_ordering #7848

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ def fail(self, msg: str, ctx: Context, serious: bool = False, *,
"""Emit an error message at given location."""
raise NotImplementedError

@abstractmethod
def note(self, msg: str, ctx: Context, code: Optional[ErrorCode] = None) -> None:
"""Emit a warning message at given location."""
raise NotImplementedError

@abstractmethod
def anal_type(self, t: Type, *,
tvar_scope: Optional[TypeVarScope] = None,
Expand Down
3 changes: 3 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def get_class_decorator_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
from mypy.plugins import attrs
from mypy.plugins import dataclasses
from mypy.plugins import total_ordering

if fullname in attrs.attr_class_makers:
return attrs.attr_class_maker_callback
Expand All @@ -100,6 +101,8 @@ def get_class_decorator_hook(self, fullname: str
)
elif fullname in dataclasses.dataclass_makers:
return dataclasses.dataclass_class_maker_callback
elif fullname == total_ordering.total_ordering_fullname:
return total_ordering.total_ordering_callback
return None


Expand Down
59 changes: 59 additions & 0 deletions mypy/plugins/total_ordering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Dict, Tuple
from typing_extensions import Final

from mypy.nodes import (Argument, TypeVarExpr, SymbolTableNode, Var, ARG_POS, MDEF)
from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method
from mypy.types import (TypeVarDef, TypeVarType)

total_ordering_fullname = "functools.total_ordering"


SELF_TVAR_NAME = '_AT' # type: Final


def _validate_total_ordering(ctx: ClassDefContext) -> None:
names = dict() # type: Dict[str, str]
for info in ctx.cls.info.mro:
for name in info.names:
if name not in names:
names[name] = info.defn.fullname

if '__eq__' not in names:
ctx.api.fail("Classes with total_ordering must define __eq__", ctx.cls)
elif names['__eq__'] == "builtins.object":
ctx.api.note("Combining inherited object.__eq__ with total_ordering "
"is unlikely to be correct", ctx.cls)
if not ('__lt__' in names or '__le__' in names or
'__gt__' in names or '__ge__' in names):
ctx.api.fail("Classes with total_ordering must define one of "
"__{lt, gt, le, ge}__", ctx.cls)


def _create_typevar_on_class(ctx: ClassDefContext) -> Tuple[TypeVarDef, TypeVarType]:
object_type = ctx.api.named_type('__builtins__.object')
tvar_name = SELF_TVAR_NAME
tvar_fullname = ctx.cls.info.fullname() + '.' + SELF_TVAR_NAME

tvd = TypeVarDef(tvar_name, tvar_fullname, -1, [], object_type)
tvd_type = TypeVarType(tvd)

self_tvar_expr = TypeVarExpr(tvar_name, tvar_fullname, [], object_type)
ctx.cls.info.names[tvar_name] = SymbolTableNode(MDEF, self_tvar_expr)

return tvd, tvd_type


def total_ordering_callback(ctx: ClassDefContext) -> None:
"""Generate the missing ordering methods for this class."""
_validate_total_ordering(ctx)

bool_type = ctx.api.named_type('__builtins__.bool')
tvd, tvd_type = _create_typevar_on_class(ctx)

args = [Argument(Var('other', tvd_type), tvd_type, None, ARG_POS)]

existing_names = set(ctx.cls.info.names)
for method in ('__lt__', '__le__', '__gt__', '__ge__'):
if method not in existing_names:
add_method(ctx, method, args, bool_type, self_type=tvd_type, tvar_def=tvd)
1 change: 1 addition & 0 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
'check-reports.test',
'check-errorcodes.test',
'check-annotated.test',
'check-total-ordering.test',
]

# Tests that use Python 3.8-only AST features (like expression-scoped ignores):
Expand Down
52 changes: 52 additions & 0 deletions test-data/unit/check-total-ordering.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
[case testTotalOrderingInference]
from functools import total_ordering
from typing import Any

@total_ordering
class Ord:
def __eq__(self, other: Any) -> bool:
return False

def __lt__(self, other: "Ord") -> bool:
return False

Ord() <= Ord()

[builtins fixtures/dict.pyi]
[case testTotalOrderingObjectEq]
from functools import total_ordering
from typing import Any

@total_ordering
class Ord: # N: Combining inherited object.__eq__ with total_ordering is unlikely to be correct
def __lt__(self, other: "Ord") -> bool:
return False

[builtins fixtures/dict.pyi]
[case testTotalOrderingNoLt]
from functools import total_ordering
from typing import Any

@total_ordering
class Ord: # E: Classes with total_ordering must define one of __{lt, gt, le, ge}__
def __eq__(self, other: Any) -> bool:
return False

[builtins fixtures/dict.pyi]

[case testTotalOrderingInherited]
from functools import total_ordering
from typing import Any

class Super:
def __eq__(self, other: Any) -> bool:
return False

@total_ordering
class Ord(Super):
def __lt__(self, other: "Ord") -> bool:
return False

Ord() <= Ord()

[builtins fixtures/dict.pyi]