diff --git a/coagent/core/agent.py b/coagent/core/agent.py index 5f2a8b3..4be29e7 100644 --- a/coagent/core/agent.py +++ b/coagent/core/agent.py @@ -161,26 +161,31 @@ async def start(self) -> None: """Start the current agent.""" # Subscribe the agent to its own address. - self._sub = await self.channel.subscribe(self.address, handler=self.receive) - - self._handle_data_task = asyncio.create_task(self._handle_data()) + self._sub = await self._create_subscription() # Send a `Started` message to the current agent. await self.channel.publish(self.address, Started().encode(), probe=False) + if not self._handle_data_task: + self._handle_data_task = asyncio.create_task(self._handle_data()) + + async def _create_subscription(self) -> Subscription: + # Subscribe the agent's receive method to its own address. + return await self.channel.subscribe(self.address, handler=self.receive) + async def stop(self) -> None: """Stop the current agent.""" # Send a `Stopped` message to the current agent. await self.channel.publish(self.address, Stopped().encode(), probe=False) - if self._handle_data_task: - self._handle_data_task.cancel() - # Unsubscribe the agent from its own address. if self._sub: await self._sub.unsubscribe() + if self._handle_data_task: + self._handle_data_task.cancel() + async def started(self) -> None: """This handler is called after the agent is started.""" pass diff --git a/coagent/core/discovery.py b/coagent/core/discovery.py index 94d7ad1..93bb1c1 100644 --- a/coagent/core/discovery.py +++ b/coagent/core/discovery.py @@ -11,6 +11,7 @@ Address, AgentSpec, RawMessage, + Subscription, ) from .util import Trie @@ -125,13 +126,7 @@ def __init__(self): async def start(self) -> None: """Since discovery is a special agent, we need to start it in a different way.""" - - # Each query message can only be received and handled by one discovery aggregator. - self._sub = await self.channel.subscribe( - self.address, - handler=self.receive, - queue=f"{self.address.topic}_workers", - ) + await super().start() # Create and start the local discovery server. self._server = DiscoveryServer() @@ -139,13 +134,20 @@ async def start(self) -> None: self._server.init(self.channel, Address(name=f"{self.address.name}.server")) await self._server.start() + async def _create_subscription(self) -> Subscription: + # Each query message can only be received and handled by one discovery aggregator. + return await self.channel.subscribe( + self.address, + handler=self.receive, + queue=f"{self.address.topic}_workers", + ) + async def stop(self) -> None: """Since discovery is a special agent, we need to stop it in a different way.""" if self._server: await self._server.stop() - if self._sub: - await self._sub.unsubscribe() + await super().stop() async def register(self, spec: AgentSpec) -> None: if spec.name == self.address.name: @@ -250,9 +252,7 @@ def __init__(self): async def start(self) -> None: """Since discovery server is a special agent, we need to start it in a different way.""" - - # Subscribe the agent to its own address. - self._sub = await self.channel.subscribe(self.address, handler=self.receive) + await super().start() # Upon startup, the current discovery server has no agent-subscriptions. # Therefore, it's necessary to synchronize the existing agent-subscriptions @@ -283,13 +283,6 @@ async def receive(raw: RawMessage) -> None: finally: await sub.unsubscribe() - async def stop(self) -> None: - """Since discovery server is a special agent, we need to stop it in a different way.""" - - # Unsubscribe the agent from its own address. - if self._sub: - await self._sub.unsubscribe() - async def register(self, spec: AgentSpec) -> None: if spec.name == self.address.name: raise ValueError(f"Agent type '{self.address.name}' is reserved") diff --git a/coagent/core/factory.py b/coagent/core/factory.py index 973ae13..5b78cf4 100644 --- a/coagent/core/factory.py +++ b/coagent/core/factory.py @@ -8,6 +8,7 @@ Agent, AgentSpec, State, + Subscription, ) @@ -42,25 +43,24 @@ def __init__(self, spec: AgentSpec) -> None: async def start(self) -> None: """Since factory is a special agent, we need to start it in a different way.""" - # Subscribe the factory to the given address. + await super().start() + + # Start the recycle loop. + self._recycle_task = asyncio.create_task(self._recycle()) + + async def _create_subscription(self) -> Subscription: + # Each CreateAgent message can only be received and handled by one factory agent. # # Note that we specify a queue parameter to distribute requests among # multiple factory agents of the same type of primitive agent. - self._sub = await self.channel.subscribe( + return await self.channel.subscribe( self.address, handler=self.receive, queue=f"{self.address.topic}_workers", ) - # Start the recycle loop. - self._recycle_task = asyncio.create_task(self._recycle()) - async def stop(self) -> None: """Since factory is a special agent, we need to stop it in a different way.""" - # Unsubscribe the factory from the address. - if self._sub: - await self._sub.unsubscribe() - # Stop all agents. for agent in self._agents.values(): await agent.stop() @@ -70,6 +70,8 @@ async def stop(self) -> None: if self._recycle_task: self._recycle_task.cancel() + await super().stop() + async def _recycle(self) -> None: """The recycle loop for deleting idle agents.""" while True: