diff --git a/aries_cloudagent/admin/server.py b/aries_cloudagent/admin/server.py index c3163033d3..6240943330 100644 --- a/aries_cloudagent/admin/server.py +++ b/aries_cloudagent/admin/server.py @@ -85,10 +85,9 @@ def __init__( ): """Initialize the webhook target.""" self.endpoint = endpoint - self._topic_filter = None self.retries = retries - # call setter - self.topic_filter = topic_filter + self._topic_filter = None + self.topic_filter = topic_filter # call setter @property def topic_filter(self) -> Set[str]: @@ -176,6 +175,8 @@ async def check_token(request, handler): middlewares.append(check_token) + collector: Collector = await self.context.inject(Collector, required=False) + if self.task_queue: @web.middleware @@ -185,14 +186,11 @@ async def apply_limiter(request, handler): middlewares.append(apply_limiter) - stats: Collector = await self.context.inject(Collector, required=False) - if stats: + elif collector: @web.middleware async def collect_stats(request, handler): - handler = stats.wrap_coro( - handler, [handler.__qualname__, "any-admin-request"] - ) + handler = collector.wrap_coro(handler, [handler.__qualname__]) return await handler(request) middlewares.append(collect_stats) @@ -231,7 +229,7 @@ async def collect_stats(request, handler): for route in app.router.routes(): cors.add(route) # get agent label - agent_label = self.context.settings.get("default_label"), + agent_label = self.context.settings.get("default_label") version_string = f"v{__version__}" setup_aiohttp_apispec( @@ -288,7 +286,6 @@ async def plugins_handler(self, request: web.BaseRequest): registry: PluginRegistry = await self.context.inject( PluginRegistry, required=False ) - print(registry) plugins = registry and sorted(registry.plugin_names) or [] return web.json_response({"result": plugins}) diff --git a/aries_cloudagent/conductor.py b/aries_cloudagent/conductor.py index af7e32e4c5..725d6438c3 100644 --- a/aries_cloudagent/conductor.py +++ b/aries_cloudagent/conductor.py @@ -64,6 +64,7 @@ async def setup(self): context = await self.context_builder.build() self.dispatcher = Dispatcher(context) + await self.dispatcher.setup() wire_format = await context.inject(BaseWireFormat, required=False) if wire_format and hasattr(wire_format, "task_queue"): @@ -118,12 +119,11 @@ async def setup(self): # "create_inbound_session", ), ) - collector.wrap(self.dispatcher, "handle_message") # at the class level (!) should not be performed multiple times collector.wrap( ConnectionManager, ( - "get_connection_targets", + # "get_connection_targets", "fetch_did_document", "find_inbound_connection", ), @@ -214,6 +214,8 @@ async def start(self) -> None: async def stop(self, timeout=1.0): """Stop the agent.""" shutdown = TaskQueue() + if self.dispatcher: + shutdown.run(self.dispatcher.complete()) if self.admin_server: shutdown.run(self.admin_server.stop()) if self.inbound_transport_manager: diff --git a/aries_cloudagent/config/default_context.py b/aries_cloudagent/config/default_context.py index e1501762ad..7ce066251b 100644 --- a/aries_cloudagent/config/default_context.py +++ b/aries_cloudagent/config/default_context.py @@ -20,7 +20,6 @@ from ..stats import Collector from ..storage.base import BaseStorage from ..storage.provider import StorageProvider -from ..transport.pack_format import PackWireFormat from ..transport.wire_format import BaseWireFormat from ..wallet.base import BaseWallet from ..wallet.provider import WalletProvider @@ -67,14 +66,12 @@ async def bind_providers(self, context: InjectionContext): StatsProvider( WalletProvider(), ( - "create", - "open", "sign_message", "verify_message", "encrypt_message", "decrypt_message", - "pack_message", - "unpack_message", + # "pack_message", + # "unpack_message", "get_local_did", ), ) @@ -128,7 +125,12 @@ async def bind_providers(self, context: InjectionContext): BaseWireFormat, CachedProvider( StatsProvider( - ClassProvider(PackWireFormat), ("encode_message", "parse_message"), + ClassProvider( + "aries_cloudagent.transport.pack_format.PackWireFormat" + ), + ( + # "encode_message", "parse_message" + ), ) ), ) diff --git a/aries_cloudagent/config/provider.py b/aries_cloudagent/config/provider.py index e51da07b15..3aba0f39a8 100644 --- a/aries_cloudagent/config/provider.py +++ b/aries_cloudagent/config/provider.py @@ -103,7 +103,10 @@ def __init__( async def provide(self, config: BaseSettings, injector: BaseInjector): """Provide the object instance given a config and injector.""" instance = await self._provider.provide(config, injector) - collector: Collector = await injector.inject(Collector, required=False) - if collector: - collector.wrap(instance, self._methods, ignore_missing=self._ignore_missing) + if self._methods: + collector: Collector = await injector.inject(Collector, required=False) + if collector: + collector.wrap( + instance, self._methods, ignore_missing=self._ignore_missing + ) return instance diff --git a/aries_cloudagent/dispatcher.py b/aries_cloudagent/dispatcher.py index 91b7842691..10cd2c8443 100644 --- a/aries_cloudagent/dispatcher.py +++ b/aries_cloudagent/dispatcher.py @@ -7,8 +7,11 @@ import asyncio import logging +import os from typing import Callable, Coroutine, Union +from aiohttp.web import HTTPException + from .config.injection_context import InjectionContext from .messaging.agent_message import AgentMessage from .messaging.error import MessageParseError @@ -18,7 +21,7 @@ from .messaging.protocol_registry import ProtocolRegistry from .messaging.request_context import RequestContext from .messaging.responder import BaseResponder -from .messaging.task_queue import TaskQueue +from .messaging.task_queue import CompletedTask, PendingTask, TaskQueue from .messaging.util import datetime_now from .stats import Collector from .transport.inbound.message import InboundMessage @@ -38,15 +41,44 @@ class Dispatcher: def __init__(self, context: InjectionContext): """Initialize an instance of Dispatcher.""" self.context = context - self.task_queue = TaskQueue(max_active=20) + self.collector: Collector = None + self.task_queue: TaskQueue = None + + async def setup(self): + """Perform async instance setup.""" + self.collector = await self.context.inject(Collector, required=False) + max_active = int(os.getenv("DISPATCHER_MAX_ACTIVE", 50)) + self.task_queue = TaskQueue( + max_active=max_active, timed=bool(self.collector), trace_fn=self.log_task + ) - def put_task(self, coro: Coroutine, complete: Callable = None) -> asyncio.Future: + def put_task( + self, coro: Coroutine, complete: Callable = None, ident: str = None + ) -> PendingTask: """Run a task in the task queue, potentially blocking other handlers.""" - return self.task_queue.put(coro, complete) + return self.task_queue.put(coro, complete, ident) - def run_task(self, coro: Coroutine, complete: Callable = None) -> asyncio.Task: + def run_task( + self, coro: Coroutine, complete: Callable = None, ident: str = None + ) -> asyncio.Task: """Run a task in the task queue, potentially blocking other handlers.""" - return self.task_queue.run(coro, complete) + return self.task_queue.run(coro, complete, ident) + + def log_task(self, task: CompletedTask): + """Log a completed task using the stats collector.""" + if task.exc_info and not issubclass(task.exc_info[0], HTTPException): + # skip errors intentionally returned to HTTP clients + LOGGER.exception( + "Handler error: %s", task.ident or "", exc_info=task.exc_info + ) + if self.collector: + timing = task.timing + if "queued" in timing: + self.collector.log( + f"Dispatcher:queued", timing["unqueued"] - timing["queued"] + ) + if task.ident: + self.collector.log(task.ident, timing["ended"] - timing["started"]) def queue_message( self, @@ -54,7 +86,7 @@ def queue_message( send_outbound: Coroutine, send_webhook: Coroutine = None, complete: Callable = None, - ) -> asyncio.Future: + ) -> PendingTask: """ Add a message to the processing queue for handling. @@ -65,7 +97,7 @@ def queue_message( complete: Function to call when the handler has completed Returns: - A future resolving to the handler task + A pending task instance resolving to the handler task """ return self.put_task( @@ -131,11 +163,10 @@ async def handle_message( context.injector.bind_instance(BaseResponder, responder) handler_cls = context.message.Handler - handler_obj = handler_cls() - collector: Collector = await context.inject(Collector, required=False) - if collector: - collector.wrap(handler_obj, "handle", ["any-message-handler"]) - await handler_obj.handle(context, responder) + handler = handler_cls().handle + if self.collector: + handler = self.collector.wrap_coro(handler, [handler.__qualname__]) + await handler(context, responder) async def make_message(self, parsed_msg: dict) -> AgentMessage: """ @@ -173,6 +204,10 @@ async def make_message(self, parsed_msg: dict) -> AgentMessage: return instance + async def complete(self, timeout: float = 0.1): + """Wait for pending tasks to complete.""" + await self.task_queue.complete(timeout=timeout) + class DispatcherResponder(BaseResponder): """Handle outgoing messages from message handlers.""" diff --git a/aries_cloudagent/messaging/task_queue.py b/aries_cloudagent/messaging/task_queue.py index 4a513b182e..04256ed3f8 100644 --- a/aries_cloudagent/messaging/task_queue.py +++ b/aries_cloudagent/messaging/task_queue.py @@ -2,11 +2,26 @@ import asyncio import logging +import time from typing import Callable, Coroutine, Tuple LOGGER = logging.getLogger(__name__) +def coro_ident(coro: Coroutine): + """Extract an identifier for a coroutine.""" + return coro and (hasattr(coro, "__qualname__") and coro.__qualname__ or repr(coro)) + + +async def coro_timed(coro: Coroutine, timing: dict): + """Capture timing for a coroutine.""" + timing["started"] = time.perf_counter() + try: + return await coro + finally: + timing["ended"] = time.perf_counter() + + def task_exc_info(task: asyncio.Task): """Extract exception info from an asyncio task.""" if not task or not task.done(): @@ -22,29 +37,112 @@ def task_exc_info(task: asyncio.Task): class CompletedTask: """Represent the result of a queued task.""" - # Note: this would be a good place to return timing information - - def __init__(self, task: asyncio.Task, exc_info: Tuple): + def __init__( + self, + task: asyncio.Task, + exc_info: Tuple, + ident: str = None, + timing: dict = None, + ): """Initialize the completed task.""" self.exc_info = exc_info + self.ident = ident self.task = task + self.timing = timing + + def __repr__(self) -> str: + """Generate string representation for logging.""" + return f"<{self.__class__.__name__} ident={self.ident} timing={self.timing}>" + + +class PendingTask: + """Represent a task in the queue.""" + + def __init__( + self, + coro: Coroutine, + complete_hook: Callable = None, + ident: str = None, + task_future: asyncio.Future = None, + queued_time: float = None, + ): + """ + Initialize the pending task. + + Args: + coro: The coroutine to be run + complete_hook: A callback to run on completion + ident: A string identifier for the task + task_future: A future to be resolved to the asyncio Task + queued_time: When the pending task was added to the queue + """ + if not asyncio.iscoroutine(coro): + raise ValueError(f"Expected coroutine, got {coro}") + self._cancelled = False + self.complete_hook = complete_hook + self.coro = coro + self.queued_time: float = queued_time + self.unqueued_time: float = None + self.ident = ident or coro_ident(coro) + self.task_future = task_future or asyncio.get_event_loop().create_future() + + def cancel(self): + """Cancel the pending task.""" + self.coro.close() + if not self.task_future.done(): + self.task_future.cancel() + self._cancelled = True + + @property + def cancelled(self): + """Accessor for the cancelled property.""" + return self._cancelled + + @property + def task(self) -> asyncio.Task: + """Accessor for the task.""" + return self.task_future.done() and self.task_future.result() + + @task.setter + def task(self, task: asyncio.Task): + """Setter for the task.""" + if self.task_future.cancelled(): + return + elif self.task_future.done(): + raise ValueError("Cannot set pending task future, already done") + self.task_future.set_result(task) + + def __await__(self): + """Wait for the task to be queued.""" + return self.task_future.__await__() + + def __repr__(self) -> str: + """Generate string representation for logging.""" + return f"<{self.__class__.__name__} ident={self.ident}>" class TaskQueue: """A class for managing a set of asyncio tasks.""" - def __init__(self, max_active: int = 0): + def __init__( + self, max_active: int = 0, timed: bool = False, trace_fn: Callable = None + ): """ Initialize the task queue. Args: max_active: The maximum number of tasks to automatically run + timed: A flag indicating that timing should be collected for tasks + trace_fn: A callback for all completed tasks """ self.loop = asyncio.get_event_loop() self.active_tasks = [] self.pending_tasks = [] + self.timed = timed self.total_done = 0 self.total_failed = 0 + self.total_started = 0 + self._trace_fn = trace_fn self._cancelled = False self._drain_evt = asyncio.Event() self._drain_task: asyncio.Task = None @@ -84,6 +182,14 @@ def current_size(self) -> int: """Accessor for the total number of tasks in the queue.""" return len(self.active_tasks) + len(self.pending_tasks) + def __bool__(self) -> bool: + """ + Support for the bool() builtin. + + Otherwise, evaluates as false when there are no tasks. + """ + return True + def __len__(self) -> int: """Support for the len() builtin.""" return self.current_size @@ -108,42 +214,51 @@ def _drain_done(self, task: asyncio.Task): async def _drain_loop(self): """Run pending tasks while there is room in the queue.""" # Note: this method should not call async methods apart from - # waiting for the updated event, to avoid yielding to other queue methods + # waiting for the drain event, to avoid yielding to other queue methods while True: self._drain_evt.clear() while self.pending_tasks and ( not self._max_active or len(self.active_tasks) < self._max_active ): - coro, task_complete, fut = self.pending_tasks.pop(0) - task = self.run(coro, task_complete) - if fut and not fut.done(): - fut.set_result(task) + pending: PendingTask = self.pending_tasks.pop(0) + if pending.queued_time: + pending.unqueued_time = time.perf_counter() + timing = { + "queued": pending.queued_time, + "unqueued": pending.unqueued_time, + } + else: + timing = None + task = self.run( + pending.coro, pending.complete_hook, pending.ident, timing + ) + try: + pending.task = task + except ValueError: + LOGGER.warning("Pending task future already fulfilled") if self.pending_tasks: await self._drain_evt.wait() else: break - def add_pending( - self, - coro: Coroutine, - task_complete: Callable = None, - fut: asyncio.Future = None, - ): + def add_pending(self, pending: PendingTask): """ Add a task to the pending queue. Args: - coro: The coroutine to run - task_complete: An optional callback when the task has completed - fut: A future that resolves to the task once it is queued + pending: The `PendingTask` to add to the task queue """ - if not asyncio.iscoroutine(coro): - raise ValueError(f"Expected coroutine, got {coro}") - self.pending_tasks.append((coro, task_complete, fut)) + if self.timed and not pending.queued_time: + pending.queued_time = time.perf_counter() + self.pending_tasks.append(pending) self.drain() def add_active( - self, task: asyncio.Task, task_complete: Callable = None + self, + task: asyncio.Task, + task_complete: Callable = None, + ident: str = None, + timing: dict = None, ) -> asyncio.Task: """ Register an active async task with an optional completion callback. @@ -151,18 +266,31 @@ def add_active( Args: task: The asyncio task instance task_complete: An optional callback to run on completion + ident: A string identifer for the task + timing: An optional dictionary of timing information """ self.active_tasks.append(task) - task.add_done_callback(lambda fut: self.completed_task(task, task_complete)) + task.add_done_callback( + lambda fut: self.completed_task(task, task_complete, ident, timing) + ) + self.total_started += 1 return task - def run(self, coro: Coroutine, task_complete: Callable = None) -> asyncio.Task: + def run( + self, + coro: Coroutine, + task_complete: Callable = None, + ident: str = None, + timing: dict = None, + ) -> asyncio.Task: """ Start executing a coroutine as an async task, bypassing the pending queue. Args: coro: The coroutine to run - task_complete: A callback to run on completion + task_complete: An optional callback to run on completion + ident: A string identifier for the task + timing: An optional dictionary of timing information Returns: the new asyncio task instance @@ -171,45 +299,64 @@ def run(self, coro: Coroutine, task_complete: Callable = None) -> asyncio.Task: raise RuntimeError("Task queue has been cancelled") if not asyncio.iscoroutine(coro): raise ValueError(f"Expected coroutine, got {coro}") + if not ident: + ident = coro_ident(coro) + if self.timed: + if not timing: + timing = dict() + coro = coro_timed(coro, timing) task = self.loop.create_task(coro) - return self.add_active(task, task_complete) + return self.add_active(task, task_complete, ident, timing) - def put(self, coro: Coroutine, task_complete: Callable = None) -> asyncio.Future: + def put( + self, coro: Coroutine, task_complete: Callable = None, ident: str = None + ) -> PendingTask: """ Add a new task to the queue, delaying execution if busy. Args: coro: The coroutine to run task_complete: A callback to run on completion + ident: A string identifier for the task Returns: a future resolving to the asyncio task instance once queued """ - fut = self.loop.create_future() + pending = PendingTask(coro, task_complete, ident) if self._cancelled: - coro.close() - fut.cancel() + pending.cancel() elif self.ready: - task = self.run(coro, task_complete) - fut.set_result(task) + pending.task = self.run(coro, task_complete, pending.ident) else: - self.add_pending(coro, task_complete, fut) - return fut + self.add_pending(pending) + return pending - def completed_task(self, task: asyncio.Task, task_complete: Callable): + def completed_task( + self, + task: asyncio.Task, + task_complete: Callable, + ident: str, + timing: dict = None, + ): """Clean up after a task has completed and run callbacks.""" exc_info = task_exc_info(task) if exc_info: self.total_failed += 1 - if not task_complete: - LOGGER.exception("Error running task", exc_info=exc_info) + if not task_complete and not self._trace_fn: + LOGGER.exception( + "Error running task %s", ident or "", exc_info=exc_info + ) else: self.total_done += 1 - if task_complete: + if task_complete or self._trace_fn: + completed = CompletedTask(task, exc_info, ident, timing) try: - task_complete(CompletedTask(task, exc_info)) + if task_complete: + task_complete(completed) + if self._trace_fn: + self._trace_fn(completed) except Exception: - LOGGER.exception("Error finalizing task") + LOGGER.exception("Error finalizing task %s", completed) try: self.active_tasks.remove(task) except ValueError: @@ -221,9 +368,8 @@ def cancel_pending(self): if self._drain_task: self._drain_task.cancel() self._drain_task = None - for coro, task_complete, fut in self.pending_tasks: - coro.close() - fut.cancel() + for pending in self.pending_tasks: + pending.cancel() self.pending_tasks = [] def cancel(self): diff --git a/aries_cloudagent/messaging/tests/test_task_queue.py b/aries_cloudagent/messaging/tests/test_task_queue.py index 4b1ae08a71..fc3fad77a3 100644 --- a/aries_cloudagent/messaging/tests/test_task_queue.py +++ b/aries_cloudagent/messaging/tests/test_task_queue.py @@ -1,10 +1,12 @@ import asyncio from asynctest import TestCase -from ..task_queue import CompletedTask, TaskQueue +from ..task_queue import CompletedTask, PendingTask, TaskQueue, task_exc_info -async def retval(val): +async def retval(val, *, delay=0): + if delay: + await asyncio.sleep(delay) return val @@ -38,14 +40,14 @@ def done(complete: CompletedTask): assert not complete.exc_info completed.append(complete.task.result()) - fut = queue.put(retval(1), done) + pend = queue.put(retval(1), done) assert not queue.pending_tasks await queue.flush() assert completed == [1] - assert fut.result().result() == 1 + assert pend.task.result() == 1 with self.assertRaises(ValueError): - queue.add_pending(None, done) + queue.put(None, done) async def test_put_limited(self): queue = TaskQueue(1) @@ -57,13 +59,27 @@ def done(complete: CompletedTask): assert not complete.exc_info completed.add(complete.task.result()) - fut1 = queue.put(retval(1), done) - fut2 = queue.put(retval(2), done) + pend1 = queue.put(retval(1), done) + pend2 = queue.put(retval(2), done) assert queue.pending_tasks await queue.flush() assert completed == {1, 2} - assert fut1.result().result() == 1 - assert fut2.result().result() == 2 + assert pend1.task.result() == 1 + assert pend2.task.result() == 2 + + async def test_pending(self): + coro = retval(1, delay=1) + pend = PendingTask(coro, None) + task = asyncio.get_event_loop().create_task(coro) + assert task_exc_info(task) is None + pend.task = task + assert pend.task is task + assert pend.task_future.result() is task + with self.assertRaises(ValueError): + pend.task = task + pend.cancel() + assert pend.cancelled + task.cancel() async def test_complete(self): queue = TaskQueue() @@ -88,9 +104,11 @@ def done(complete: CompletedTask): completed.add(complete.task.result()) queue.run(retval(1), done) + sleep = queue.run(retval(1, delay=1), done) queue.put(retval(2), done) queue.put(retval(3), done) queue.cancel_pending() + sleep.cancel() await queue.flush() assert completed == {1} @@ -102,7 +120,7 @@ def done(complete: CompletedTask): assert not complete.exc_info completed.add(complete.task.result()) - queue.run(retval(1), done) + queue.run(retval(1, delay=1), done) queue.put(retval(2), done) queue.put(retval(3), done) queue.cancel() @@ -117,12 +135,12 @@ def done(complete: CompletedTask): co.close() co = retval(1) - fut = queue.put(co) - assert fut.cancelled() + pend = queue.put(co) + assert pend.cancelled async def test_cancel_long(self): queue = TaskQueue() - task = queue.run(asyncio.sleep(5)) + task = queue.run(retval(1, delay=5)) queue.cancel() await queue @@ -134,7 +152,7 @@ async def test_cancel_long(self): async def test_complete_with_timeout(self): queue = TaskQueue() - task = queue.run(asyncio.sleep(5)) + task = queue.run(retval(1, delay=5)) await queue.complete(0.01) # cancellation may take a second @@ -155,5 +173,21 @@ def done(complete: CompletedTask): task = queue.run(retval(1), done) await task - queue.completed_task(task, done) + queue.completed_task(task, done, None, dict()) assert completed == [1, 1] + + async def test_timed(self): + completed = [] + + def done(complete: CompletedTask): + assert not complete.exc_info + completed.append((complete.task.result(), complete.timing)) + + queue = TaskQueue(max_active=1, timed=True, trace_fn=done) + task1 = queue.run(retval(1)) + task2 = await queue.put(retval(2)) + await queue.complete(0.1) + + assert len(completed) == 2 + assert "queued" not in completed[0][1] + assert "queued" in completed[1][1] diff --git a/aries_cloudagent/stats.py b/aries_cloudagent/stats.py index a41e42f191..e57d208cf7 100644 --- a/aries_cloudagent/stats.py +++ b/aries_cloudagent/stats.py @@ -82,7 +82,7 @@ def stop(self): if self.start_time: dur = self.now() - self.start_time for grp in self.groups: - self.collector.log(grp, dur) + self.collector.log(grp, dur, self.start_time) self.start_time = None def __enter__(self): @@ -124,12 +124,13 @@ def enabled(self, val: bool): """Setter for the collector's enabled property.""" self._enabled = val - def log(self, name: str, duration: float): + def log(self, name: str, duration: float, start: float = None): """Log an entry in the statistics if the collector is enabled.""" if self._enabled: self._stats.log(name, duration) if self._log_file: - start = time.perf_counter() - duration + if start is None: + start = time.perf_counter() - duration self._log_file.write(f"{name} {start:.5f} {duration:.5f}\n") def mark(self, *names): diff --git a/aries_cloudagent/tests/test_dispatcher.py b/aries_cloudagent/tests/test_dispatcher.py index 1cce81907b..66c1b5e7df 100644 --- a/aries_cloudagent/tests/test_dispatcher.py +++ b/aries_cloudagent/tests/test_dispatcher.py @@ -68,6 +68,7 @@ async def test_dispatch(self): {StubAgentMessage.Meta.message_type: StubAgentMessage} ) dispatcher = test_module.Dispatcher(context) + await dispatcher.setup() rcv = Receiver() message = {"@type": StubAgentMessage.Meta.message_type} @@ -84,6 +85,7 @@ async def test_dispatch(self): async def test_bad_message_dispatch(self): dispatcher = test_module.Dispatcher(make_context()) + await dispatcher.setup() rcv = Receiver() bad_message = {"bad": "message"} await dispatcher.queue_message(make_inbound(bad_message), rcv.send)