Skip to content

Commit

Permalink
Add sources operator
Browse files Browse the repository at this point in the history
  • Loading branch information
vxgmichel committed May 7, 2024
1 parent f8b51c1 commit 2ec465a
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 19 deletions.
175 changes: 174 additions & 1 deletion aiostream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,19 @@ def pipe(
) -> Callable[[AsyncIterable[A]], Stream[T]]: ...


# Operator decorator
class SourcesOperator(Protocol[P, T]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Stream[T]: ...

@staticmethod
def raw(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[T]: ...

@staticmethod
def pipe(
*args: P.args, **kwargs: P.kwargs
) -> Callable[[AsyncIterable[Any]], Stream[T]]: ...


# Operator decorators


def operator(
Expand Down Expand Up @@ -573,3 +585,164 @@ def __str__(self) -> str:
# Create operator class
operator_instance = PipableOperatorImplementation()
return operator_instance


def sources_operator(
func: Callable[P, AsyncIterator[T]],
) -> SourcesOperator[P, T]:
"""Create a pipable stream operator from an asynchronous generator
(or any function returning an asynchronous iterable) that takes
a variadic ``*args`` of sources as argument.
Decorator usage::
@sources_operator
async def chain(*sources):
for source in sources:
async with streamcontext(source) as streamer:
async for item in streamer:
yield item
Positional arguments are expected to be the asynchronous iteratables.
Keyword arguments are not supported at the moment.
When used in a pipable context, the asynchronous iterable injected by
the pipe operator is used as the first argument.
The return value is a dynamically created class.
It has the same name, module and doc as the original function.
A new stream is created by simply instanciating the operator::
empty_chained = chain()
single_chained = chain(random())
multiple_chained = chain(stream.just(0.0), stream.just(1.0), random())
The original function is called at instanciation to check that
signature match. The source is also checked for asynchronous iteration.
The operator also have a pipe class method that can be used along
with the piping synthax::
just_zero = stream.just(0.0)
multiple_chained = just_zero | chain.pipe(stream.just(1.0, random())
This is strictly equivalent to the previous example.
Other methods are available:
- `original`: the original function as a static method
- `raw`: same as original but add extra checking
The raw method is useful to create new operators from existing ones::
@chain_operator
def chain_twice(*sources):
return chain.raw(*sources, *sources)
"""
# First check for classmethod instance, to avoid more confusing errors later on
if isinstance(func, classmethod):
raise ValueError(

Check warning on line 646 in aiostream/core.py

View check run for this annotation

Codecov / codecov/patch

aiostream/core.py#L646

Added line #L646 was not covered by tests
"An operator cannot be created from a class method, "
"since the decorated function becomes an operator class"
)

# Gather data
name = func.__name__
module = func.__module__
extra_doc = func.__doc__
doc = extra_doc or f"Regular {name} stream operator."

# Extract signature
signature = inspect.signature(func)
parameters = list(signature.parameters.values())
return_annotation = signature.return_annotation
if parameters and parameters[0].name in ("self", "cls"):
raise ValueError(

Check warning on line 662 in aiostream/core.py

View check run for this annotation

Codecov / codecov/patch

aiostream/core.py#L662

Added line #L662 was not covered by tests
"An operator cannot be created from a method, "
"since the decorated function becomes an operator class"
)

# Check for positional first parameter
if not parameters or parameters[0].kind != inspect.Parameter.VAR_POSITIONAL:
raise ValueError(

Check warning on line 669 in aiostream/core.py

View check run for this annotation

Codecov / codecov/patch

aiostream/core.py#L669

Added line #L669 was not covered by tests
"The first parameter of the sources operator must be var-positional"
)

# Wrapped static method
original_func = func
original_func.__qualname__ = name + ".original"

# Gather attributes
class PipableOperatorImplementation:
__qualname__ = name
__module__ = module
__doc__ = doc

original = staticmethod(original_func)

@staticmethod
def raw(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[T]:
for source in args:
assert_async_iterable(source)
return func(*args, **kwargs)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Stream[T]:
for source in args:
assert_async_iterable(source)
factory = functools.partial(self.raw, *args, **kwargs)
return Stream(factory)

@staticmethod
def pipe(
*args: P.args,
**kwargs: P.kwargs,
) -> Callable[[AsyncIterable[Any]], Stream[T]]:
return lambda source: operator_instance(source, *args, **kwargs) # type: ignore

def __repr__(self) -> str:
return f"{module}.{name}"

Check warning on line 705 in aiostream/core.py

View check run for this annotation

Codecov / codecov/patch

aiostream/core.py#L705

Added line #L705 was not covered by tests

def __str__(self) -> str:
return f"{module}.{name}"

Check warning on line 708 in aiostream/core.py

View check run for this annotation

Codecov / codecov/patch

aiostream/core.py#L708

Added line #L708 was not covered by tests

# Customize raw method
PipableOperatorImplementation.raw.__signature__ = signature # type: ignore[attr-defined]
PipableOperatorImplementation.raw.__qualname__ = name + ".raw"
PipableOperatorImplementation.raw.__module__ = module
PipableOperatorImplementation.raw.__doc__ = doc

# Customize call method
self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
new_parameters = [self_parameter] + parameters
new_return_annotation = (
return_annotation.replace("AsyncIterator", "Stream")
if isinstance(return_annotation, str)
else return_annotation
)
PipableOperatorImplementation.__call__.__signature__ = signature.replace( # type: ignore[attr-defined]
parameters=new_parameters, return_annotation=new_return_annotation
)
PipableOperatorImplementation.__call__.__qualname__ = name + ".__call__"
PipableOperatorImplementation.__call__.__name__ = "__call__"
PipableOperatorImplementation.__call__.__module__ = module
PipableOperatorImplementation.__call__.__doc__ = doc

# Customize pipe method
pipe_parameters = parameters
pipe_return_annotation = f"Callable[[AsyncIterable[Any]], {new_return_annotation}]"
PipableOperatorImplementation.pipe.__signature__ = signature.replace( # type: ignore[attr-defined]
parameters=pipe_parameters, return_annotation=pipe_return_annotation
)
PipableOperatorImplementation.pipe.__qualname__ = name + ".pipe"
PipableOperatorImplementation.pipe.__module__ = module
PipableOperatorImplementation.pipe.__doc__ = (
f'Piped version of the "{name}" stream operator.'
)
if extra_doc:
PipableOperatorImplementation.pipe.__doc__ += "\n\n " + extra_doc

# Create operator class
operator_instance = PipableOperatorImplementation()
return operator_instance
30 changes: 12 additions & 18 deletions aiostream/stream/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing_extensions import ParamSpec

from ..aiter_utils import AsyncExitStack, anext
from ..core import streamcontext, pipable_operator
from ..core import sources_operator, streamcontext, pipable_operator

from . import create
from . import select
Expand All @@ -32,27 +32,22 @@
P = ParamSpec("P")


@pipable_operator
async def chain(
source: AsyncIterable[T], *more_sources: AsyncIterable[T]
) -> AsyncIterator[T]:
@sources_operator
async def chain(*sources: AsyncIterable[T]) -> AsyncIterator[T]:
"""Chain asynchronous sequences together, in the order they are given.
Note: the sequences are not iterated until it is required,
so if the operation is interrupted, the remaining sequences
will be left untouched.
"""
sources = source, *more_sources
for source in sources:
async with streamcontext(source) as streamer:
async for item in streamer:
yield item


@pipable_operator
async def zip(
source: AsyncIterable[T], *more_sources: AsyncIterable[T]
) -> AsyncIterator[tuple[T, ...]]:
@sources_operator
async def zip(*sources: AsyncIterable[T]) -> AsyncIterator[tuple[T, ...]]:
"""Combine and forward the elements of several asynchronous sequences.
Each generated value is a tuple of elements, using the same order as
Expand All @@ -62,7 +57,9 @@ async def zip(
Note: the different sequences are awaited in parrallel, so that their
waiting times don't add up.
"""
sources = source, *more_sources
# No sources
if not sources:
return

Check warning on line 62 in aiostream/stream/combine.py

View check run for this annotation

Codecov / codecov/patch

aiostream/stream/combine.py#L62

Added line #L62 was not covered by tests

# One sources
if len(sources) == 1:
Expand Down Expand Up @@ -209,25 +206,23 @@ def map(
return smap.raw(source, sync_func, *more_sources)


@pipable_operator
@sources_operator
def merge(
source: AsyncIterable[T], *more_sources: AsyncIterable[T]
*sources: AsyncIterable[T],
) -> AsyncIterator[T]:
"""Merge several asynchronous sequences together.
All the sequences are iterated simultaneously and their elements
are forwarded as soon as they're available. The generation continues
until all the sequences are exhausted.
"""
sources = [source, *more_sources]
source_stream: AsyncIterable[AsyncIterable[T]] = create.iterate.raw(sources)
return advanced.flatten.raw(source_stream)


@pipable_operator
@sources_operator
def ziplatest(
source: AsyncIterable[T],
*more_sources: AsyncIterable[T],
*sources: AsyncIterable[T],
partial: bool = True,
default: T | None = None,
) -> AsyncIterator[tuple[T | None, ...]]:
Expand All @@ -244,7 +239,6 @@ def ziplatest(
are forwarded as soon as they're available. The generation continues
until all the sequences are exhausted.
"""
sources = source, *more_sources
n = len(sources)

# Custom getter
Expand Down
16 changes: 16 additions & 0 deletions tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ async def test_chain(assert_run, assert_cleanup):
xs += stream.range(15, 20) | add_resource.pipe(1)
await assert_run(xs, list(range(10, 20)))

# Empty chain (issue #95)
xs = stream.chain()
await assert_run(xs, [])


@pytest.mark.asyncio
async def test_zip(assert_run):
Expand All @@ -25,6 +29,10 @@ async def test_zip(assert_run):
expected = [(x,) * 3 for x in range(5)]
await assert_run(ys, expected)

# Empty zip (issue #95)
xs = stream.merge()
await assert_run(xs, [])


@pytest.mark.asyncio
async def test_map(assert_run, assert_cleanup):
Expand Down Expand Up @@ -173,6 +181,10 @@ async def agen2():
xs = stream.merge(agen1(), agen2()) | pipe.delay(1) | pipe.take(1)
await assert_run(xs, [1])

# Empty merge (issue #95)
xs = stream.merge()
await assert_run(xs, [])


@pytest.mark.asyncio
async def test_ziplatest(assert_run, assert_cleanup):
Expand All @@ -189,3 +201,7 @@ async def test_ziplatest(assert_run, assert_cleanup):
zs = stream.ziplatest(xs, ys, partial=False)
await assert_run(zs, [(0, 1), (2, 1), (2, 3), (4, 3)])
assert loop.steps == [1, 1, 1, 1]

# Empty ziplatest (issue #95)
xs = stream.ziplatest()
await assert_run(xs, [])

0 comments on commit 2ec465a

Please sign in to comment.