From b3249839135944182586039223ab761df02a1424 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Wed, 12 Feb 2025 12:57:35 -0500 Subject: [PATCH] fix: Fix type hints, update tests and samples --- .../framework/distributed-agent-runtime.ipynb | 446 +++++++++--------- .../runtimes/grpc/_worker_runtime.py | 4 +- .../autogen-ext/tests/test_worker_runtime.py | 48 +- 3 files changed, 249 insertions(+), 249 deletions(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/distributed-agent-runtime.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/distributed-agent-runtime.ipynb index c67c998c0a65..ce11bc2d4960 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/distributed-agent-runtime.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/distributed-agent-runtime.ipynb @@ -1,225 +1,225 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Distributed Agent Runtime\n", - "\n", - "```{attention}\n", - "The distributed agent runtime is an experimental feature. Expect breaking changes\n", - "to the API.\n", - "```\n", - "\n", - "A distributed agent runtime facilitates communication and agent lifecycle management\n", - "across process boundaries.\n", - "It consists of a host service and at least one worker runtime.\n", - "\n", - "The host service maintains connections to all active worker runtimes,\n", - "facilitates message delivery, and keeps sessions for all direct messages (i.e., RPCs).\n", - "A worker runtime processes application code (agents) and connects to the host service.\n", - "It also advertises the agents which they support to the host service,\n", - "so the host service can deliver messages to the correct worker.\n", - "\n", - "````{note}\n", - "The distributed agent runtime requires extra dependencies, install them using:\n", - "```bash\n", - "pip install \"autogen-ext[grpc]\"\n", - "```\n", - "````\n", - "\n", - "We can start a host service using {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost\n", - "\n", - "host = GrpcWorkerAgentRuntimeHost(address=\"localhost:50051\")\n", - "host.start() # Start a host service in the background." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The above code starts the host service in the background and accepts\n", - "worker connections on port 50051.\n", - "\n", - "Before running worker runtimes, let's define our agent.\n", - "The agent will publish a new message on every message it receives.\n", - "It also keeps track of how many messages it has published, and \n", - "stops publishing new messages once it has published 5 messages." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from dataclasses import dataclass\n", - "\n", - "from autogen_core import DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler\n", - "\n", - "\n", - "@dataclass\n", - "class MyMessage:\n", - " content: str\n", - "\n", - "\n", - "@default_subscription\n", - "class MyAgent(RoutedAgent):\n", - " def __init__(self, name: str) -> None:\n", - " super().__init__(\"My agent\")\n", - " self._name = name\n", - " self._counter = 0\n", - "\n", - " @message_handler\n", - " async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:\n", - " self._counter += 1\n", - " if self._counter > 5:\n", - " return\n", - " content = f\"{self._name}: Hello x {self._counter}\"\n", - " print(content)\n", - " await self.publish_message(MyMessage(content=content), DefaultTopicId())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can set up the worker agent runtimes.\n", - "We use {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime`.\n", - "We set up two worker runtimes. Each runtime hosts one agent.\n", - "All agents publish and subscribe to the default topic, so they can see all\n", - "messages being published.\n", - "\n", - "To run the agents, we publishes a message from a worker." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "worker1: Hello x 1\n", - "worker2: Hello x 1\n", - "worker2: Hello x 2\n", - "worker1: Hello x 2\n", - "worker1: Hello x 3\n", - "worker2: Hello x 3\n", - "worker2: Hello x 4\n", - "worker1: Hello x 4\n", - "worker1: Hello x 5\n", - "worker2: Hello x 5\n" - ] - } - ], - "source": [ - "import asyncio\n", - "\n", - "from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime\n", - "\n", - "worker1 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n", - "worker1.start()\n", - "await MyAgent.register(worker1, \"worker1\", lambda: MyAgent(\"worker1\"))\n", - "\n", - "worker2 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n", - "worker2.start()\n", - "await MyAgent.register(worker2, \"worker2\", lambda: MyAgent(\"worker2\"))\n", - "\n", - "await worker2.publish_message(MyMessage(content=\"Hello!\"), DefaultTopicId())\n", - "\n", - "# Let the agents run for a while.\n", - "await asyncio.sleep(5)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can see each agent published exactly 5 messages.\n", - "\n", - "To stop the worker runtimes, we can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime.stop`." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "await worker1.stop()\n", - "await worker2.stop()\n", - "\n", - "# To keep the worker running until a termination signal is received (e.g., SIGTERM).\n", - "# await worker1.stop_when_signal()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost.stop`\n", - "to stop the host service." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "await host.stop()\n", - "\n", - "# To keep the host service running until a termination signal (e.g., SIGTERM)\n", - "# await host.stop_when_signal()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Cross-Language Runtimes\n", - "The process described above is largely the same, however all message types MUST use shared protobuf schemas for all cross-agent message types.\n", - "\n", - "# Next Steps\n", - "To see complete examples of using distributed runtime, please take a look at the following samples:\n", - "\n", - "- [Distributed Workers](https://github.com/microsoft/autogen/tree/main/python/samples/core_grpc_worker_runtime) \n", - "- [Distributed Semantic Router](https://github.com/microsoft/autogen/tree/main/python/samples/core_semantic_router) \n", - "- [Distributed Group Chat](https://github.com/microsoft/autogen/tree/main/python/samples/core_distributed-group-chat) \n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "agnext", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Distributed Agent Runtime\n", + "\n", + "```{attention}\n", + "The distributed agent runtime is an experimental feature. Expect breaking changes\n", + "to the API.\n", + "```\n", + "\n", + "A distributed agent runtime facilitates communication and agent lifecycle management\n", + "across process boundaries.\n", + "It consists of a host service and at least one worker runtime.\n", + "\n", + "The host service maintains connections to all active worker runtimes,\n", + "facilitates message delivery, and keeps sessions for all direct messages (i.e., RPCs).\n", + "A worker runtime processes application code (agents) and connects to the host service.\n", + "It also advertises the agents which they support to the host service,\n", + "so the host service can deliver messages to the correct worker.\n", + "\n", + "````{note}\n", + "The distributed agent runtime requires extra dependencies, install them using:\n", + "```bash\n", + "pip install \"autogen-ext[grpc]\"\n", + "```\n", + "````\n", + "\n", + "We can start a host service using {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntimeHost\n", + "\n", + "host = GrpcWorkerAgentRuntimeHost(address=\"localhost:50051\")\n", + "host.start() # Start a host service in the background." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The above code starts the host service in the background and accepts\n", + "worker connections on port 50051.\n", + "\n", + "Before running worker runtimes, let's define our agent.\n", + "The agent will publish a new message on every message it receives.\n", + "It also keeps track of how many messages it has published, and \n", + "stops publishing new messages once it has published 5 messages." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import dataclass\n", + "\n", + "from autogen_core import DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler\n", + "\n", + "\n", + "@dataclass\n", + "class MyMessage:\n", + " content: str\n", + "\n", + "\n", + "@default_subscription\n", + "class MyAgent(RoutedAgent):\n", + " def __init__(self, name: str) -> None:\n", + " super().__init__(\"My agent\")\n", + " self._name = name\n", + " self._counter = 0\n", + "\n", + " @message_handler\n", + " async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:\n", + " self._counter += 1\n", + " if self._counter > 5:\n", + " return\n", + " content = f\"{self._name}: Hello x {self._counter}\"\n", + " print(content)\n", + " await self.publish_message(MyMessage(content=content), DefaultTopicId())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can set up the worker agent runtimes.\n", + "We use {py:class}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime`.\n", + "We set up two worker runtimes. Each runtime hosts one agent.\n", + "All agents publish and subscribe to the default topic, so they can see all\n", + "messages being published.\n", + "\n", + "To run the agents, we publishes a message from a worker." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "worker1: Hello x 1\n", + "worker2: Hello x 1\n", + "worker2: Hello x 2\n", + "worker1: Hello x 2\n", + "worker1: Hello x 3\n", + "worker2: Hello x 3\n", + "worker2: Hello x 4\n", + "worker1: Hello x 4\n", + "worker1: Hello x 5\n", + "worker2: Hello x 5\n" + ] + } + ], + "source": [ + "import asyncio\n", + "\n", + "from autogen_ext.runtimes.grpc import GrpcWorkerAgentRuntime\n", + "\n", + "worker1 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n", + "await worker1.start()\n", + "await MyAgent.register(worker1, \"worker1\", lambda: MyAgent(\"worker1\"))\n", + "\n", + "worker2 = GrpcWorkerAgentRuntime(host_address=\"localhost:50051\")\n", + "await worker2.start()\n", + "await MyAgent.register(worker2, \"worker2\", lambda: MyAgent(\"worker2\"))\n", + "\n", + "await worker2.publish_message(MyMessage(content=\"Hello!\"), DefaultTopicId())\n", + "\n", + "# Let the agents run for a while.\n", + "await asyncio.sleep(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see each agent published exactly 5 messages.\n", + "\n", + "To stop the worker runtimes, we can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntime.stop`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "await worker1.stop()\n", + "await worker2.stop()\n", + "\n", + "# To keep the worker running until a termination signal is received (e.g., SIGTERM).\n", + "# await worker1.stop_when_signal()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can call {py:meth}`~autogen_ext.runtimes.grpc.GrpcWorkerAgentRuntimeHost.stop`\n", + "to stop the host service." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "await host.stop()\n", + "\n", + "# To keep the host service running until a termination signal (e.g., SIGTERM)\n", + "# await host.stop_when_signal()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cross-Language Runtimes\n", + "The process described above is largely the same, however all message types MUST use shared protobuf schemas for all cross-agent message types.\n", + "\n", + "# Next Steps\n", + "To see complete examples of using distributed runtime, please take a look at the following samples:\n", + "\n", + "- [Distributed Workers](https://github.com/microsoft/autogen/tree/main/python/samples/core_grpc_worker_runtime) \n", + "- [Distributed Semantic Router](https://github.com/microsoft/autogen/tree/main/python/samples/core_semantic_router) \n", + "- [Distributed Group Chat](https://github.com/microsoft/autogen/tree/main/python/samples/core_distributed-group-chat) \n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index dad6715fd2b7..4035391337a2 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -166,7 +166,7 @@ async def _connect( send_queue: asyncio.Queue[agent_worker_pb2.Message], receive_queue: asyncio.Queue[agent_worker_pb2.Message], client_id: str, - ) -> asyncio.Task: + ) -> Task[None]: from grpc.aio import StreamStreamCall # TODO: where do exceptions from reading the iterable go? How do we recover from those? @@ -184,7 +184,7 @@ async def read_loop() -> None: logger.info("EOF") break logger.info(f"Received a message from host: {message}") - receive_queue.put(message) + await receive_queue.put(message) logger.info("Put message in receive queue") return asyncio.create_task(read_loop()) diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index 7849dd54768f..6545c242907e 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -42,7 +42,7 @@ async def test_agent_types_must_be_unique_single_worker() -> None: host.start() worker = GrpcWorkerAgentRuntime(host_address=host_address) - worker.start() + await worker.start() await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent) @@ -65,9 +65,9 @@ async def test_agent_types_must_be_unique_multiple_workers() -> None: host.start() worker1 = GrpcWorkerAgentRuntime(host_address=host_address) - worker1.start() + await worker1.start() worker2 = GrpcWorkerAgentRuntime(host_address=host_address) - worker2.start() + await worker2.start() await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent) @@ -91,7 +91,7 @@ async def test_register_receives_publish() -> None: host.start() worker1 = GrpcWorkerAgentRuntime(host_address=host_address) - worker1.start() + await worker1.start() worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType)) await worker1.register_factory( type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent @@ -99,7 +99,7 @@ async def test_register_receives_publish() -> None: await worker1.add_subscription(TypeSubscription("default", "name1")) worker2 = GrpcWorkerAgentRuntime(host_address=host_address) - worker2.start() + await worker2.start() worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType)) await worker2.register_factory( type=AgentType("name2"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent @@ -137,7 +137,7 @@ async def test_register_doesnt_receive_after_removing_subscription() -> None: host.start() worker1 = GrpcWorkerAgentRuntime(host_address=host_address) - worker1.start() + await worker1.start() worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType)) await worker1.register_factory( type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent @@ -177,7 +177,7 @@ async def test_register_receives_publish_cascade_single_worker() -> None: host = GrpcWorkerAgentRuntimeHost(address=host_address) host.start() runtime = GrpcWorkerAgentRuntime(host_address=host_address) - runtime.start() + await runtime.start() num_agents = 5 num_initial_messages = 5 @@ -228,14 +228,14 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None: # Register agents for i in range(num_agents): runtime = GrpcWorkerAgentRuntime(host_address=host_address) - runtime.start() + await runtime.start() await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds)) workers.append(runtime) # Publish messages publisher = GrpcWorkerAgentRuntime(host_address=host_address) publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType)) - publisher.start() + await publisher.start() for _ in range(num_initial_messages): await publisher.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId()) @@ -259,10 +259,10 @@ async def test_default_subscription() -> None: host = GrpcWorkerAgentRuntimeHost(address=host_address) host.start() worker = GrpcWorkerAgentRuntime(host_address=host_address) - worker.start() + await worker.start() publisher = GrpcWorkerAgentRuntime(host_address=host_address) publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType)) - publisher.start() + await publisher.start() await LoopbackAgentWithDefaultSubscription.register(worker, "name", lambda: LoopbackAgentWithDefaultSubscription()) @@ -294,10 +294,10 @@ async def test_default_subscription_other_source() -> None: host = GrpcWorkerAgentRuntimeHost(address=host_address) host.start() runtime = GrpcWorkerAgentRuntime(host_address=host_address) - runtime.start() + await runtime.start() publisher = GrpcWorkerAgentRuntime(host_address=host_address) publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType)) - publisher.start() + await publisher.start() await LoopbackAgentWithDefaultSubscription.register(runtime, "name", lambda: LoopbackAgentWithDefaultSubscription()) @@ -329,10 +329,10 @@ async def test_type_subscription() -> None: host = GrpcWorkerAgentRuntimeHost(address=host_address) host.start() worker = GrpcWorkerAgentRuntime(host_address=host_address) - worker.start() + await worker.start() publisher = GrpcWorkerAgentRuntime(host_address=host_address) publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType)) - publisher.start() + await publisher.start() @type_subscription("Other") class LoopbackAgentWithSubscription(LoopbackAgent): ... @@ -369,10 +369,10 @@ async def test_duplicate_subscription() -> None: worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address) host.start() try: - worker1.start() + await worker1.start() await NoopAgent.register(worker1, "worker1", lambda: NoopAgent()) - worker1_2.start() + await worker1_2.start() # Note: This passes because worker1 is still running with pytest.raises(Exception, match="Agent type worker1 already registered"): @@ -411,7 +411,7 @@ async def get_subscribed_recipients() -> List[AgentId]: return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId()) # type: ignore[reportPrivateUsage] try: - worker1.start() + await worker1.start() await LoopbackAgentWithDefaultSubscription.register( worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription() ) @@ -439,7 +439,7 @@ async def get_subscribed_recipients() -> List[AgentId]: assert len(recipients2) == 0 await asyncio.sleep(1) - worker1_2.start() + await worker1_2.start() await LoopbackAgentWithDefaultSubscription.register( worker1_2, "worker1", lambda: LoopbackAgentWithDefaultSubscription() ) @@ -480,12 +480,12 @@ async def test_proto_payloads() -> None: receiver_runtime = GrpcWorkerAgentRuntime( host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE ) - receiver_runtime.start() + await receiver_runtime.start() publisher_runtime = GrpcWorkerAgentRuntime( host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE ) publisher_runtime.add_message_serializer(try_get_known_serializers_for_type(ProtoMessage)) - publisher_runtime.start() + await publisher_runtime.start() await ProtoReceivingAgent.register(receiver_runtime, "name", ProtoReceivingAgent) @@ -535,9 +535,9 @@ async def test_grpc_max_message_size() -> None: try: host.start() - worker1.start() - worker2.start() - worker3.start() + await worker1.start() + await worker2.start() + await worker3.start() await LoopbackAgentWithDefaultSubscription.register( worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription() )