From 604b08d5ba7c1b6e3d2f2ddd50dcf020f7e2794a Mon Sep 17 00:00:00 2001 From: Anton Agestam Date: Sat, 21 Sep 2024 20:13:41 +0200 Subject: [PATCH] Use get_protocol_members in protocol checking (#490) This changes `check_protocol()` to make use of `get_protocol_members` from typing-extensions. This allows removing an existing hard-coded exclusion list for attributes existing on Protocol, but also handles the cases `__orig_bases__` and `__weakref__` that was breaking when checking intersecting protocols (a subclass of two or more protocols). This has the effect of turning some false positives into true negatives, but it also leaves some false negatives. To make that clear, xfail test cases are added for the resulting false negatives. --- docs/versionhistory.rst | 5 +++ src/typeguard/_checkers.py | 22 ++++------ tests/test_checkers.py | 83 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 14 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index afedbaa..64e33d8 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -4,6 +4,11 @@ Version history This library adheres to `Semantic Versioning 2.0 `_. +**UNRELEASED** + +- Fixed basic support for intersection protocols + (`#490 `_; PR by @antonagestam) + **4.3.0** (2024-05-27) - Added support for checking against static protocols diff --git a/src/typeguard/_checkers.py b/src/typeguard/_checkers.py index 485bcb7..52ec2b8 100644 --- a/src/typeguard/_checkers.py +++ b/src/typeguard/_checkers.py @@ -654,19 +654,13 @@ def check_protocol( else: return - # Collect a set of methods and non-method attributes present in the protocol - ignored_attrs = set(dir(typing.Protocol)) | { - "__annotations__", - "__non_callable_proto_members__", - } expected_methods: dict[str, tuple[Any, Any]] = {} expected_noncallable_members: dict[str, Any] = {} - for attrname in dir(origin_type): - # Skip attributes present in typing.Protocol - if attrname in ignored_attrs: - continue + origin_annotations = typing.get_type_hints(origin_type) + + for attrname in typing_extensions.get_protocol_members(origin_type): + member = getattr(origin_type, attrname, None) - member = getattr(origin_type, attrname) if callable(member): signature = inspect.signature(member) argtypes = [ @@ -681,10 +675,10 @@ def check_protocol( ) expected_methods[attrname] = argtypes, return_annotation else: - expected_noncallable_members[attrname] = member - - for attrname, annotation in typing.get_type_hints(origin_type).items(): - expected_noncallable_members[attrname] = annotation + try: + expected_noncallable_members[attrname] = origin_annotations[attrname] + except KeyError: + expected_noncallable_members[attrname] = member subject_annotations = typing.get_type_hints(subject) diff --git a/tests/test_checkers.py b/tests/test_checkers.py index f8b21d6..d9237a9 100644 --- a/tests/test_checkers.py +++ b/tests/test_checkers.py @@ -16,14 +16,17 @@ Dict, ForwardRef, FrozenSet, + Iterable, Iterator, List, Literal, Mapping, MutableMapping, Optional, + Protocol, Sequence, Set, + Sized, TextIO, Tuple, Type, @@ -995,6 +998,86 @@ def test_text_real_file(self, tmp_path: Path): check_type(f, TextIO) +class TestIntersectingProtocol: + SIT = TypeVar("SIT", covariant=True) + + class SizedIterable( + Sized, + Iterable[SIT], + Protocol[SIT], + ): ... + + @pytest.mark.parametrize( + "subject, predicate_type", + ( + pytest.param( + (), + SizedIterable, + id="empty_tuple_unspecialized", + ), + pytest.param( + range(2), + SizedIterable, + id="range", + ), + pytest.param( + (), + SizedIterable[int], + id="empty_tuple_int_specialized", + ), + pytest.param( + (1, 2, 3), + SizedIterable[int], + id="tuple_int_specialized", + ), + pytest.param( + ("1", "2", "3"), + SizedIterable[str], + id="tuple_str_specialized", + ), + ), + ) + def test_valid_member_passes(self, subject: object, predicate_type: type) -> None: + for _ in range(2): # Makes sure that the cache is also exercised + check_type(subject, predicate_type) + + xfail_nested_protocol_checks = pytest.mark.xfail( + reason="false negative due to missing support for nested protocol checks", + ) + + @pytest.mark.parametrize( + "subject, predicate_type", + ( + pytest.param( + (1 for _ in ()), + SizedIterable, + id="generator", + ), + pytest.param( + range(2), + SizedIterable[str], + marks=xfail_nested_protocol_checks, + id="range_str_specialized", + ), + pytest.param( + (1, 2, 3), + SizedIterable[str], + marks=xfail_nested_protocol_checks, + id="int_tuple_str_specialized", + ), + pytest.param( + ("1", "2", "3"), + SizedIterable[int], + marks=xfail_nested_protocol_checks, + id="str_tuple_int_specialized", + ), + ), + ) + def test_raises_for_non_member(self, subject: object, predicate_type: type) -> None: + with pytest.raises(TypeCheckError): + check_type(subject, predicate_type) + + @pytest.mark.parametrize( "instantiate, annotation", [