Skip to content

Commit

Permalink
Merge pull request #287 from andrewwhitehead/feature/dispatch-active
Browse files Browse the repository at this point in the history
Dispatcher queue improvements
  • Loading branch information
swcurran authored Dec 10, 2019
2 parents df82c8e + a3a0f1c commit 0a9788c
Show file tree
Hide file tree
Showing 9 changed files with 317 additions and 95 deletions.
17 changes: 7 additions & 10 deletions aries_cloudagent/admin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ def __init__(
):
"""Initialize the webhook target."""
self.endpoint = endpoint
self._topic_filter = None
self.retries = retries
# call setter
self.topic_filter = topic_filter
self._topic_filter = None
self.topic_filter = topic_filter # call setter

@property
def topic_filter(self) -> Set[str]:
Expand Down Expand Up @@ -176,6 +175,8 @@ async def check_token(request, handler):

middlewares.append(check_token)

collector: Collector = await self.context.inject(Collector, required=False)

if self.task_queue:

@web.middleware
Expand All @@ -185,14 +186,11 @@ async def apply_limiter(request, handler):

middlewares.append(apply_limiter)

stats: Collector = await self.context.inject(Collector, required=False)
if stats:
elif collector:

@web.middleware
async def collect_stats(request, handler):
handler = stats.wrap_coro(
handler, [handler.__qualname__, "any-admin-request"]
)
handler = collector.wrap_coro(handler, [handler.__qualname__])
return await handler(request)

middlewares.append(collect_stats)
Expand Down Expand Up @@ -231,7 +229,7 @@ async def collect_stats(request, handler):
for route in app.router.routes():
cors.add(route)
# get agent label
agent_label = self.context.settings.get("default_label"),
agent_label = self.context.settings.get("default_label")
version_string = f"v{__version__}"

setup_aiohttp_apispec(
Expand Down Expand Up @@ -288,7 +286,6 @@ async def plugins_handler(self, request: web.BaseRequest):
registry: PluginRegistry = await self.context.inject(
PluginRegistry, required=False
)
print(registry)
plugins = registry and sorted(registry.plugin_names) or []
return web.json_response({"result": plugins})

Expand Down
6 changes: 4 additions & 2 deletions aries_cloudagent/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def setup(self):
context = await self.context_builder.build()

self.dispatcher = Dispatcher(context)
await self.dispatcher.setup()

wire_format = await context.inject(BaseWireFormat, required=False)
if wire_format and hasattr(wire_format, "task_queue"):
Expand Down Expand Up @@ -118,12 +119,11 @@ async def setup(self):
# "create_inbound_session",
),
)
collector.wrap(self.dispatcher, "handle_message")
# at the class level (!) should not be performed multiple times
collector.wrap(
ConnectionManager,
(
"get_connection_targets",
# "get_connection_targets",
"fetch_did_document",
"find_inbound_connection",
),
Expand Down Expand Up @@ -214,6 +214,8 @@ async def start(self) -> None:
async def stop(self, timeout=1.0):
"""Stop the agent."""
shutdown = TaskQueue()
if self.dispatcher:
shutdown.run(self.dispatcher.complete())
if self.admin_server:
shutdown.run(self.admin_server.stop())
if self.inbound_transport_manager:
Expand Down
14 changes: 8 additions & 6 deletions aries_cloudagent/config/default_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ..stats import Collector
from ..storage.base import BaseStorage
from ..storage.provider import StorageProvider
from ..transport.pack_format import PackWireFormat
from ..transport.wire_format import BaseWireFormat
from ..wallet.base import BaseWallet
from ..wallet.provider import WalletProvider
Expand Down Expand Up @@ -67,14 +66,12 @@ async def bind_providers(self, context: InjectionContext):
StatsProvider(
WalletProvider(),
(
"create",
"open",
"sign_message",
"verify_message",
"encrypt_message",
"decrypt_message",
"pack_message",
"unpack_message",
# "pack_message",
# "unpack_message",
"get_local_did",
),
)
Expand Down Expand Up @@ -128,7 +125,12 @@ async def bind_providers(self, context: InjectionContext):
BaseWireFormat,
CachedProvider(
StatsProvider(
ClassProvider(PackWireFormat), ("encode_message", "parse_message"),
ClassProvider(
"aries_cloudagent.transport.pack_format.PackWireFormat"
),
(
# "encode_message", "parse_message"
),
)
),
)
Expand Down
9 changes: 6 additions & 3 deletions aries_cloudagent/config/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def __init__(
async def provide(self, config: BaseSettings, injector: BaseInjector):
"""Provide the object instance given a config and injector."""
instance = await self._provider.provide(config, injector)
collector: Collector = await injector.inject(Collector, required=False)
if collector:
collector.wrap(instance, self._methods, ignore_missing=self._ignore_missing)
if self._methods:
collector: Collector = await injector.inject(Collector, required=False)
if collector:
collector.wrap(
instance, self._methods, ignore_missing=self._ignore_missing
)
return instance
61 changes: 48 additions & 13 deletions aries_cloudagent/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

import asyncio
import logging
import os
from typing import Callable, Coroutine, Union

from aiohttp.web import HTTPException

from .config.injection_context import InjectionContext
from .messaging.agent_message import AgentMessage
from .messaging.error import MessageParseError
Expand All @@ -18,7 +21,7 @@
from .messaging.protocol_registry import ProtocolRegistry
from .messaging.request_context import RequestContext
from .messaging.responder import BaseResponder
from .messaging.task_queue import TaskQueue
from .messaging.task_queue import CompletedTask, PendingTask, TaskQueue
from .messaging.util import datetime_now
from .stats import Collector
from .transport.inbound.message import InboundMessage
Expand All @@ -38,23 +41,52 @@ class Dispatcher:
def __init__(self, context: InjectionContext):
"""Initialize an instance of Dispatcher."""
self.context = context
self.task_queue = TaskQueue(max_active=20)
self.collector: Collector = None
self.task_queue: TaskQueue = None

async def setup(self):
"""Perform async instance setup."""
self.collector = await self.context.inject(Collector, required=False)
max_active = int(os.getenv("DISPATCHER_MAX_ACTIVE", 50))
self.task_queue = TaskQueue(
max_active=max_active, timed=bool(self.collector), trace_fn=self.log_task
)

def put_task(self, coro: Coroutine, complete: Callable = None) -> asyncio.Future:
def put_task(
self, coro: Coroutine, complete: Callable = None, ident: str = None
) -> PendingTask:
"""Run a task in the task queue, potentially blocking other handlers."""
return self.task_queue.put(coro, complete)
return self.task_queue.put(coro, complete, ident)

def run_task(self, coro: Coroutine, complete: Callable = None) -> asyncio.Task:
def run_task(
self, coro: Coroutine, complete: Callable = None, ident: str = None
) -> asyncio.Task:
"""Run a task in the task queue, potentially blocking other handlers."""
return self.task_queue.run(coro, complete)
return self.task_queue.run(coro, complete, ident)

def log_task(self, task: CompletedTask):
"""Log a completed task using the stats collector."""
if task.exc_info and not issubclass(task.exc_info[0], HTTPException):
# skip errors intentionally returned to HTTP clients
LOGGER.exception(
"Handler error: %s", task.ident or "", exc_info=task.exc_info
)
if self.collector:
timing = task.timing
if "queued" in timing:
self.collector.log(
f"Dispatcher:queued", timing["unqueued"] - timing["queued"]
)
if task.ident:
self.collector.log(task.ident, timing["ended"] - timing["started"])

def queue_message(
self,
inbound_message: InboundMessage,
send_outbound: Coroutine,
send_webhook: Coroutine = None,
complete: Callable = None,
) -> asyncio.Future:
) -> PendingTask:
"""
Add a message to the processing queue for handling.
Expand All @@ -65,7 +97,7 @@ def queue_message(
complete: Function to call when the handler has completed
Returns:
A future resolving to the handler task
A pending task instance resolving to the handler task
"""
return self.put_task(
Expand Down Expand Up @@ -131,11 +163,10 @@ async def handle_message(
context.injector.bind_instance(BaseResponder, responder)

handler_cls = context.message.Handler
handler_obj = handler_cls()
collector: Collector = await context.inject(Collector, required=False)
if collector:
collector.wrap(handler_obj, "handle", ["any-message-handler"])
await handler_obj.handle(context, responder)
handler = handler_cls().handle
if self.collector:
handler = self.collector.wrap_coro(handler, [handler.__qualname__])
await handler(context, responder)

async def make_message(self, parsed_msg: dict) -> AgentMessage:
"""
Expand Down Expand Up @@ -173,6 +204,10 @@ async def make_message(self, parsed_msg: dict) -> AgentMessage:

return instance

async def complete(self, timeout: float = 0.1):
"""Wait for pending tasks to complete."""
await self.task_queue.complete(timeout=timeout)


class DispatcherResponder(BaseResponder):
"""Handle outgoing messages from message handlers."""
Expand Down
Loading

0 comments on commit 0a9788c

Please sign in to comment.