Skip to content

Commit

Permalink
improvements to mcp client/server support, introduces WIP mcp-server-…
Browse files Browse the repository at this point in the history
…fusion (#342)
  • Loading branch information
bkrabach authored Feb 24, 2025
1 parent 9570e4d commit 3d8e0b7
Show file tree
Hide file tree
Showing 31 changed files with 1,446 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,6 @@ async def handle_completion(
metadata_key: str,
response_start_time: float,
) -> StepResult:
# helper function for handling errors
async def handle_error(error_message: str) -> StepResult:
await context.send_messages(
NewConversationMessage(
content=error_message,
message_type=MessageType.notice,
metadata=step_result.metadata,
)
)
step_result.status = "error"
return step_result

# get service and request configuration for generative model
generative_request_config = request_config

Expand Down Expand Up @@ -166,7 +154,6 @@ async def handle_error(error_message: str) -> StepResult:
tool_call_count = 0
for tool_call in tool_calls:
tool_call_count += 1

tool_call_status = f"using tool `{tool_call.name}`"
async with context.set_status(f"{tool_call_status}..."):

Expand All @@ -181,8 +168,26 @@ async def on_logging_message(msg: str) -> None:
on_logging_message,
)
except Exception as e:
logger.exception(f"Error handling tool call: {e}")
return await handle_error("An error occurred while handling the tool call.")
logger.exception(f"Error handling tool call '{tool_call.name}': {e}")
deepmerge.always_merger.merge(
step_result.metadata,
{
"debug": {
f"{metadata_key}:request:tool_call_{tool_call_count}": {
"error": str(e),
},
},
},
)
await context.send_messages(
NewConversationMessage(
content=f"Error executing tool '{tool_call.name}': {e}",
message_type=MessageType.notice,
metadata=step_result.metadata,
)
)
step_result.status = "error"
return step_result

# Update content and metadata with tool call result metadata
deepmerge.always_merger.merge(step_result.metadata, tool_call_result.metadata)
Expand Down
35 changes: 25 additions & 10 deletions assistants/codespace-assistant/assistant/response/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, List

from assistant_extensions.attachments import AttachmentsExtension
from assistant_extensions.mcp import MCPSession, establish_mcp_sessions, get_mcp_server_prompts
from assistant_extensions.mcp import MCPSession, establish_mcp_sessions, get_mcp_server_prompts, refresh_mcp_sessions
from semantic_workbench_api_model.workbench_model import (
ConversationMessage,
MessageType,
Expand Down Expand Up @@ -33,16 +33,28 @@ async def respond_to_conversation(
async with AsyncExitStack() as stack:
# If tools are enabled, establish connections to the MCP servers
mcp_sessions: List[MCPSession] = []
if config.extensions_config.tools.enabled:
mcp_sessions = await establish_mcp_sessions(config.extensions_config.tools, stack)
if not mcp_sessions:
await context.send_messages(
NewConversationMessage(
content="Unable to connect to any MCP servers. Please ensure the servers are running.",
message_type=MessageType.notice,
metadata=metadata,
)

async def error_handler(server_config, error) -> None:
logger.error(f"Failed to connect to MCP server {server_config.key}: {error}")
# Also notify the user about this server failure here.
await context.send_messages(
NewConversationMessage(
content=f"Failed to connect to MCP server {server_config.key}: {error}",
message_type=MessageType.notice,
metadata=metadata,
)
)

if config.extensions_config.tools.enabled:
mcp_sessions = await establish_mcp_sessions(
tools_config=config.extensions_config.tools,
stack=stack,
error_handler=error_handler
)

if len(config.extensions_config.tools.mcp_servers) > 0 and len(mcp_sessions) == 0:
# No MCP servers are available, so we should not continue
logger.error("No MCP servers are available.")
return

# Retrieve prompts from the MCP servers
Expand Down Expand Up @@ -77,6 +89,9 @@ async def respond_to_conversation(
logger.info("Response interrupted.")
break

# Reconnect to the MCP servers if they were disconnected
mcp_sessions = await refresh_mcp_sessions(mcp_sessions)

step_result = await next_step(
mcp_sessions=mcp_sessions,
mcp_prompts=mcp_prompts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ async def get_completion(
completion_args["tools"] = tools or NotGiven()
if tools is not None:
completion_args["tool_choice"] = "auto"
completion_args["parallel_tool_calls"] = False

logger.debug(
dedent(f"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
MCPSession,
MCPToolsConfigModel,
)
from ._server_utils import establish_mcp_sessions, get_mcp_server_prompts
from ._server_utils import establish_mcp_sessions, get_mcp_server_prompts, refresh_mcp_sessions
from ._tool_utils import handle_mcp_tool_call, retrieve_mcp_tools_from_sessions

__all__ = [
Expand All @@ -15,5 +15,6 @@
"establish_mcp_sessions",
"get_mcp_server_prompts",
"handle_mcp_tool_call",
"refresh_mcp_sessions",
"retrieve_mcp_tools_from_sessions",
]
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,27 @@ class MCPToolsConfigModel(BaseModel):
),
MCPServerConfig(
key="giphy",
command="http://http://127.0.0.1:6000/sse",
command="http://127.0.0.1:6000/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="fusion",
command="http://127.0.0.1:6050/sse",
args=[],
prompt=dedent("""
When creating models, remember the following:
- Z is vertical, X is horizontal, and Y is depth
- The top plane for an entity is an XY plane, at the Z coordinate of the top of the entity
- The bottom plane for an entity is an XY plane, at the Z coordinate of the bottom of the entity
- The front plane for an entity is an XZ plane, at the Y coordinate of the front of the entity
- The back plane for an entity is an XZ plane, at the Y coordinate of the back of the entity
- The left plane for an entity is a YZ plane, at the X coordinate of the left of the entity
- The right plane for an entity is a YZ plane, at the X coordinate of the right of the entity
- Remember to always use the correct plane and consider the amount of adjustment on the 3rd plane necessary
""").strip(),
enabled=False,
),
MCPServerConfig(
key="memory",
command="npx",
Expand Down Expand Up @@ -211,6 +228,7 @@ class MCPSession:
config: MCPServerConfig
client_session: ClientSession
tools: List[Tool] = []
is_connected: bool = True

def __init__(self, config: MCPServerConfig, client_session: ClientSession) -> None:
self.config = config
Expand All @@ -220,6 +238,7 @@ async def initialize(self) -> None:
# Load all tools from the session, later we can do the same for resources, prompts, etc.
tools_result = await self.client_session.list_tools()
self.tools = tools_result.tools
self.is_connected = True
logger.debug(
f"Loaded {len(tools_result.tools)} tools from session '{self.config.key}'"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from asyncio import CancelledError
from contextlib import AsyncExitStack, asynccontextmanager
from typing import AsyncIterator, List, Optional

from typing import AsyncIterator, Callable, List, Optional
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
Expand Down Expand Up @@ -70,9 +69,9 @@ async def connect_to_mcp_server_sse(
)
headers = get_env_dict(server_config)

# FIXME: Bumping timeout to 15 minutes, but this should be configurable
# FIXME: Bumping sse_read_timeout to 15 minutes and timeout to 5 minutes, but this should be configurable
async with sse_client(
url=server_config.command, headers=headers, sse_read_timeout=60 * 15
url=server_config.command, headers=headers, timeout=60 * 5, sse_read_timeout=60 * 15
) as (
read_stream,
write_stream,
Expand All @@ -83,9 +82,13 @@ async def connect_to_mcp_server_sse(

except ExceptionGroup as e:
logger.exception(f"TaskGroup failed in SSE client for {server_config.key}: {e}")
for sub_extension in e.exceptions:
logger.error(f"Sub-exception: {server_config.key}: {sub_extension}")
raise
for sub in e.exceptions:
logger.error(f"Sub-exception: {server_config.key}: {sub}")
# If there's exactly one underlying exception, re-raise it
if len(e.exceptions) == 1:
raise e.exceptions[0]
else:
raise
except CancelledError as e:
logger.exception(
f"Task was cancelled in SSE client for {server_config.key}: {e}"
Expand All @@ -98,32 +101,74 @@ async def connect_to_mcp_server_sse(
logger.exception(f"Error connecting to {server_config.key}: {e}")
raise

async def refresh_mcp_sessions(mcp_sessions: list[MCPSession]) -> list[MCPSession]:
"""
Check each MCP session for connectivity. If a session is marked as disconnected,
attempt to reconnect it using reconnect_mcp_session.
"""
active_sessions = []
for session in mcp_sessions:
if not session.is_connected:
logger.info(f"Session {session.config.key} is disconnected. Attempting to reconnect...")
new_session = await reconnect_mcp_session(session.config)
if new_session:
active_sessions.append(new_session)
else:
logger.error(f"Failed to reconnect MCP server {session.config.key}.")
else:
active_sessions.append(session)
return active_sessions

async def establish_mcp_sessions(
tools_config: MCPToolsConfigModel, stack: AsyncExitStack
) -> List[MCPSession]:

async def reconnect_mcp_session(server_config: MCPServerConfig) -> MCPSession | None:
"""
Establish connections to MCP servers using the provided AsyncExitStack.
Attempt to reconnect to the MCP server using the provided configuration.
Returns a new MCPSession if successful, or None otherwise.
This version relies directly on the existing connection context manager
to avoid interfering with cancel scopes.
"""
try:
async with connect_to_mcp_server(server_config) as client_session:
if client_session is None:
logger.error(f"Reconnection returned no client session for {server_config.key}")
return None

new_session = MCPSession(config=server_config, client_session=client_session)
await new_session.initialize()
new_session.is_connected = True
logger.info(f"Successfully reconnected to MCP server {server_config.key}")
return new_session
except Exception as e:
logger.exception(f"Error reconnecting MCP server {server_config.key}: {e}")
return None


async def establish_mcp_sessions(
tools_config: MCPToolsConfigModel, stack: AsyncExitStack, error_handler: Optional[Callable] = None
) -> List[MCPSession]:
mcp_sessions: List[MCPSession] = []
for server_config in tools_config.mcp_servers:
# Check to see if the server is enabled
if not server_config.enabled:
logger.debug(f"Skipping disabled server: {server_config.key}")
continue
try:
client_session: ClientSession | None = await stack.enter_async_context(
connect_to_mcp_server(server_config)
)
except Exception as e:
# Log a cleaner error message for this specific server
logger.error(f"Failed to connect to MCP server {server_config.key}: {e}")
# Also notify the user about this server failure here.
if error_handler:
await error_handler(server_config, e)
# Abort the connection attempt for the servers to avoid only partial server connections
# This could lead to assistant creatively trying to use the other tools to compensate
# for the missing tools, which can sometimes be very problematic.
return []

client_session: ClientSession | None = await stack.enter_async_context(
connect_to_mcp_server(server_config)
)
if client_session:
# Create an MCP session with the client session
mcp_session = MCPSession(
config=server_config, client_session=client_session
)
# Initialize the session to load tools, resources, etc.
mcp_session = MCPSession(config=server_config, client_session=client_session)
await mcp_session.initialize()
# Add the session to the list of established sessions
mcp_sessions.append(mcp_session)
else:
logger.warning(f"Could not establish session with {server_config.key}")
Expand Down
Loading

0 comments on commit 3d8e0b7

Please sign in to comment.