diff --git a/mypy/plugin.py b/mypy/plugin.py index 31bc335b069b..32140d9b3109 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -259,6 +259,11 @@ def fail(self, msg: str, ctx: Context, serious: bool = False, *, """Emit an error message at given location.""" raise NotImplementedError + @abstractmethod + def note(self, msg: str, ctx: Context, code: Optional[ErrorCode] = None) -> None: + """Emit a warning message at given location.""" + raise NotImplementedError + @abstractmethod def anal_type(self, t: Type, *, tvar_scope: Optional[TypeVarScope] = None, diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index ca9d3baad3bb..afb703e1f00c 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -90,6 +90,7 @@ def get_class_decorator_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: from mypy.plugins import attrs from mypy.plugins import dataclasses + from mypy.plugins import total_ordering if fullname in attrs.attr_class_makers: return attrs.attr_class_maker_callback @@ -100,6 +101,8 @@ def get_class_decorator_hook(self, fullname: str ) elif fullname in dataclasses.dataclass_makers: return dataclasses.dataclass_class_maker_callback + elif fullname == total_ordering.total_ordering_fullname: + return total_ordering.total_ordering_callback return None diff --git a/mypy/plugins/total_ordering.py b/mypy/plugins/total_ordering.py new file mode 100644 index 000000000000..fa67cbc07f17 --- /dev/null +++ b/mypy/plugins/total_ordering.py @@ -0,0 +1,59 @@ +from typing import Dict, Tuple +from typing_extensions import Final + +from mypy.nodes import (Argument, TypeVarExpr, SymbolTableNode, Var, ARG_POS, MDEF) +from mypy.plugin import ClassDefContext +from mypy.plugins.common import add_method +from mypy.types import (TypeVarDef, TypeVarType) + +total_ordering_fullname = "functools.total_ordering" + + +SELF_TVAR_NAME = '_AT' # type: Final + + +def _validate_total_ordering(ctx: ClassDefContext) -> None: + names = dict() # type: Dict[str, str] + for info in ctx.cls.info.mro: + for name in info.names: + if name not in names: + names[name] = info.defn.fullname + + if '__eq__' not in names: + ctx.api.fail("Classes with total_ordering must define __eq__", ctx.cls) + elif names['__eq__'] == "builtins.object": + ctx.api.note("Combining inherited object.__eq__ with total_ordering " + "is unlikely to be correct", ctx.cls) + if not ('__lt__' in names or '__le__' in names or + '__gt__' in names or '__ge__' in names): + ctx.api.fail("Classes with total_ordering must define one of " + "__{lt, gt, le, ge}__", ctx.cls) + + +def _create_typevar_on_class(ctx: ClassDefContext) -> Tuple[TypeVarDef, TypeVarType]: + object_type = ctx.api.named_type('__builtins__.object') + tvar_name = SELF_TVAR_NAME + tvar_fullname = ctx.cls.info.fullname() + '.' + SELF_TVAR_NAME + + tvd = TypeVarDef(tvar_name, tvar_fullname, -1, [], object_type) + tvd_type = TypeVarType(tvd) + + self_tvar_expr = TypeVarExpr(tvar_name, tvar_fullname, [], object_type) + ctx.cls.info.names[tvar_name] = SymbolTableNode(MDEF, self_tvar_expr) + + return tvd, tvd_type + + +def total_ordering_callback(ctx: ClassDefContext) -> None: + """Generate the missing ordering methods for this class.""" + _validate_total_ordering(ctx) + + bool_type = ctx.api.named_type('__builtins__.bool') + tvd, tvd_type = _create_typevar_on_class(ctx) + + args = [Argument(Var('other', tvd_type), tvd_type, None, ARG_POS)] + + existing_names = set(ctx.cls.info.names) + for method in ('__lt__', '__le__', '__gt__', '__ge__'): + if method not in existing_names: + add_method(ctx, method, args, bool_type, self_type=tvd_type, tvar_def=tvd) diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 5fd5405ec4e8..d6cf4f61c2cb 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -88,6 +88,7 @@ 'check-reports.test', 'check-errorcodes.test', 'check-annotated.test', + 'check-total-ordering.test', ] # Tests that use Python 3.8-only AST features (like expression-scoped ignores): diff --git a/test-data/unit/check-total-ordering.test b/test-data/unit/check-total-ordering.test new file mode 100644 index 000000000000..055812c03ae8 --- /dev/null +++ b/test-data/unit/check-total-ordering.test @@ -0,0 +1,52 @@ +[case testTotalOrderingInference] +from functools import total_ordering +from typing import Any + +@total_ordering +class Ord: + def __eq__(self, other: Any) -> bool: + return False + + def __lt__(self, other: "Ord") -> bool: + return False + +Ord() <= Ord() + +[builtins fixtures/dict.pyi] +[case testTotalOrderingObjectEq] +from functools import total_ordering +from typing import Any + +@total_ordering +class Ord: # N: Combining inherited object.__eq__ with total_ordering is unlikely to be correct + def __lt__(self, other: "Ord") -> bool: + return False + +[builtins fixtures/dict.pyi] +[case testTotalOrderingNoLt] +from functools import total_ordering +from typing import Any + +@total_ordering +class Ord: # E: Classes with total_ordering must define one of __{lt, gt, le, ge}__ + def __eq__(self, other: Any) -> bool: + return False + +[builtins fixtures/dict.pyi] + +[case testTotalOrderingInherited] +from functools import total_ordering +from typing import Any + +class Super: + def __eq__(self, other: Any) -> bool: + return False + +@total_ordering +class Ord(Super): + def __lt__(self, other: "Ord") -> bool: + return False + +Ord() <= Ord() + +[builtins fixtures/dict.pyi]