Skip to content

Commit

Permalink
add task queue unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Whitehead <[email protected]>
  • Loading branch information
andrewwhitehead committed Nov 16, 2019
1 parent 1c3dc0c commit b0554e5
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 13 deletions.
29 changes: 16 additions & 13 deletions aries_cloudagent/messaging/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def _drain_loop(self):
):
coro, task_complete, fut = self.pending_tasks.pop(0)
task = self.run(coro, task_complete)
if fut:
if fut and not fut.done():
fut.set_result(task)
if self.pending_tasks:
await self._updated_evt.wait()
Expand Down Expand Up @@ -179,9 +179,9 @@ def put(self, coro: Coroutine, task_complete: Callable = None) -> asyncio.Future
"""
fut = self.loop.create_future()
if self._cancelled:
coro.close()
fut.cancel()
return
if self.ready:
elif self.ready:
task = self.run(coro, task_complete)
fut.set_result(task)
else:
Expand All @@ -193,21 +193,22 @@ def completed_task(self, task: asyncio.Task, task_complete: Callable):
exc_info = task_exc_info(task)
if exc_info and not task_complete:
LOGGER.exception("Error running task", exc_info=exc_info)
try:
self.active_tasks.remove(task)
except ValueError:
pass
if task_complete:
try:
task_complete(CompletedTask(task, exc_info))
except Exception:
LOGGER.exception("Error finalizing task")
try:
self.active_tasks.remove(task)
except ValueError:
pass
self.drain()

def cancel_pending(self):
"""Cancel any pending tasks in the queue."""
if self._drain_task:
self._drain_task.cancel()
self._drain_task = None
for coro, task_complete, fut in self.pending_tasks:
coro.close()
fut.cancel()
Expand All @@ -228,24 +229,26 @@ async def complete(self, timeout: float = None, cleanup: bool = True):
if timeout or timeout is None:
try:
await self.wait_for(timeout)
except TimeoutError:
except asyncio.TimeoutError:
pass
for task in self.active_tasks:
if not task.done():
task.cancel()
if cleanup:
while self.active_tasks:
await self._updated_evt.wait()
while True:
drain = self.drain()
if not drain:
break
await drain

async def flush(self):
"""Wait for any active or pending tasks to be completed."""
if self.pending_tasks and not self._drain_task:
self.drain()
self.drain()
while self.active_tasks or self._drain_task:
if self._drain_task:
await self._drain_task
if self.active_tasks:
await asyncio.gather(*self.active_tasks)
await asyncio.wait(self.active_tasks)

def __await__(self):
"""Handle the builtin await operator."""
Expand Down
159 changes: 159 additions & 0 deletions aries_cloudagent/messaging/tests/test_task_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import asyncio
from asynctest import TestCase

from ..task_queue import CompletedTask, TaskQueue


async def retval(val):
return val


class TestTaskQueue(TestCase):
async def test_run(self):
queue = TaskQueue()
task = None
completed = []

def done(complete: CompletedTask):
assert complete.task is task
assert not complete.exc_info
completed.append(complete.task.result())

task = queue.run(retval(1), done)
assert queue.current_active == 1
assert len(queue) == queue.current_size == 1
assert not queue.current_pending
await queue.flush()
assert completed == [1]
assert task.result() == 1

with self.assertRaises(ValueError):
queue.run(None, done)

async def test_put_no_limit(self):
queue = TaskQueue()
completed = []

def done(complete: CompletedTask):
assert not complete.exc_info
completed.append(complete.task.result())

fut = queue.put(retval(1), done)
assert not queue.pending_tasks
await queue.flush()
assert completed == [1]
assert fut.result().result() == 1

with self.assertRaises(ValueError):
queue.add_pending(None, done)

async def test_put_limited(self):
queue = TaskQueue(1)
assert queue.max_active == 1
assert not queue.cancelled
completed = set()

def done(complete: CompletedTask):
assert not complete.exc_info
completed.add(complete.task.result())

fut1 = queue.put(retval(1), done)
fut2 = queue.put(retval(2), done)
assert queue.pending_tasks
await queue.flush()
assert completed == {1, 2}
assert fut1.result().result() == 1
assert fut2.result().result() == 2

async def test_complete(self):
queue = TaskQueue()
completed = set()

def done(complete: CompletedTask):
assert not complete.exc_info
completed.add(complete.task.result())

queue.run(retval(1), done)
await queue.put(retval(2), done)
queue.put(retval(3), done)
await queue.complete()
assert completed == {1, 2, 3}

async def test_cancel_pending(self):
queue = TaskQueue(1)
completed = set()

def done(complete: CompletedTask):
assert not complete.exc_info
completed.add(complete.task.result())

queue.run(retval(1), done)
queue.put(retval(2), done)
queue.put(retval(3), done)
queue.cancel_pending()
await queue.flush()
assert completed == {1}

async def test_cancel_all(self):
queue = TaskQueue(1)
completed = set()

def done(complete: CompletedTask):
assert not complete.exc_info
completed.add(complete.task.result())

queue.run(retval(1), done)
queue.put(retval(2), done)
queue.put(retval(3), done)
queue.cancel()
assert queue.cancelled
await queue.flush()
assert not completed
assert not queue.current_size

co = retval(1)
with self.assertRaises(RuntimeError):
queue.run(co, done)
co.close()

co = retval(1)
fut = queue.put(co)
assert fut.cancelled()

async def test_cancel_long(self):
queue = TaskQueue()
task = queue.run(asyncio.sleep(5))
queue.cancel()
await queue

# cancellation may take a second
# assert task.cancelled()

with self.assertRaises(asyncio.CancelledError):
await task

async def test_complete_with_timeout(self):
queue = TaskQueue()
task = queue.run(asyncio.sleep(5))
await queue.complete(0.01)

# cancellation may take a second
# assert task.cancelled()

with self.assertRaises(asyncio.CancelledError):
await task

async def test_repeat_callback(self):
# check that running the callback twice does not throw an exception

queue = TaskQueue()
completed = []

def done(complete: CompletedTask):
assert not complete.exc_info
completed.append(complete.task.result())

task = queue.run(retval(1), done)
await task
queue.completed_task(task, done)
assert completed == [1, 1]

0 comments on commit b0554e5

Please sign in to comment.