From b9cb7d54229ec144f09c5476e1e1dd543134fdee Mon Sep 17 00:00:00 2001 From: Aaron Ecay Date: Fri, 1 Nov 2019 15:33:20 +0000 Subject: [PATCH 1/6] Add plugin to typecheck functools.total_ordering --- mypy/plugins/default.py | 3 ++ mypy/plugins/total_ordering.py | 47 ++++++++++++++++++++++++ mypy/test/testcheck.py | 1 + test-data/unit/check-total-ordering.test | 35 ++++++++++++++++++ 4 files changed, 86 insertions(+) create mode 100644 mypy/plugins/total_ordering.py create mode 100644 test-data/unit/check-total-ordering.test diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index ca9d3baad3bb..afb703e1f00c 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -90,6 +90,7 @@ def get_class_decorator_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: from mypy.plugins import attrs from mypy.plugins import dataclasses + from mypy.plugins import total_ordering if fullname in attrs.attr_class_makers: return attrs.attr_class_maker_callback @@ -100,6 +101,8 @@ def get_class_decorator_hook(self, fullname: str ) elif fullname in dataclasses.dataclass_makers: return dataclasses.dataclass_class_maker_callback + elif fullname == total_ordering.total_ordering_fullname: + return total_ordering.total_ordering_callback return None diff --git a/mypy/plugins/total_ordering.py b/mypy/plugins/total_ordering.py new file mode 100644 index 000000000000..4e06caa3e7ee --- /dev/null +++ b/mypy/plugins/total_ordering.py @@ -0,0 +1,47 @@ +from mypy.nodes import (Argument, TypeVarExpr, SymbolTableNode, Var, ARG_POS, MDEF) +from mypy.plugin import ClassDefContext +from mypy.plugins.common import add_method +from mypy.types import (TypeVarDef, TypeVarType) + +total_ordering_fullname = "functools.total_ordering" + + +SELF_TVAR_NAME = '_AT' # type: Final + + +def _validate_total_ordering(ctx: ClassDefContext) -> None: + names = set(ctx.cls.info.names) + if '__eq__' not in names: + ctx.api.fail("Classes with total_ordering must define __eq__", ctx.cls) + if not ('__lt__' in names or '__le__' in names or + '__gt__' in names or '__ge__' in names): + ctx.api.fail("Classes with total_ordering must define one of __{lt, gt, le, ge}__", ctx.cls) + + +def _create_typevar_on_class(ctx: ClassDefContext) -> TypeVarType: + object_type = ctx.api.named_type('__builtins__.object') + tvar_name = SELF_TVAR_NAME + tvar_fullname = ctx.cls.info.fullname() + '.' + SELF_TVAR_NAME + + tvd = TypeVarDef(tvar_name, tvar_fullname, -1, [], object_type) + tvd_type = TypeVarType(tvd) + + self_tvar_expr = TypeVarExpr(tvar_name, tvar_fullname, [], object_type) + ctx.cls.info.names[tvar_name] = SymbolTableNode(MDEF, self_tvar_expr) + + return tvd, tvd_type + + +def total_ordering_callback(ctx: ClassDefContext) -> None: + """Generate the missing ordering methods for this class.""" + _validate_total_ordering(ctx) + + bool_type = ctx.api.named_type('__builtins__.bool') + tvd, tvd_type = _create_typevar_on_class(ctx) + + args = [Argument(Var('other', tvd_type), tvd_type, None, ARG_POS)] + + existing_names = set(ctx.cls.info.names) + for method in ('__lt__', '__le__', '__gt__', '__ge__'): + if method not in existing_names: + add_method(ctx, method, args, bool_type, self_type=tvd_type, tvar_def=tvd) diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 5fd5405ec4e8..d6cf4f61c2cb 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -88,6 +88,7 @@ 'check-reports.test', 'check-errorcodes.test', 'check-annotated.test', + 'check-total-ordering.test', ] # Tests that use Python 3.8-only AST features (like expression-scoped ignores): diff --git a/test-data/unit/check-total-ordering.test b/test-data/unit/check-total-ordering.test new file mode 100644 index 000000000000..d87b13034866 --- /dev/null +++ b/test-data/unit/check-total-ordering.test @@ -0,0 +1,35 @@ +[case testTotalOrderingInference] +from functools import total_ordering +from typing import Any + +@total_ordering +class Ord: + def __eq__(self, other: Any) -> bool: + return False + + def __lt__(self, other: "Ord") -> bool: + return False + +Ord() <= Ord() + +[builtins fixtures/dict.pyi] +[case testTotalOrderingNoEq] +from functools import total_ordering +from typing import Any + +@total_ordering +class Ord: # E: Classes with total_ordering must define __eq__ + def __lt__(self, other: "Ord") -> bool: + return False + +[builtins fixtures/dict.pyi] +[case testTotalOrderingNoLt] +from functools import total_ordering +from typing import Any + +@total_ordering +class Ord: # E: Classes with total_ordering must define one of __{lt, gt, le, ge}__ + def __eq__(self, other: Any) -> bool: + return False + +[builtins fixtures/dict.pyi] From 00ab9b3dfd20807f07a2e0cc50048ecfb783ccad Mon Sep 17 00:00:00 2001 From: Aaron Ecay Date: Fri, 1 Nov 2019 16:39:42 +0000 Subject: [PATCH 2/6] Count __eq__ etc methods from superclasses when validating total_ordering --- mypy/plugins/total_ordering.py | 5 ++++- test-data/unit/check-total-ordering.test | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/mypy/plugins/total_ordering.py b/mypy/plugins/total_ordering.py index 4e06caa3e7ee..649cadfd2c76 100644 --- a/mypy/plugins/total_ordering.py +++ b/mypy/plugins/total_ordering.py @@ -10,7 +10,10 @@ def _validate_total_ordering(ctx: ClassDefContext) -> None: - names = set(ctx.cls.info.names) + names = set() + for info in ctx.cls.info.mro: + names = names.union(info.names) + if '__eq__' not in names: ctx.api.fail("Classes with total_ordering must define __eq__", ctx.cls) if not ('__lt__' in names or '__le__' in names or diff --git a/test-data/unit/check-total-ordering.test b/test-data/unit/check-total-ordering.test index d87b13034866..aaf1412e23be 100644 --- a/test-data/unit/check-total-ordering.test +++ b/test-data/unit/check-total-ordering.test @@ -33,3 +33,20 @@ class Ord: # E: Classes with total_ordering must define one of __{lt, gt, le, g return False [builtins fixtures/dict.pyi] + +[case testTotalOrderingInherited] +from functools import total_ordering +from typing import Any + +class Super: + def __eq__(self, other: Any) -> bool: + return False + +@total_ordering +class Ord(Super): + def __lt__(self, other: "Ord") -> bool: + return False + +Ord() <= Ord() + +[builtins fixtures/dict.pyi] From e1f7142bf1672f007fd6eef7003b3062760b91fa Mon Sep 17 00:00:00 2001 From: Aaron Ecay Date: Sat, 2 Nov 2019 00:19:07 +0000 Subject: [PATCH 3/6] Warn if __eq__ is inherited from object Now that we consider inherited methods from superclasses, a class will never fail typechecking for lack of __eq__ (which is inherited from object). However, this method is unlikely to provide the correct semantics. --- mypy/plugins/total_ordering.py | 8 ++++++-- test-data/unit/check-total-ordering.test | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mypy/plugins/total_ordering.py b/mypy/plugins/total_ordering.py index 649cadfd2c76..fc0a97449607 100644 --- a/mypy/plugins/total_ordering.py +++ b/mypy/plugins/total_ordering.py @@ -10,12 +10,16 @@ def _validate_total_ordering(ctx: ClassDefContext) -> None: - names = set() + names = dict() for info in ctx.cls.info.mro: - names = names.union(info.names) + for name in info.names: + if name not in names: + names[name] = info.defn.fullname if '__eq__' not in names: ctx.api.fail("Classes with total_ordering must define __eq__", ctx.cls) + elif names['__eq__'] == "builtins.object": + ctx.api.note("Combining inherited object.__eq__ with total_ordering is unlikely to be correct", ctx.cls) if not ('__lt__' in names or '__le__' in names or '__gt__' in names or '__ge__' in names): ctx.api.fail("Classes with total_ordering must define one of __{lt, gt, le, ge}__", ctx.cls) diff --git a/test-data/unit/check-total-ordering.test b/test-data/unit/check-total-ordering.test index aaf1412e23be..055812c03ae8 100644 --- a/test-data/unit/check-total-ordering.test +++ b/test-data/unit/check-total-ordering.test @@ -13,12 +13,12 @@ class Ord: Ord() <= Ord() [builtins fixtures/dict.pyi] -[case testTotalOrderingNoEq] +[case testTotalOrderingObjectEq] from functools import total_ordering from typing import Any @total_ordering -class Ord: # E: Classes with total_ordering must define __eq__ +class Ord: # N: Combining inherited object.__eq__ with total_ordering is unlikely to be correct def __lt__(self, other: "Ord") -> bool: return False From d412c6471657a26845cb90b3f4eec7492d3748c2 Mon Sep 17 00:00:00 2001 From: Aaron Ecay Date: Sat, 2 Nov 2019 00:47:36 +0000 Subject: [PATCH 4/6] Make it typecheck --- mypy/plugin.py | 5 +++++ mypy/plugins/total_ordering.py | 7 +++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mypy/plugin.py b/mypy/plugin.py index 31bc335b069b..32140d9b3109 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -259,6 +259,11 @@ def fail(self, msg: str, ctx: Context, serious: bool = False, *, """Emit an error message at given location.""" raise NotImplementedError + @abstractmethod + def note(self, msg: str, ctx: Context, code: Optional[ErrorCode] = None) -> None: + """Emit a warning message at given location.""" + raise NotImplementedError + @abstractmethod def anal_type(self, t: Type, *, tvar_scope: Optional[TypeVarScope] = None, diff --git a/mypy/plugins/total_ordering.py b/mypy/plugins/total_ordering.py index fc0a97449607..dd62ea230f6d 100644 --- a/mypy/plugins/total_ordering.py +++ b/mypy/plugins/total_ordering.py @@ -1,3 +1,6 @@ +from typing import Dict, Tuple +from typing_extensions import Final + from mypy.nodes import (Argument, TypeVarExpr, SymbolTableNode, Var, ARG_POS, MDEF) from mypy.plugin import ClassDefContext from mypy.plugins.common import add_method @@ -10,7 +13,7 @@ def _validate_total_ordering(ctx: ClassDefContext) -> None: - names = dict() + names: Dict[str, str] = dict() for info in ctx.cls.info.mro: for name in info.names: if name not in names: @@ -25,7 +28,7 @@ def _validate_total_ordering(ctx: ClassDefContext) -> None: ctx.api.fail("Classes with total_ordering must define one of __{lt, gt, le, ge}__", ctx.cls) -def _create_typevar_on_class(ctx: ClassDefContext) -> TypeVarType: +def _create_typevar_on_class(ctx: ClassDefContext) -> Tuple[TypeVarDef, TypeVarType]: object_type = ctx.api.named_type('__builtins__.object') tvar_name = SELF_TVAR_NAME tvar_fullname = ctx.cls.info.fullname() + '.' + SELF_TVAR_NAME From 69d0b9ad44f11efb496074211b6afaf63a8ae535 Mon Sep 17 00:00:00 2001 From: Aaron Ecay Date: Sat, 2 Nov 2019 00:48:37 +0000 Subject: [PATCH 5/6] Fix flake8 warnings --- mypy/plugins/total_ordering.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mypy/plugins/total_ordering.py b/mypy/plugins/total_ordering.py index dd62ea230f6d..4491d790bbf5 100644 --- a/mypy/plugins/total_ordering.py +++ b/mypy/plugins/total_ordering.py @@ -22,10 +22,12 @@ def _validate_total_ordering(ctx: ClassDefContext) -> None: if '__eq__' not in names: ctx.api.fail("Classes with total_ordering must define __eq__", ctx.cls) elif names['__eq__'] == "builtins.object": - ctx.api.note("Combining inherited object.__eq__ with total_ordering is unlikely to be correct", ctx.cls) + ctx.api.note("Combining inherited object.__eq__ with total_ordering " + "is unlikely to be correct", ctx.cls) if not ('__lt__' in names or '__le__' in names or '__gt__' in names or '__ge__' in names): - ctx.api.fail("Classes with total_ordering must define one of __{lt, gt, le, ge}__", ctx.cls) + ctx.api.fail("Classes with total_ordering must define one of " + "__{lt, gt, le, ge}__", ctx.cls) def _create_typevar_on_class(ctx: ClassDefContext) -> Tuple[TypeVarDef, TypeVarType]: From 1fa4248441f68622d71a1cc736aa636cfaa8ba2f Mon Sep 17 00:00:00 2001 From: Aaron Ecay Date: Sat, 2 Nov 2019 14:29:25 +0000 Subject: [PATCH 6/6] Use type comment for 3.5 compatibility --- mypy/plugins/total_ordering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/plugins/total_ordering.py b/mypy/plugins/total_ordering.py index 4491d790bbf5..fa67cbc07f17 100644 --- a/mypy/plugins/total_ordering.py +++ b/mypy/plugins/total_ordering.py @@ -13,7 +13,7 @@ def _validate_total_ordering(ctx: ClassDefContext) -> None: - names: Dict[str, str] = dict() + names = dict() # type: Dict[str, str] for info in ctx.cls.info.mro: for name in info.names: if name not in names: