diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9a95000..2abafed 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: additional_dependencies: [pytest, typing-extensions] types: [python] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.361 + rev: v1.1.362 hooks: - id: pyright additional_dependencies: [pytest, typing-extensions] diff --git a/aiostream/core.py b/aiostream/core.py index 5bea582..0cb9469 100644 --- a/aiostream/core.py +++ b/aiostream/core.py @@ -280,7 +280,19 @@ def pipe( ) -> Callable[[AsyncIterable[A]], Stream[T]]: ... -# Operator decorator +class SourcesOperator(Protocol[P, T]): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Stream[T]: ... + + @staticmethod + def raw(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[T]: ... + + @staticmethod + def pipe( + *args: P.args, **kwargs: P.kwargs + ) -> Callable[[AsyncIterable[Any]], Stream[T]]: ... + + +# Operator decorators def operator( @@ -297,18 +309,18 @@ async def random(offset=0., width=1.): while True: yield offset + width * random.random() - The return value is a dynamically created class. - It has the same name, module and doc as the original function. + The return value is a dynamically created callable. + It has the same name, module and documentation as the original function. - A new stream is created by simply instanciating the operator:: + A new stream is created by simply calling the operator:: xs = random() - The original function is called at instanciation to check that - signature match. Other methods are available: + The original function is called right away to check that the + signatures match. Other methods are available: - `original`: the original function as a static method - - `raw`: same as original but add extra checking + - `raw`: same as original with extra checking The `pipable` argument is deprecated, use `pipable_operator` instead. """ @@ -357,9 +369,6 @@ async def random(offset=0., width=1.): # Gather attributes class OperatorImplementation: - __qualname__ = name - __module__ = module - __doc__ = doc original = staticmethod(original_func) @@ -399,8 +408,18 @@ def __str__(self) -> str: OperatorImplementation.__call__.__module__ = module OperatorImplementation.__call__.__doc__ = doc - # Create operator class - return OperatorImplementation() + # Create operator singleton + properly_named_class = type( + name, + (OperatorImplementation,), + { + "__qualname__": name, + "__module__": module, + "__doc__": doc, + }, + ) + operator_instance = properly_named_class() + return operator_instance def pipable_operator( @@ -420,19 +439,19 @@ async def multiply(source, factor): The first argument is expected to be the asynchronous iteratable used for piping. - The return value is a dynamically created class. - It has the same name, module and doc as the original function. + The return value is a dynamically created callable. + It has the same name, module and documentation as the original function. - A new stream is created by simply instanciating the operator:: + A new stream is created by simply calling the operator:: xs = random() ys = multiply(xs, 2) - The original function is called at instanciation to check that - signature match. The source is also checked for asynchronous iteration. + The original function is called right away (but not awaited) to check that + signatures match. The sources are also checked for asynchronous iteration. - The operator also have a pipe class method that can be used along - with the piping synthax:: + The operator also have a `pipe` method that can be used with the pipe + synthax:: xs = random() ys = xs | multiply.pipe(2) @@ -442,7 +461,7 @@ async def multiply(source, factor): Other methods are available: - `original`: the original function as a static method - - `raw`: same as original but add extra checking + - `raw`: same as original with extra checking The raw method is useful to create new operators from existing ones:: @@ -495,9 +514,6 @@ def double(source): # Gather attributes class PipableOperatorImplementation: - __qualname__ = name - __module__ = module - __doc__ = doc original = staticmethod(original_func) @@ -570,6 +586,180 @@ def __str__(self) -> str: if extra_doc: PipableOperatorImplementation.pipe.__doc__ += "\n\n " + extra_doc - # Create operator class - operator_instance = PipableOperatorImplementation() + # Create operator singleton + properly_named_class = type( + name, + (PipableOperatorImplementation,), + { + "__qualname__": name, + "__module__": module, + "__doc__": doc, + }, + ) + operator_instance = properly_named_class() + return operator_instance + + +def sources_operator( + func: Callable[P, AsyncIterator[T]], +) -> SourcesOperator[P, T]: + """Create a pipable stream operator from an asynchronous generator + (or any function returning an asynchronous iterable) that takes + a variadic ``*args`` of sources as argument. + + Decorator usage:: + + @sources_operator + async def chain(*sources, repeat=1): + for source in (sources * repeat): + async with streamcontext(source) as streamer: + async for item in streamer: + yield item + + Positional arguments are expected to be asynchronous iterables. + + When used in a pipable context, the asynchronous iterable injected by + the pipe operator is used as the first argument. + + The return value is a dynamically created callable. + It has the same name, module and documentation as the original function. + + A new stream is created by simply calling the operator:: + + xs = chain() + ys = chain(random()) + zs = chain(stream.just(0.0), stream.just(1.0), random()) + + The original function is called right away (but not awaited) to check that + signatures match. The sources are also checked for asynchronous iteration. + + The operator also have a `pipe` method that can be used with the pipe + synthax:: + + just_zero = stream.just(0.0) + zs = just_zero | chain.pipe(stream.just(1.0), random()) + + This is strictly equivalent to the previous ``zs`` example. + + Other methods are available: + + - `original`: the original function as a static method + - `raw`: same as original with extra checking + + The raw method is useful to create new operators from existing ones:: + + @sources_operator + def chain_twice(*sources): + return chain.raw(*sources, repeat=2) + """ + # First check for classmethod instance, to avoid more confusing errors later on + if isinstance(func, classmethod): + raise ValueError( + "An operator cannot be created from a class method, " + "since the decorated function becomes an operator class" + ) + + # Gather data + name = func.__name__ + module = func.__module__ + extra_doc = func.__doc__ + doc = extra_doc or f"Regular {name} stream operator." + + # Extract signature + signature = inspect.signature(func) + parameters = list(signature.parameters.values()) + return_annotation = signature.return_annotation + if parameters and parameters[0].name in ("self", "cls"): + raise ValueError( + "An operator cannot be created from a method, " + "since the decorated function becomes an operator class" + ) + + # Check for positional first parameter + if not parameters or parameters[0].kind != inspect.Parameter.VAR_POSITIONAL: + raise ValueError( + "The first parameter of the sources operator must be var-positional" + ) + + # Wrapped static method + original_func = func + original_func.__qualname__ = name + ".original" + + # Gather attributes + class SourcesOperatorImplementation: + + original = staticmethod(original_func) + + @staticmethod + def raw(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[T]: + for source in args: + assert_async_iterable(source) + return func(*args, **kwargs) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Stream[T]: + for source in args: + assert_async_iterable(source) + factory = functools.partial(self.raw, *args, **kwargs) + return Stream(factory) + + @staticmethod + def pipe( + *args: P.args, + **kwargs: P.kwargs, + ) -> Callable[[AsyncIterable[Any]], Stream[T]]: + return lambda source: operator_instance(source, *args, **kwargs) # type: ignore + + def __repr__(self) -> str: + return f"{module}.{name}" + + def __str__(self) -> str: + return f"{module}.{name}" + + # Customize raw method + SourcesOperatorImplementation.raw.__signature__ = signature # type: ignore[attr-defined] + SourcesOperatorImplementation.raw.__qualname__ = name + ".raw" + SourcesOperatorImplementation.raw.__module__ = module + SourcesOperatorImplementation.raw.__doc__ = doc + + # Customize call method + self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) + new_parameters = [self_parameter] + parameters + new_return_annotation = ( + return_annotation.replace("AsyncIterator", "Stream") + if isinstance(return_annotation, str) + else return_annotation + ) + SourcesOperatorImplementation.__call__.__signature__ = signature.replace( # type: ignore[attr-defined] + parameters=new_parameters, return_annotation=new_return_annotation + ) + SourcesOperatorImplementation.__call__.__qualname__ = name + ".__call__" + SourcesOperatorImplementation.__call__.__name__ = "__call__" + SourcesOperatorImplementation.__call__.__module__ = module + SourcesOperatorImplementation.__call__.__doc__ = doc + + # Customize pipe method + pipe_parameters = parameters + pipe_return_annotation = f"Callable[[AsyncIterable[Any]], {new_return_annotation}]" + SourcesOperatorImplementation.pipe.__signature__ = signature.replace( # type: ignore[attr-defined] + parameters=pipe_parameters, return_annotation=pipe_return_annotation + ) + SourcesOperatorImplementation.pipe.__qualname__ = name + ".pipe" + SourcesOperatorImplementation.pipe.__module__ = module + SourcesOperatorImplementation.pipe.__doc__ = ( + f'Piped version of the "{name}" stream operator.' + ) + if extra_doc: + SourcesOperatorImplementation.pipe.__doc__ += "\n\n " + extra_doc + + # Create operator singleton + properly_named_class = type( + name, + (SourcesOperatorImplementation,), + { + "__qualname__": name, + "__module__": module, + "__doc__": doc, + }, + ) + operator_instance = properly_named_class() return operator_instance diff --git a/aiostream/stream/combine.py b/aiostream/stream/combine.py index b3534c2..e7a4dbf 100644 --- a/aiostream/stream/combine.py +++ b/aiostream/stream/combine.py @@ -17,7 +17,7 @@ from typing_extensions import ParamSpec from ..aiter_utils import AsyncExitStack, anext -from ..core import streamcontext, pipable_operator +from ..core import sources_operator, streamcontext, pipable_operator from . import create from . import select @@ -32,27 +32,22 @@ P = ParamSpec("P") -@pipable_operator -async def chain( - source: AsyncIterable[T], *more_sources: AsyncIterable[T] -) -> AsyncIterator[T]: +@sources_operator +async def chain(*sources: AsyncIterable[T]) -> AsyncIterator[T]: """Chain asynchronous sequences together, in the order they are given. Note: the sequences are not iterated until it is required, so if the operation is interrupted, the remaining sequences will be left untouched. """ - sources = source, *more_sources for source in sources: async with streamcontext(source) as streamer: async for item in streamer: yield item -@pipable_operator -async def zip( - source: AsyncIterable[T], *more_sources: AsyncIterable[T] -) -> AsyncIterator[tuple[T, ...]]: +@sources_operator +async def zip(*sources: AsyncIterable[T]) -> AsyncIterator[tuple[T, ...]]: """Combine and forward the elements of several asynchronous sequences. Each generated value is a tuple of elements, using the same order as @@ -62,7 +57,9 @@ async def zip( Note: the different sequences are awaited in parrallel, so that their waiting times don't add up. """ - sources = source, *more_sources + # No sources + if not sources: + return # One sources if len(sources) == 1: @@ -209,9 +206,9 @@ def map( return smap.raw(source, sync_func, *more_sources) -@pipable_operator +@sources_operator def merge( - source: AsyncIterable[T], *more_sources: AsyncIterable[T] + *sources: AsyncIterable[T], ) -> AsyncIterator[T]: """Merge several asynchronous sequences together. @@ -219,15 +216,13 @@ def merge( are forwarded as soon as they're available. The generation continues until all the sequences are exhausted. """ - sources = [source, *more_sources] source_stream: AsyncIterable[AsyncIterable[T]] = create.iterate.raw(sources) return advanced.flatten.raw(source_stream) -@pipable_operator +@sources_operator def ziplatest( - source: AsyncIterable[T], - *more_sources: AsyncIterable[T], + *sources: AsyncIterable[T], partial: bool = True, default: T | None = None, ) -> AsyncIterator[tuple[T | None, ...]]: @@ -244,7 +239,6 @@ def ziplatest( are forwarded as soon as they're available. The generation continues until all the sequences are exhausted. """ - sources = source, *more_sources n = len(sources) # Custom getter diff --git a/docs/core.rst b/docs/core.rst index 0aedf43..0203ef1 100644 --- a/docs/core.rst +++ b/docs/core.rst @@ -22,3 +22,5 @@ Operator decorators .. autofunction:: operator .. autofunction:: pipable_operator + +.. autofunction:: sources_operator diff --git a/tests/test_combine.py b/tests/test_combine.py index 161a66b..27892a0 100644 --- a/tests/test_combine.py +++ b/tests/test_combine.py @@ -17,6 +17,10 @@ async def test_chain(assert_run, assert_cleanup): xs += stream.range(15, 20) | add_resource.pipe(1) await assert_run(xs, list(range(10, 20))) + # Empty chain (issue #95) + xs = stream.chain() + await assert_run(xs, []) + @pytest.mark.asyncio async def test_zip(assert_run): @@ -25,6 +29,10 @@ async def test_zip(assert_run): expected = [(x,) * 3 for x in range(5)] await assert_run(ys, expected) + # Empty zip (issue #95) + xs = stream.zip() + await assert_run(xs, []) + @pytest.mark.asyncio async def test_map(assert_run, assert_cleanup): @@ -173,6 +181,10 @@ async def agen2(): xs = stream.merge(agen1(), agen2()) | pipe.delay(1) | pipe.take(1) await assert_run(xs, [1]) + # Empty merge (issue #95) + xs = stream.merge() + await assert_run(xs, []) + @pytest.mark.asyncio async def test_ziplatest(assert_run, assert_cleanup): @@ -189,3 +201,7 @@ async def test_ziplatest(assert_run, assert_cleanup): zs = stream.ziplatest(xs, ys, partial=False) await assert_run(zs, [(0, 1), (2, 1), (2, 3), (4, 3)]) assert loop.steps == [1, 1, 1, 1] + + # Empty ziplatest (issue #95) + xs = stream.ziplatest() + await assert_run(xs, []) diff --git a/tests/test_core.py b/tests/test_core.py index 9b1ecde..fabb97a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,7 +1,7 @@ import inspect import pytest -from aiostream.core import pipable_operator +from aiostream.core import pipable_operator, sources_operator from aiostream.test_utils import add_resource from aiostream import stream, streamcontext, operator @@ -25,7 +25,9 @@ async def test_streamcontext(assert_cleanup): assert loop.steps == [1] -@pytest.mark.parametrize("operator_param", [operator, pipable_operator]) +@pytest.mark.parametrize( + "operator_param", [operator, pipable_operator, sources_operator] +) def test_operator_from_method(operator_param): with pytest.raises(ValueError): @@ -131,6 +133,14 @@ async def test1(*args): yield 1 +def test_sources_operator_with_postional_args(): + with pytest.raises(ValueError): + + @sources_operator + async def test1(source): + yield 1 + + def test_introspection_for_operator(): # Extract original information original = stream.range.original # type: ignore @@ -214,3 +224,53 @@ def test_introspection_for_pipable_operator(): str(inspect.signature(stream.take.pipe)) == "(n: 'int') -> 'Callable[[AsyncIterable[X]], Stream[T]]'" ) + + +def test_introspection_for_sources_operator(): + # Extract original information + original = stream.zip.original # type: ignore + original_doc = original.__doc__ + assert original_doc is not None + assert ( + original_doc.splitlines()[0] + == "Combine and forward the elements of several asynchronous sequences." + ) + assert ( + str(inspect.signature(original)) + == "(*sources: 'AsyncIterable[T]') -> 'AsyncIterator[tuple[T, ...]]'" + ) + + # Check the stream operator + assert str(stream.zip) == repr(stream.zip) == "aiostream.stream.combine.zip" + assert stream.zip.__module__ == "aiostream.stream.combine" + assert stream.zip.__doc__ == original_doc + + # Check the raw method + assert stream.zip.raw.__qualname__ == "zip.raw" + assert stream.zip.raw.__module__ == "aiostream.stream.combine" + assert stream.zip.raw.__doc__ == original_doc + assert ( + str(inspect.signature(stream.zip.raw)) + == "(*sources: 'AsyncIterable[T]') -> 'AsyncIterator[tuple[T, ...]]'" + ) + + # Check the __call__ method + assert stream.zip.__call__.__qualname__ == "zip.__call__" + assert stream.zip.__call__.__module__ == "aiostream.stream.combine" + assert stream.zip.__call__.__doc__ == original_doc + assert ( + str(inspect.signature(stream.zip.__call__)) + == "(*sources: 'AsyncIterable[T]') -> 'Stream[tuple[T, ...]]'" + ) + + # Check the pipe method + assert stream.zip.pipe.__qualname__ == "zip.pipe" + assert stream.zip.pipe.__module__ == "aiostream.stream.combine" + assert ( + stream.zip.pipe.__doc__ + == 'Piped version of the "zip" stream operator.\n\n ' + original_doc + ) + assert ( + str(inspect.signature(stream.zip.pipe)) + == "(*sources: 'AsyncIterable[T]') -> 'Callable[[AsyncIterable[Any]], Stream[tuple[T, ...]]]'" + )