Skip to content

Commit

Permalink
improved rendering of consecutive/multi-step messages and tool calls/…
Browse files Browse the repository at this point in the history
…results (#327)

See MESSAGE_METADATA.md for more details on how to take advantage of
tools calls/results via special metadata values.
  • Loading branch information
bkrabach authored Feb 16, 2025
1 parent 518c909 commit e2abc74
Show file tree
Hide file tree
Showing 45 changed files with 1,507 additions and 723 deletions.
13 changes: 13 additions & 0 deletions assistants/codespace-assistant/.vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,18 @@
"consoleTitle": "${workspaceFolderBasename}"
//"justMyCode": false, // Set to false to debug external libraries
}
],
"compounds": [
{
"name": "assistants: codespace-assistant (for dev)",
"configurations": [
"assistants: codespace-assistant",
"app: semantic-workbench-app",
"service: semantic-workbench-service",
"mcp-servers: mcp-server-bing-search",
"mcp-servers: mcp-server-giphy",
"mcp-servers: mcp-server-open-deep-research"
]
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def establish_mcp_sessions(tools_config: ToolsConfigModel, stack: AsyncExi
client_session: ClientSession | None = await stack.enter_async_context(connect_to_mcp_server(server_config))
if client_session:
# Create an MCP session with the client session
mcp_session = MCPSession(name=server_config.key, client_session=client_session)
mcp_session = MCPSession(config=server_config, client_session=client_session)
# Initialize the session to load tools, resources, etc.
await mcp_session.initialize()
# Add the session to the list of established sessions
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# utils/tool_utils.py
import logging
from typing import List
from textwrap import dedent
from typing import AsyncGenerator, List

import deepmerge
from mcp import Tool
Expand All @@ -23,47 +24,27 @@ def retrieve_tools_from_sessions(mcp_sessions: List[MCPSession], tools_config: T
]


async def handle_tool_call(
def get_mcp_session_and_tool_by_tool_name(
mcp_sessions: List[MCPSession],
tool_call: ToolCall,
method_metadata_key: str,
) -> ToolCallResult:
tool_name: str,
) -> tuple[MCPSession | None, Tool | None]:
"""
Handle the tool call by invoking the appropriate tool and returning a ToolCallResult.
Retrieve the MCP session and tool by tool name.
"""

# Initialize metadata
metadata = {}

# Find the tool and session from the full collection of sessions
mcp_session, tool = next(
(
(mcp_session, tool)
for mcp_session in mcp_sessions
for tool in mcp_session.tools
if tool.name == tool_call.name
),
return next(
((mcp_session, tool) for mcp_session in mcp_sessions for tool in mcp_session.tools if tool.name == tool_name),
(None, None),
)
if not mcp_session or not tool:
return ToolCallResult(
id=tool_call.id,
content=f"Tool '{tool_call.name}' not found in any of the sessions.",
message_type=ToolMessageType.notice,
metadata={},
)

# Update metadata with tool call details
deepmerge.always_merger.merge(
metadata,
{
"debug": {
method_metadata_key: {
"tool_call": tool_call.to_json(),
},
},
},
)

async def execute_tool_call(
mcp_session: MCPSession,
tool_call: ToolCall,
method_metadata_key: str,
) -> ToolCallResult:

# Initialize metadata
metadata = {}

# Initialize tool_result
tool_result = None
Expand All @@ -72,7 +53,7 @@ async def handle_tool_call(

# Invoke the tool
try:
logger.debug(f"Invoking '{mcp_session.name}.{tool_call.name}' with arguments: {tool_call.arguments}")
logger.debug(f"Invoking '{mcp_session.config.key}.{tool_call.name}' with arguments: {tool_call.arguments}")
tool_result = await mcp_session.client_session.call_tool(tool_call.name, tool_call.arguments)
tool_output = tool_result.content
except Exception as e:
Expand Down Expand Up @@ -106,3 +87,63 @@ async def handle_tool_call(
message_type=ToolMessageType.tool_result,
metadata=metadata,
)

async def handle_tool_call(
mcp_sessions: List[MCPSession],
tool_call: ToolCall,
method_metadata_key: str,
) -> ToolCallResult:
"""
Handle the tool call by invoking the appropriate tool and returning a ToolCallResult.
"""

# Find the tool and session from the full collection of sessions
mcp_session, tool = get_mcp_session_and_tool_by_tool_name(mcp_sessions, tool_call.name)

if not mcp_session or not tool:
return ToolCallResult(
id=tool_call.id,
content=f"Tool '{tool_call.name}' not found in any of the sessions.",
message_type=ToolMessageType.notice,
metadata={},
)

return await execute_tool_call(mcp_session, tool_call, method_metadata_key)


async def handle_long_running_tool_call(
mcp_sessions: List[MCPSession],
tool_call: ToolCall,
method_metadata_key: str,
) -> AsyncGenerator[ToolCallResult, None]:
"""
Handle the streaming tool call by invoking the appropriate tool and returning a ToolCallResult.
"""

# Find the tool and session from the full collection of sessions
mcp_session, tool = get_mcp_session_and_tool_by_tool_name(mcp_sessions, tool_call.name)

if not mcp_session or not tool:
yield ToolCallResult(
id=tool_call.id,
content=f"Tool '{tool_call.name}' not found in any of the sessions.",
message_type=ToolMessageType.notice,
metadata={},
)
return

# For now, let's just hack to return an immediate response to indicate that the tool call was received
# and is being processed and that the results will be sent in a separate message.
yield ToolCallResult(
id=tool_call.id,
content=dedent(f"""
Processing tool call '{tool_call.name}'.
Estimated time to completion: {mcp_session.config.task_completion_estimate}
""").strip(),
message_type=ToolMessageType.tool_result,
metadata={},
)

# Perform the tool call
tool_call_result = await execute_tool_call(mcp_session, tool_call, method_metadata_key)
yield tool_call_result
153 changes: 84 additions & 69 deletions assistants/codespace-assistant/assistant/extensions/tools/__model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,6 @@
logger = logging.getLogger(__name__)


class MCPSession:
name: str
client_session: ClientSession
tools: List[Tool] = []

def __init__(self, name: str, client_session: ClientSession) -> None:
self.name = name
self.client_session = client_session

async def initialize(self) -> None:
# Load all tools from the session, later we can do the same for resources, prompts, etc.
tools_result = await self.client_session.list_tools()
self.tools = tools_result.tools
logger.debug(f"Loaded {len(tools_result.tools)} tools from session '{self.name}'")


@dataclass
class ToolCall:
id: str
name: str
arguments: dict[str, Any]

def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"arguments": self.arguments,
}

def to_json(self, **kwargs) -> str:
return json.dumps(self, default=lambda o: o.__dict__, **kwargs)


class ToolMessageType(StrEnum):
notice = "notice"
tool_result = "tool_result"


@dataclass
class ToolCallResult:
id: str
content: str
message_type: ToolMessageType
metadata: dict[str, Any]


class MCPServerEnvConfig(BaseModel):
key: Annotated[str, Field(title="Key", description="Environment variable key.")]
value: Annotated[str, Field(title="Value", description="Environment variable value.")]
Expand All @@ -69,11 +23,7 @@ class MCPServerConfig(BaseModel):
key: Annotated[str, Field(title="Key", description="Unique key for the server configuration.")]

command: Annotated[
str,
Field(
title="Command",
description="Command to run the server, use url if using SSE transport."
)
str, Field(title="Command", description="Command to run the server, use url if using SSE transport.")
]

args: Annotated[List[str], Field(title="Arguments", description="Arguments to pass to the server.")]
Expand All @@ -89,6 +39,19 @@ class MCPServerConfig(BaseModel):
UISchema(widget="textarea"),
] = ""

long_running: Annotated[
bool,
Field(title="Long Running", description="Does this server run long running tasks?"),
] = False

task_completion_estimate: Annotated[
int,
Field(
title="Long Running Task Completion Time Estimate",
description="Estimated time to complete an average long running task (in seconds).",
),
] = 30


class ToolsConfigModel(BaseModel):
enabled: Annotated[
Expand Down Expand Up @@ -150,6 +113,30 @@ class ToolsConfigModel(BaseModel):
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/workspaces/semanticworkbench"],
),
MCPServerConfig(
key="vscode",
command="http://127.0.0.1:6010/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="bing-search",
command="http://127.0.0.1:6030/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="open-deep-research",
command="http://127.0.0.1:6020/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="giphy",
command="http://http://127.0.0.1:6000/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="memory",
command="npx",
Expand Down Expand Up @@ -178,24 +165,6 @@ class ToolsConfigModel(BaseModel):
""").strip(),
enabled=False,
),
MCPServerConfig(
key="open-deep-research",
command="http://127.0.0.1:6020/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="vscode",
command="http://127.0.0.1:6010/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="giphy",
command="http://http://127.0.0.1:6000/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="sequential-thinking",
command="npx",
Expand All @@ -214,3 +183,49 @@ class ToolsConfigModel(BaseModel):
""").strip(),
),
] = ["directory_tree"]


class MCPSession:
config: MCPServerConfig
client_session: ClientSession
tools: List[Tool] = []

def __init__(self, config: MCPServerConfig, client_session: ClientSession) -> None:
self.config = config
self.client_session = client_session

async def initialize(self) -> None:
# Load all tools from the session, later we can do the same for resources, prompts, etc.
tools_result = await self.client_session.list_tools()
self.tools = tools_result.tools
logger.debug(f"Loaded {len(tools_result.tools)} tools from session '{self.config.key}'")


@dataclass
class ToolCall:
id: str
name: str
arguments: dict[str, Any]

def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"arguments": self.arguments,
}

def to_json(self, **kwargs) -> str:
return json.dumps(self, default=lambda o: o.__dict__, **kwargs)


class ToolMessageType(StrEnum):
notice = "notice"
tool_result = "tool_result"


@dataclass
class ToolCallResult:
id: str
content: str
message_type: ToolMessageType
metadata: dict[str, Any]
Loading

0 comments on commit e2abc74

Please sign in to comment.