From b966a9549e689525f7f1949c40baf9779545690a Mon Sep 17 00:00:00 2001 From: Max Fischer Date: Sun, 16 Apr 2023 12:03:41 +0200 Subject: [PATCH] Complete itertools.chain interface (#108) * added chain.aclose method for cleanup (closes #107) * do not reconstruct chain implementation again and again * use same implementation for chain and chain.from_iterable * chain owns explicitly passed iterables --- asyncstdlib/itertools.py | 49 +++++++++++++++++++++++++---------- unittests/test_itertools.py | 51 +++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 14 deletions(-) diff --git a/asyncstdlib/itertools.py b/asyncstdlib/itertools.py index fc12573..f6f06ee 100644 --- a/asyncstdlib/itertools.py +++ b/asyncstdlib/itertools.py @@ -148,33 +148,54 @@ class chain(AsyncIterator[T]): The resulting iterator consecutively iterates over and yields all values from each of the ``iterables``. This is similar to converting all ``iterables`` to sequences and concatenating them, but lazily exhausts each iterable. + + The ``chain`` assumes ownership of its ``iterables`` and closes them reliably + when the ``chain`` is closed. Pass the ``iterables`` via a :py:class:`tuple` to + ``chain.from_iterable`` to avoid closing all iterables but those already processed. """ - __slots__ = ("_impl",) + __slots__ = ("_iterator", "_owned_iterators") - def __init__(self, *iterables: AnyIterable[T]): - async def impl() -> AsyncIterator[T]: - for iterable in iterables: + @staticmethod + async def _chain_iterator( + any_iterables: AnyIterable[AnyIterable[T]], + ) -> AsyncGenerator[T, None]: + async with ScopedIter(any_iterables) as iterables: + async for iterable in iterables: async with ScopedIter(iterable) as iterator: async for item in iterator: yield item - self._impl = impl() + def __init__( + self, *iterables: AnyIterable[T], _iterables: AnyIterable[AnyIterable[T]] = () + ): + self._iterator = self._chain_iterator(iterables or _iterables) + self._owned_iterators = ( + iterable + for iterable in iterables + if isinstance(iterable, AsyncIterator) and hasattr(iterable, "aclose") + ) - @staticmethod - async def from_iterable(iterable: AnyIterable[AnyIterable[T]]) -> AsyncIterator[T]: + @classmethod + def from_iterable(cls, iterable: AnyIterable[AnyIterable[T]]) -> "chain[T]": """ Alternate constructor for :py:func:`~.chain` that lazily exhausts - iterables as well + the ``iterable`` of iterables as well + + This is suitable for chaining iterables from a lazy or infinite ``iterable``. + In turn, closing the ``chain`` only closes those iterables + already fetched from ``iterable``. """ - async with ScopedIter(iterable) as iterables: - async for sub_iterable in iterables: - async with ScopedIter(sub_iterable) as iterator: - async for item in iterator: - yield item + return cls(_iterables=iterable) def __anext__(self) -> Awaitable[T]: - return self._impl.__anext__() + return self._iterator.__anext__() + + async def aclose(self) -> None: + for iterable in self._owned_iterators: + if hasattr(iterable, "aclose"): + await iterable.aclose() + await self._iterator.aclose() async def compress( diff --git a/unittests/test_itertools.py b/unittests/test_itertools.py index a4d3984..43ce466 100644 --- a/unittests/test_itertools.py +++ b/unittests/test_itertools.py @@ -87,6 +87,57 @@ async def test_chain(iterables): ) +class ACloseFacade: + """Wrapper to check if an iterator has been closed""" + + def __init__(self, iterable): + self.closed = False + self.__wrapped__ = iterable + self._iterator = a.iter(iterable) + + async def __anext__(self): + if self.closed: + raise StopAsyncIteration() + return await self._iterator.__anext__() + + def __aiter__(self): + return self + + async def aclose(self): + if hasattr(self._iterator, "aclose"): + await self._iterator.aclose() + self.closed = True + + +@pytest.mark.parametrize("iterables", chains) +@sync +async def test_chain_close_auto(iterables): + """Test that `chain` closes exhausted iterators""" + closeable_iterables = [ACloseFacade(iterable) for iterable in iterables] + assert await a.list(a.chain(*closeable_iterables)) == list( + itertools.chain(*iterables) + ) + assert all(iterable.closed for iterable in closeable_iterables) + + +# insert a known filled iterable since chain closes all that are exhausted +@pytest.mark.parametrize("iterables", [([1], *chain) for chain in chains]) +@pytest.mark.parametrize( + "chain_type, must_close", + [(lambda iterators: a.chain(*iterators), True), (a.chain.from_iterable, False)], +) +@sync +async def test_chain_close_partial(iterables, chain_type, must_close): + """Test that `chain` closes owned iterators""" + closeable_iterables = [ACloseFacade(iterable) for iterable in iterables] + chain = chain_type(closeable_iterables) + assert await a.anext(chain) == next(itertools.chain(*iterables)) + await chain.aclose() + assert all(iterable.closed == must_close for iterable in closeable_iterables[1:]) + # closed chain must remain closed regardless of iterators + assert await a.anext(chain, "sentinel") == "sentinel" + + compress_cases = [ (range(20), [idx % 2 for idx in range(20)]), ([1] * 5, [True, True, False, True, True]),