Skip to content

Commit

Permalink
fix: Fix type hints, update tests and samples
Browse files Browse the repository at this point in the history
  • Loading branch information
lokitoth committed Feb 12, 2025
1 parent b67affc commit 5e0e0fb
Show file tree
Hide file tree
Showing 12 changed files with 258 additions and 258 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def _connect(
send_queue: asyncio.Queue[agent_worker_pb2.Message],
receive_queue: asyncio.Queue[agent_worker_pb2.Message],
client_id: str,
) -> asyncio.Task:
) -> Task[None]:
from grpc.aio import StreamStreamCall

# TODO: where do exceptions from reading the iterable go? How do we recover from those?
Expand All @@ -184,7 +184,7 @@ async def read_loop() -> None:
logger.info("EOF")
break

Check warning on line 185 in python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py#L184-L185

Added lines #L184 - L185 were not covered by tests
logger.info(f"Received a message from host: {message}")
receive_queue.put(message)
await receive_queue.put(message)
logger.info("Put message in receive queue")

return asyncio.create_task(read_loop())
Expand Down
48 changes: 24 additions & 24 deletions python/packages/autogen-ext/tests/test_worker_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def test_agent_types_must_be_unique_single_worker() -> None:
host.start()

worker = GrpcWorkerAgentRuntime(host_address=host_address)
worker.start()
await worker.start()

await worker.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)

Expand All @@ -65,9 +65,9 @@ async def test_agent_types_must_be_unique_multiple_workers() -> None:
host.start()

worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1.start()
await worker1.start()
worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
worker2.start()
await worker2.start()

await worker1.register_factory(type=AgentType("name1"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)

Expand All @@ -91,15 +91,15 @@ async def test_register_receives_publish() -> None:
host.start()

worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1.start()
await worker1.start()
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker1.register_factory(
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
)
await worker1.add_subscription(TypeSubscription("default", "name1"))

worker2 = GrpcWorkerAgentRuntime(host_address=host_address)
worker2.start()
await worker2.start()
worker2.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker2.register_factory(
type=AgentType("name2"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
Expand Down Expand Up @@ -137,7 +137,7 @@ async def test_register_doesnt_receive_after_removing_subscription() -> None:
host.start()

worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1.start()
await worker1.start()
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker1.register_factory(
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
Expand Down Expand Up @@ -177,7 +177,7 @@ async def test_register_receives_publish_cascade_single_worker() -> None:
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
runtime.start()
await runtime.start()

num_agents = 5
num_initial_messages = 5
Expand Down Expand Up @@ -228,14 +228,14 @@ async def test_register_receives_publish_cascade_multiple_workers() -> None:
# Register agents
for i in range(num_agents):
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
runtime.start()
await runtime.start()
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))
workers.append(runtime)

# Publish messages
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(CascadingMessageType))
publisher.start()
await publisher.start()
for _ in range(num_initial_messages):
await publisher.publish_message(CascadingMessageType(round=1), topic_id=DefaultTopicId())

Expand All @@ -259,10 +259,10 @@ async def test_default_subscription() -> None:
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
worker.start()
await worker.start()
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()
await publisher.start()

await LoopbackAgentWithDefaultSubscription.register(worker, "name", lambda: LoopbackAgentWithDefaultSubscription())

Expand Down Expand Up @@ -294,10 +294,10 @@ async def test_default_subscription_other_source() -> None:
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = GrpcWorkerAgentRuntime(host_address=host_address)
runtime.start()
await runtime.start()
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()
await publisher.start()

await LoopbackAgentWithDefaultSubscription.register(runtime, "name", lambda: LoopbackAgentWithDefaultSubscription())

Expand Down Expand Up @@ -329,10 +329,10 @@ async def test_type_subscription() -> None:
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
worker.start()
await worker.start()
publisher = GrpcWorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()
await publisher.start()

@type_subscription("Other")
class LoopbackAgentWithSubscription(LoopbackAgent): ...
Expand Down Expand Up @@ -369,10 +369,10 @@ async def test_duplicate_subscription() -> None:
worker1_2 = GrpcWorkerAgentRuntime(host_address=host_address)
host.start()
try:
worker1.start()
await worker1.start()
await NoopAgent.register(worker1, "worker1", lambda: NoopAgent())

worker1_2.start()
await worker1_2.start()

# Note: This passes because worker1 is still running
with pytest.raises(Exception, match="Agent type worker1 already registered"):
Expand Down Expand Up @@ -411,7 +411,7 @@ async def get_subscribed_recipients() -> List[AgentId]:
return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId()) # type: ignore[reportPrivateUsage]

try:
worker1.start()
await worker1.start()
await LoopbackAgentWithDefaultSubscription.register(
worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
)
Expand Down Expand Up @@ -439,7 +439,7 @@ async def get_subscribed_recipients() -> List[AgentId]:
assert len(recipients2) == 0
await asyncio.sleep(1)

worker1_2.start()
await worker1_2.start()
await LoopbackAgentWithDefaultSubscription.register(
worker1_2, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
)
Expand Down Expand Up @@ -480,12 +480,12 @@ async def test_proto_payloads() -> None:
receiver_runtime = GrpcWorkerAgentRuntime(
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
)
receiver_runtime.start()
await receiver_runtime.start()
publisher_runtime = GrpcWorkerAgentRuntime(
host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE
)
publisher_runtime.add_message_serializer(try_get_known_serializers_for_type(ProtoMessage))
publisher_runtime.start()
await publisher_runtime.start()

await ProtoReceivingAgent.register(receiver_runtime, "name", ProtoReceivingAgent)

Expand Down Expand Up @@ -535,9 +535,9 @@ async def test_grpc_max_message_size() -> None:

try:
host.start()
worker1.start()
worker2.start()
worker3.start()
await worker1.start()
await worker2.start()
await worker3.start()
await LoopbackAgentWithDefaultSubscription.register(
worker1, "worker1", lambda: LoopbackAgentWithDefaultSubscription()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def main(config: AppConfig):
editor_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
await asyncio.sleep(4)
Console().print(Markdown("Starting **`Editor Agent`**"))
editor_agent_runtime.start()
await editor_agent_runtime.start()
editor_agent_type = await BaseGroupChatAgent.register(
editor_agent_runtime,
config.editor_agent.topic_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def main(config: AppConfig):
group_chat_manager_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]
await asyncio.sleep(1)
Console().print(Markdown("Starting **`Group Chat Manager`**"))
group_chat_manager_runtime.start()
await group_chat_manager_runtime.start()
set_all_log_levels(logging.ERROR)

group_chat_manager_type = await GroupChatManager.register(
Expand Down
2 changes: 1 addition & 1 deletion python/samples/core_distributed-group-chat/run_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def main(config: AppConfig):
ui_agent_runtime.add_message_serializer(get_serializers([RequestToSpeak, GroupChatMessage, MessageChunk])) # type: ignore[arg-type]

Console().print(Markdown("Starting **`UI Agent`**"))
ui_agent_runtime.start()
await ui_agent_runtime.start()
set_all_log_levels(logging.ERROR)

ui_agent_type = await UIAgent.register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def main(config: AppConfig) -> None:
await asyncio.sleep(3)
Console().print(Markdown("Starting **`Writer Agent`**"))

writer_agent_runtime.start()
await writer_agent_runtime.start()
writer_agent_type = await BaseGroupChatAgent.register(
writer_agent_runtime,
config.writer_agent.topic_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
async def main() -> None:
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.add_message_serializer(try_get_known_serializers_for_type(CascadingMessage))
runtime.start()
await runtime.start()
await ObserverAgent.register(runtime, "observer_agent", lambda: ObserverAgent())
await runtime.publish_message(CascadingMessage(round=1), topic_id=DefaultTopicId())
await runtime.stop_when_signal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
async def main() -> None:
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.add_message_serializer(try_get_known_serializers_for_type(ReceiveMessageEvent))
runtime.start()
await runtime.start()
agent_type = f"cascading_agent_{uuid.uuid4()}".replace("-", "_")
await CascadingAgent.register(runtime, agent_type, lambda: CascadingAgent(max_rounds=3))
await runtime.stop_when_signal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoRet

async def main() -> None:
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.start()
await runtime.start()
for t in [AskToGreet, Greeting, ReturnedGreeting, Feedback, ReturnedFeedback]:
runtime.add_message_serializer(try_get_known_serializers_for_type(t))

Expand Down
2 changes: 1 addition & 1 deletion python/samples/core_grpc_worker_runtime/run_worker_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:

async def main() -> None:
runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")
runtime.start()
await runtime.start()

await ReceiveAgent.register(
runtime,
Expand Down
2 changes: 1 addition & 1 deletion python/samples/core_semantic_router/run_semantic_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def output_result(
async def run_workers():
agent_runtime = GrpcWorkerAgentRuntime(host_address="localhost:50051")

agent_runtime.start()
await agent_runtime.start()

# Create the agents
await WorkerAgent.register(agent_runtime, "finance", lambda: WorkerAgent("finance_agent"))
Expand Down

0 comments on commit 5e0e0fb

Please sign in to comment.