Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improved rendering of consecutive/multi-step messages and tool calls/results #327

Merged
merged 17 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions assistants/codespace-assistant/.vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,18 @@
"consoleTitle": "${workspaceFolderBasename}"
//"justMyCode": false, // Set to false to debug external libraries
}
],
"compounds": [
{
"name": "assistants: codespace-assistant (for dev)",
"configurations": [
"assistants: codespace-assistant",
"app: semantic-workbench-app",
"service: semantic-workbench-service",
"mcp-servers: mcp-server-bing-search",
"mcp-servers: mcp-server-giphy",
"mcp-servers: mcp-server-open-deep-research"
]
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def establish_mcp_sessions(tools_config: ToolsConfigModel, stack: AsyncExi
client_session: ClientSession | None = await stack.enter_async_context(connect_to_mcp_server(server_config))
if client_session:
# Create an MCP session with the client session
mcp_session = MCPSession(name=server_config.key, client_session=client_session)
mcp_session = MCPSession(config=server_config, client_session=client_session)
# Initialize the session to load tools, resources, etc.
await mcp_session.initialize()
# Add the session to the list of established sessions
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# utils/tool_utils.py
import logging
from typing import List
from textwrap import dedent
from typing import AsyncGenerator, List

import deepmerge
from mcp import Tool
Expand All @@ -23,47 +24,27 @@ def retrieve_tools_from_sessions(mcp_sessions: List[MCPSession], tools_config: T
]


async def handle_tool_call(
def get_mcp_session_and_tool_by_tool_name(
mcp_sessions: List[MCPSession],
tool_call: ToolCall,
method_metadata_key: str,
) -> ToolCallResult:
tool_name: str,
) -> tuple[MCPSession | None, Tool | None]:
"""
Handle the tool call by invoking the appropriate tool and returning a ToolCallResult.
Retrieve the MCP session and tool by tool name.
"""

# Initialize metadata
metadata = {}

# Find the tool and session from the full collection of sessions
mcp_session, tool = next(
(
(mcp_session, tool)
for mcp_session in mcp_sessions
for tool in mcp_session.tools
if tool.name == tool_call.name
),
return next(
((mcp_session, tool) for mcp_session in mcp_sessions for tool in mcp_session.tools if tool.name == tool_name),
(None, None),
)
if not mcp_session or not tool:
return ToolCallResult(
id=tool_call.id,
content=f"Tool '{tool_call.name}' not found in any of the sessions.",
message_type=ToolMessageType.notice,
metadata={},
)

# Update metadata with tool call details
deepmerge.always_merger.merge(
metadata,
{
"debug": {
method_metadata_key: {
"tool_call": tool_call.to_json(),
},
},
},
)

async def execute_tool_call(
mcp_session: MCPSession,
tool_call: ToolCall,
method_metadata_key: str,
) -> ToolCallResult:

# Initialize metadata
metadata = {}

# Initialize tool_result
tool_result = None
Expand All @@ -72,7 +53,7 @@ async def handle_tool_call(

# Invoke the tool
try:
logger.debug(f"Invoking '{mcp_session.name}.{tool_call.name}' with arguments: {tool_call.arguments}")
logger.debug(f"Invoking '{mcp_session.config.key}.{tool_call.name}' with arguments: {tool_call.arguments}")
tool_result = await mcp_session.client_session.call_tool(tool_call.name, tool_call.arguments)
tool_output = tool_result.content
except Exception as e:
Expand Down Expand Up @@ -106,3 +87,63 @@ async def handle_tool_call(
message_type=ToolMessageType.tool_result,
metadata=metadata,
)

async def handle_tool_call(
mcp_sessions: List[MCPSession],
tool_call: ToolCall,
method_metadata_key: str,
) -> ToolCallResult:
"""
Handle the tool call by invoking the appropriate tool and returning a ToolCallResult.
"""

# Find the tool and session from the full collection of sessions
mcp_session, tool = get_mcp_session_and_tool_by_tool_name(mcp_sessions, tool_call.name)

if not mcp_session or not tool:
return ToolCallResult(
id=tool_call.id,
content=f"Tool '{tool_call.name}' not found in any of the sessions.",
message_type=ToolMessageType.notice,
metadata={},
)

return await execute_tool_call(mcp_session, tool_call, method_metadata_key)


async def handle_long_running_tool_call(
mcp_sessions: List[MCPSession],
tool_call: ToolCall,
method_metadata_key: str,
) -> AsyncGenerator[ToolCallResult, None]:
"""
Handle the streaming tool call by invoking the appropriate tool and returning a ToolCallResult.
"""

# Find the tool and session from the full collection of sessions
mcp_session, tool = get_mcp_session_and_tool_by_tool_name(mcp_sessions, tool_call.name)

if not mcp_session or not tool:
yield ToolCallResult(
id=tool_call.id,
content=f"Tool '{tool_call.name}' not found in any of the sessions.",
message_type=ToolMessageType.notice,
metadata={},
)
return

# For now, let's just hack to return an immediate response to indicate that the tool call was received
# and is being processed and that the results will be sent in a separate message.
yield ToolCallResult(
id=tool_call.id,
content=dedent(f"""
Processing tool call '{tool_call.name}'.
Estimated time to completion: {mcp_session.config.task_completion_estimate}
""").strip(),
message_type=ToolMessageType.tool_result,
metadata={},
)

# Perform the tool call
tool_call_result = await execute_tool_call(mcp_session, tool_call, method_metadata_key)
yield tool_call_result
153 changes: 84 additions & 69 deletions assistants/codespace-assistant/assistant/extensions/tools/__model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,6 @@
logger = logging.getLogger(__name__)


class MCPSession:
name: str
client_session: ClientSession
tools: List[Tool] = []

def __init__(self, name: str, client_session: ClientSession) -> None:
self.name = name
self.client_session = client_session

async def initialize(self) -> None:
# Load all tools from the session, later we can do the same for resources, prompts, etc.
tools_result = await self.client_session.list_tools()
self.tools = tools_result.tools
logger.debug(f"Loaded {len(tools_result.tools)} tools from session '{self.name}'")


@dataclass
class ToolCall:
id: str
name: str
arguments: dict[str, Any]

def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"arguments": self.arguments,
}

def to_json(self, **kwargs) -> str:
return json.dumps(self, default=lambda o: o.__dict__, **kwargs)


class ToolMessageType(StrEnum):
notice = "notice"
tool_result = "tool_result"


@dataclass
class ToolCallResult:
id: str
content: str
message_type: ToolMessageType
metadata: dict[str, Any]


class MCPServerEnvConfig(BaseModel):
key: Annotated[str, Field(title="Key", description="Environment variable key.")]
value: Annotated[str, Field(title="Value", description="Environment variable value.")]
Expand All @@ -69,11 +23,7 @@ class MCPServerConfig(BaseModel):
key: Annotated[str, Field(title="Key", description="Unique key for the server configuration.")]

command: Annotated[
str,
Field(
title="Command",
description="Command to run the server, use url if using SSE transport."
)
str, Field(title="Command", description="Command to run the server, use url if using SSE transport.")
]

args: Annotated[List[str], Field(title="Arguments", description="Arguments to pass to the server.")]
Expand All @@ -89,6 +39,19 @@ class MCPServerConfig(BaseModel):
UISchema(widget="textarea"),
] = ""

long_running: Annotated[
bool,
Field(title="Long Running", description="Does this server run long running tasks?"),
] = False

task_completion_estimate: Annotated[
int,
Field(
title="Long Running Task Completion Time Estimate",
description="Estimated time to complete an average long running task (in seconds).",
),
] = 30


class ToolsConfigModel(BaseModel):
enabled: Annotated[
Expand Down Expand Up @@ -150,6 +113,30 @@ class ToolsConfigModel(BaseModel):
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/workspaces/semanticworkbench"],
),
MCPServerConfig(
key="vscode",
command="http://127.0.0.1:6010/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="bing-search",
command="http://127.0.0.1:6030/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="open-deep-research",
command="http://127.0.0.1:6020/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="giphy",
command="http://http://127.0.0.1:6000/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="memory",
command="npx",
Expand Down Expand Up @@ -178,24 +165,6 @@ class ToolsConfigModel(BaseModel):
""").strip(),
enabled=False,
),
MCPServerConfig(
key="open-deep-research",
command="http://127.0.0.1:6020/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="vscode",
command="http://127.0.0.1:6010/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="giphy",
command="http://http://127.0.0.1:6000/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="sequential-thinking",
command="npx",
Expand All @@ -214,3 +183,49 @@ class ToolsConfigModel(BaseModel):
""").strip(),
),
] = ["directory_tree"]


class MCPSession:
config: MCPServerConfig
client_session: ClientSession
tools: List[Tool] = []

def __init__(self, config: MCPServerConfig, client_session: ClientSession) -> None:
self.config = config
self.client_session = client_session

async def initialize(self) -> None:
# Load all tools from the session, later we can do the same for resources, prompts, etc.
tools_result = await self.client_session.list_tools()
self.tools = tools_result.tools
logger.debug(f"Loaded {len(tools_result.tools)} tools from session '{self.config.key}'")


@dataclass
class ToolCall:
id: str
name: str
arguments: dict[str, Any]

def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"arguments": self.arguments,
}

def to_json(self, **kwargs) -> str:
return json.dumps(self, default=lambda o: o.__dict__, **kwargs)


class ToolMessageType(StrEnum):
notice = "notice"
tool_result = "tool_result"


@dataclass
class ToolCallResult:
id: str
content: str
message_type: ToolMessageType
metadata: dict[str, Any]
Loading