Skip to content

Commit

Permalink
Merge StreamChatAgent into ChatAgent
Browse files Browse the repository at this point in the history
  • Loading branch information
RussellLuo committed Jan 22, 2025
1 parent 516127e commit f4a4b32
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 49 deletions.
2 changes: 1 addition & 1 deletion coagent/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ruff: noqa: F401
from .chat_agent import ChatAgent, confirm, submit, RunContext, StreamChatAgent, tool
from .chat_agent import ChatAgent, confirm, submit, RunContext, tool
from .dynamic_triage import DynamicTriage
from .mcp_agent import MCPAgent
from .messages import ChatHistory, ChatMessage
Expand Down
41 changes: 2 additions & 39 deletions coagent/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,6 @@ def __init__(self, host_agent: ChatAgent, agent_type: str):
self.host_agent: ChatAgent = host_agent
self.agent_type: str = agent_type

async def handle(self, msg: ChatHistory) -> ChatMessage:
addr = Address(name=self.agent_type, id=self.host_agent.address.id)
result = await self.host_agent.channel.publish(addr, msg.encode(), request=True)
return ChatMessage.decode(result)


class StreamDelegate:
"""A streaming delegate agent that helps to handle a specific task."""

def __init__(self, host_agent: StreamChatAgent, agent_type: str):
self.host_agent: StreamChatAgent = host_agent
self.agent_type: str = agent_type

async def handle(self, msg: ChatHistory) -> AsyncIterator[ChatMessage]:
addr = Address(name=self.agent_type, id=self.host_agent.address.id)
result = self.host_agent.channel.publish_multi(addr, msg.encode())
Expand Down Expand Up @@ -186,7 +173,7 @@ async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
return run


class StreamChatAgent(BaseAgent):
class ChatAgent(BaseAgent):
def __init__(
self,
name: str = "",
Expand Down Expand Up @@ -239,7 +226,7 @@ async def get_swarm_agent(self) -> SwarmAgent:

async def agent(self, agent_type: str) -> AsyncIterator[ChatMessage]:
"""The candidate agent to delegate the conversation to."""
async for chunk in StreamDelegate(self, agent_type).handle(self._history):
async for chunk in Delegate(self, agent_type).handle(self._history):
yield chunk

@handler
Expand Down Expand Up @@ -311,27 +298,3 @@ async def _is_submit_message(self, history: ChatHistory) -> bool:
return False
last_msg = history.messages[-1]
return last_msg.role == "user" and last_msg.type == "submit"


class ChatAgent(StreamChatAgent):
"""Non-streaming ChatAgent."""

async def agent(self, agent_type: str) -> ChatMessage:
"""The candidate agent to delegate the conversation to."""
return await Delegate(self, agent_type).handle(self._history)

@handler
async def handle_history(self, msg: ChatHistory, ctx: Context) -> ChatMessage:
accumulated_response = ChatMessage(role="assistant", content="")
response = super().handle_history(msg, ctx)
async for chunk in response:
accumulated_response.content += chunk.content
return accumulated_response

@handler
async def handle_message(self, msg: ChatMessage, ctx: Context) -> ChatMessage:
accumulated_response = ChatMessage(role="assistant", content="")
response = super().handle_message(msg, ctx)
async for chunk in response:
accumulated_response.content += chunk.content
return accumulated_response
4 changes: 2 additions & 2 deletions coagent/agents/dynamic_triage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

from .aswarm import Agent as SwarmAgent, Swarm
from .chat_agent import ChatHistory, ChatMessage, StreamDelegate
from .chat_agent import ChatHistory, ChatMessage, Delegate
from .model_client import default_model_client, ModelClient


Expand Down Expand Up @@ -104,7 +104,7 @@ async def _update_swarm_agent(self) -> None:

def _transfer_to_agent(self, agent_type: str):
async def run() -> AsyncIterator[ChatMessage]:
async for chunk in StreamDelegate(self, agent_type).handle(self._history):
async for chunk in Delegate(self, agent_type).handle(self._history):
yield chunk

return run
Expand Down
9 changes: 2 additions & 7 deletions coagent/agents/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@ class ChatMessage(Message):
)

def __add__(self, other: ChatMessage) -> ChatMessage:
return ChatMessage(
role=self.role,
content=self.content + other.content,
type=self.type,
sender=self.sender,
to_user=self.to_user,
)
self.content += other.content
return self

def model_dump(self, **kwargs) -> dict[str, Any]:
return super().model_dump(include={"role", "content"}, **kwargs)
Expand Down

0 comments on commit f4a4b32

Please sign in to comment.