diff --git a/aries_cloudagent/messaging/task_queue.py b/aries_cloudagent/messaging/task_queue.py index 316206cbae..b9ea4d759e 100644 --- a/aries_cloudagent/messaging/task_queue.py +++ b/aries_cloudagent/messaging/task_queue.py @@ -108,7 +108,7 @@ async def _drain_loop(self): ): coro, task_complete, fut = self.pending_tasks.pop(0) task = self.run(coro, task_complete) - if fut: + if fut and not fut.done(): fut.set_result(task) if self.pending_tasks: await self._updated_evt.wait() @@ -179,9 +179,9 @@ def put(self, coro: Coroutine, task_complete: Callable = None) -> asyncio.Future """ fut = self.loop.create_future() if self._cancelled: + coro.close() fut.cancel() - return - if self.ready: + elif self.ready: task = self.run(coro, task_complete) fut.set_result(task) else: @@ -193,21 +193,22 @@ def completed_task(self, task: asyncio.Task, task_complete: Callable): exc_info = task_exc_info(task) if exc_info and not task_complete: LOGGER.exception("Error running task", exc_info=exc_info) - try: - self.active_tasks.remove(task) - except ValueError: - pass if task_complete: try: task_complete(CompletedTask(task, exc_info)) except Exception: LOGGER.exception("Error finalizing task") + try: + self.active_tasks.remove(task) + except ValueError: + pass self.drain() def cancel_pending(self): """Cancel any pending tasks in the queue.""" if self._drain_task: self._drain_task.cancel() + self._drain_task = None for coro, task_complete, fut in self.pending_tasks: coro.close() fut.cancel() @@ -228,24 +229,26 @@ async def complete(self, timeout: float = None, cleanup: bool = True): if timeout or timeout is None: try: await self.wait_for(timeout) - except TimeoutError: + except asyncio.TimeoutError: pass for task in self.active_tasks: if not task.done(): task.cancel() if cleanup: - while self.active_tasks: - await self._updated_evt.wait() + while True: + drain = self.drain() + if not drain: + break + await drain async def flush(self): """Wait for any active or pending tasks to be completed.""" - if self.pending_tasks and not self._drain_task: - self.drain() + self.drain() while self.active_tasks or self._drain_task: if self._drain_task: await self._drain_task if self.active_tasks: - await asyncio.gather(*self.active_tasks) + await asyncio.wait(self.active_tasks) def __await__(self): """Handle the builtin await operator.""" diff --git a/aries_cloudagent/messaging/tests/test_task_queue.py b/aries_cloudagent/messaging/tests/test_task_queue.py new file mode 100644 index 0000000000..4b1ae08a71 --- /dev/null +++ b/aries_cloudagent/messaging/tests/test_task_queue.py @@ -0,0 +1,159 @@ +import asyncio +from asynctest import TestCase + +from ..task_queue import CompletedTask, TaskQueue + + +async def retval(val): + return val + + +class TestTaskQueue(TestCase): + async def test_run(self): + queue = TaskQueue() + task = None + completed = [] + + def done(complete: CompletedTask): + assert complete.task is task + assert not complete.exc_info + completed.append(complete.task.result()) + + task = queue.run(retval(1), done) + assert queue.current_active == 1 + assert len(queue) == queue.current_size == 1 + assert not queue.current_pending + await queue.flush() + assert completed == [1] + assert task.result() == 1 + + with self.assertRaises(ValueError): + queue.run(None, done) + + async def test_put_no_limit(self): + queue = TaskQueue() + completed = [] + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.append(complete.task.result()) + + fut = queue.put(retval(1), done) + assert not queue.pending_tasks + await queue.flush() + assert completed == [1] + assert fut.result().result() == 1 + + with self.assertRaises(ValueError): + queue.add_pending(None, done) + + async def test_put_limited(self): + queue = TaskQueue(1) + assert queue.max_active == 1 + assert not queue.cancelled + completed = set() + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.add(complete.task.result()) + + fut1 = queue.put(retval(1), done) + fut2 = queue.put(retval(2), done) + assert queue.pending_tasks + await queue.flush() + assert completed == {1, 2} + assert fut1.result().result() == 1 + assert fut2.result().result() == 2 + + async def test_complete(self): + queue = TaskQueue() + completed = set() + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.add(complete.task.result()) + + queue.run(retval(1), done) + await queue.put(retval(2), done) + queue.put(retval(3), done) + await queue.complete() + assert completed == {1, 2, 3} + + async def test_cancel_pending(self): + queue = TaskQueue(1) + completed = set() + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.add(complete.task.result()) + + queue.run(retval(1), done) + queue.put(retval(2), done) + queue.put(retval(3), done) + queue.cancel_pending() + await queue.flush() + assert completed == {1} + + async def test_cancel_all(self): + queue = TaskQueue(1) + completed = set() + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.add(complete.task.result()) + + queue.run(retval(1), done) + queue.put(retval(2), done) + queue.put(retval(3), done) + queue.cancel() + assert queue.cancelled + await queue.flush() + assert not completed + assert not queue.current_size + + co = retval(1) + with self.assertRaises(RuntimeError): + queue.run(co, done) + co.close() + + co = retval(1) + fut = queue.put(co) + assert fut.cancelled() + + async def test_cancel_long(self): + queue = TaskQueue() + task = queue.run(asyncio.sleep(5)) + queue.cancel() + await queue + + # cancellation may take a second + # assert task.cancelled() + + with self.assertRaises(asyncio.CancelledError): + await task + + async def test_complete_with_timeout(self): + queue = TaskQueue() + task = queue.run(asyncio.sleep(5)) + await queue.complete(0.01) + + # cancellation may take a second + # assert task.cancelled() + + with self.assertRaises(asyncio.CancelledError): + await task + + async def test_repeat_callback(self): + # check that running the callback twice does not throw an exception + + queue = TaskQueue() + completed = [] + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.append(complete.task.result()) + + task = queue.run(retval(1), done) + await task + queue.completed_task(task, done) + assert completed == [1, 1]