diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 67cb2b8c3f03..b2960ba0c7b7 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -48,6 +48,7 @@ _CONVERTED_COLLECTIONS = [ collections.abc.Set, collections.abc.MutableSet, + collections.abc.Collection, ] @@ -118,6 +119,10 @@ def _match_is_exactly_iterable(user_type): return getattr(user_type, '__origin__', None) is expected_origin +def _match_is_exactly_collection(user_type): + return getattr(user_type, '__origin__', None) is collections.abc.Collection + + def match_is_named_tuple(user_type): return ( _safe_issubclass(user_type, typing.Tuple) and @@ -322,6 +327,10 @@ def convert_to_beam_type(typ): match=_match_issubclass(typing.Iterator), arity=1, beam_type=typehints.Iterator), + _TypeMapEntry( + match=_match_is_exactly_collection, + arity=1, + beam_type=typehints.Collection), ] # Find the first matching entry. diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index 9c2762dff710..2e6db6a7733c 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -189,7 +189,15 @@ def test_convert_to_beam_type_with_collections_types(self): ( 'enum mutable set', collections.abc.MutableSet[_TestEnum], - typehints.Set[_TestEnum]) + typehints.Set[_TestEnum]), + ( + 'collection enum', + collections.abc.Collection[_TestEnum], + typehints.Collection[_TestEnum]), + ( + 'collection of tuples', + collections.abc.Collection[tuple[str, int]], + typehints.Collection[typehints.Tuple[str, int]]), ] for test_case in test_cases: diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 238bf8c321d6..5726a8a8ca92 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -82,6 +82,7 @@ 'Dict', 'Set', 'FrozenSet', + 'Collection', 'Iterable', 'Iterator', 'Generator', @@ -1017,6 +1018,62 @@ def __getitem__(self, type_param): FrozenSetTypeConstraint = FrozenSetHint.FrozenSetTypeConstraint +class CollectionHint(CompositeTypeHint): + """ A Collection type-hint. + + Collection[X] defines a type-hint for a collection of homogenous types. 'X' + may be either a built-in Python type or another nested TypeConstraint. + + This represents a collections.abc.Collection type, which implements + __contains__, __iter__, and __len__. This acts as a parent type for + sets but has fewer guarantees for mixins. + """ + class CollectionTypeConstraint(SequenceTypeConstraint): + def __init__(self, type_param): + super().__init__(type_param, abc.Collection) + + def __repr__(self): + return 'Collection[%s]' % repr(self.inner_type) + + @staticmethod + def _is_subclass_constraint(sub): + return isinstance( + sub, ( + CollectionTypeConstraint, + FrozenSetTypeConstraint, + SetTypeConstraint)) + + # TODO(https://github.com/apache/beam/issues/29135): allow for consistency + # with Mapping types + def _consistent_with_check_(self, sub): + if self._is_subclass_constraint(sub): + return is_consistent_with(sub.inner_type, self.inner_type) + elif isinstance(sub, TupleConstraint): + if not sub.tuple_types: + # The empty tuple is consistent with Iterator[T] for any T. + return True + # Each element in the hetrogenious tuple must be consistent with + # the collection type. + # E.g. Tuple[A, B] < Collection[C] if A < C and B < C. + return all( + is_consistent_with(elem, self.inner_type) + for elem in sub.tuple_types) + elif not isinstance(sub, TypeConstraint): + if getattr(sub, '__origin__', None) is not None and getattr( + sub, '__args__', None) is not None: + return issubclass(sub, abc.Collection) and is_consistent_with( + sub.__args__, self.inner_type) + return False + + def __getitem__(self, type_param): + validate_composite_type_param( + type_param, error_msg_prefix='Parameter to a Collection hint') + return self.CollectionTypeConstraint(type_param) + + +CollectionTypeConstraint = CollectionHint.CollectionTypeConstraint + + class IterableHint(CompositeTypeHint): """An Iterable type-hint. @@ -1187,6 +1244,7 @@ def __getitem__(self, type_params): Dict = DictHint() Set = SetHint() FrozenSet = FrozenSetHint() +Collection = CollectionHint() Iterable = IterableHint() Iterator = IteratorHint() Generator = GeneratorHint() diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 7f8c322f9f40..1d938edcc24b 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -865,6 +865,33 @@ class FrozenSetHintTestCase(BaseSetHintTest.CommonTests): string_type = 'FrozenSet' +class CollectionHintTestCase(TypeHintTestCase): + def test_type_constraint_compatibility(self): + self.assertCompatible(typehints.Collection[int], typehints.Set[int]) + self.assertCompatible(typehints.Iterable[int], typehints.Collection[int]) + self.assertCompatible(typehints.Collection[int], typehints.FrozenSet[int]) + self.assertCompatible( + typehints.Collection[typehints.Any], typehints.Collection[int]) + self.assertCompatible(typehints.Collection[int], typehints.Tuple[int]) + self.assertCompatible(typehints.Any, typehints.Collection[str]) + + def test_one_way_compatibility(self): + self.assertNotCompatible(typehints.Set[int], typehints.Collection[int]) + self.assertNotCompatible( + typehints.FrozenSet[int], typehints.Collection[int]) + self.assertNotCompatible(typehints.Tuple[int], typehints.Collection[int]) + self.assertNotCompatible(typehints.Collection[int], typehints.Iterable[int]) + + def test_getitem_invalid_composite_type_param(self): + with self.assertRaises(TypeError) as e: + typehints.Collection[5] + self.assertEqual( + 'Parameter to a Collection hint must be a ' + 'non-sequence, a type, or a TypeConstraint. 5 is ' + 'an instance of int.', + e.exception.args[0]) + + class IterableHintTestCase(TypeHintTestCase): def test_getitem_invalid_composite_type_param(self): with self.assertRaises(TypeError) as e: @@ -893,6 +920,7 @@ def test_compatibility(self): self.assertCompatible( typehints.Iterable[typehints.Any], typehints.List[typehints.Tuple[int, bool]]) + self.assertCompatible(typehints.Iterable[int], typehints.Collection[int]) def test_tuple_compatibility(self): self.assertCompatible(typehints.Iterable[int], typehints.Tuple[int, ...])