Skip to content

Commit

Permalink
Simplifies event handler decorator (microsoft#196)
Browse files Browse the repository at this point in the history
To avoid overload
  • Loading branch information
markwaddle authored Nov 1, 2024
1 parent 583bb96 commit d4a37db
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Protocol,
TypeVar,
Union,
overload,
)

import typing_extensions
Expand Down Expand Up @@ -126,15 +125,24 @@ class ObjectEventHandlers(Generic[EventHandlerT]):
def __init__(self, on_created=True, on_updated=True, on_deleted=True) -> None:
if on_created:
self._on_created_handlers = EventHandlerList[EventHandlerT]()
self.on_created = _create_decorator(self._on_created_handlers)
self.on_created = _create_decorator(self._on_created_handlers, "others")
"""event handler for created event; excluding events from this assistant service"""
self.on_created_including_mine = _create_decorator(self._on_created_handlers, "all")
"""event handler for created event; including events from this assistant service"""

if on_updated:
self._on_updated_handlers = EventHandlerList[EventHandlerT]()
self.on_updated = _create_decorator(self._on_updated_handlers)
self.on_updated = _create_decorator(self._on_updated_handlers, "others")
"""event handler for updated event; excluding events from this assistant service"""
self.on_updated_including_mine = _create_decorator(self._on_updated_handlers, "all")
"""event handler for updated event; including events from this assistant service"""

if on_deleted:
self._on_deleted_handlers = EventHandlerList[EventHandlerT]()
self.on_deleted = _create_decorator(self._on_deleted_handlers)
self.on_deleted = _create_decorator(self._on_deleted_handlers, "others")
"""event handler for deleted event; excluding events from this assistant service"""
self.on_deleted_including_mine = _create_decorator(self._on_deleted_handlers, "all")
"""event handler for deleted event; including events from this assistant service"""


LifecycleEventHandler = Callable[[], Awaitable[None] | None]
Expand All @@ -143,45 +151,20 @@ def __init__(self, on_created=True, on_updated=True, on_deleted=True) -> None:
class LifecycleEventHandlers:
def __init__(self) -> None:
self._on_service_start_handlers = EventHandlerList[LifecycleEventHandler]()
self.on_service_start = _create_decorator(self._on_service_start_handlers)
self.on_service_start = _create_decorator(self._on_service_start_handlers, "all")

self._on_service_shutdown_handlers = EventHandlerList[LifecycleEventHandler]()
self.on_service_shutdown = _create_decorator(self._on_service_shutdown_handlers)
self.on_service_shutdown = _create_decorator(self._on_service_shutdown_handlers, "all")


def _create_decorator(
handler_list: EventHandlerList[EventHandlerT],
):
@overload
def decorator(func_or_include: EventHandlerT) -> EventHandlerT: ...

@overload
def decorator(
func_or_include: IncludeEventsFromActors | None = "others",
) -> Callable[[EventHandlerT], EventHandlerT]: ...

def decorator(
func_or_include: EventHandlerT | IncludeEventsFromActors | None = "others",
) -> EventHandlerT | Callable[[EventHandlerT], EventHandlerT]:
filter: IncludeEventsFromActors = "others"
match func_or_include:
case "all":
filter = "all"
case "this_assistant_service":
filter = "this_assistant_service"

def _decorator(func: EventHandlerT) -> EventHandlerT:
handler_list.append((func, filter))
return func

# decorator with no arguments
if callable(func_or_include):
return _decorator(func_or_include)

# decorator with arguments
return _decorator

return decorator
handler_list: EventHandlerList[EventHandlerT], filter: IncludeEventsFromActors
) -> Callable[[EventHandlerT], EventHandlerT]:
def _decorator(func: EventHandlerT) -> EventHandlerT:
handler_list.append((func, filter))
return func

return _decorator


AssistantEventHandler = Callable[[AssistantContext], Awaitable[None] | None]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ async def test_assistant_with_event_handlers(
assistant_created_calls = 0
conversation_created_calls = 0
message_created_calls = 0
message_created_with_parens_calls = 0
message_created_all_calls = 0
message_chat_created_calls = 0

Expand All @@ -90,16 +89,7 @@ def on_message_created(
nonlocal message_created_calls
message_created_calls += 1

@app.events.conversation.message.on_created()
def on_message_created_with_parens(
conversation_context: ConversationContext,
_: workbench_model.ConversationEvent,
message: workbench_model.ConversationMessage,
) -> None:
nonlocal message_created_with_parens_calls
message_created_with_parens_calls += 1

@app.events.conversation.message.on_created("all")
@app.events.conversation.message.on_created_including_mine
def on_message_created_all(
conversation_context: ConversationContext,
_: workbench_model.ConversationEvent,
Expand Down Expand Up @@ -174,7 +164,6 @@ async def on_chat_message(

assert message_created_calls == 1
assert message_created_all_calls == 1
assert message_created_with_parens_calls == 1
assert message_chat_created_calls == 1

# send a message of type "notice"
Expand All @@ -201,7 +190,6 @@ async def on_chat_message(
)

assert message_created_calls == 2
assert message_created_with_parens_calls == 2
assert message_created_all_calls == 2
assert message_chat_created_calls == 1

Expand Down Expand Up @@ -230,7 +218,6 @@ async def on_chat_message(

# these should remain unchanged
assert message_chat_created_calls == 1
assert message_created_with_parens_calls == 2
assert message_created_calls == 2

# this should have been called
Expand Down

0 comments on commit d4a37db

Please sign in to comment.