Skip to content

Commit

Permalink
supports building extension, more dev config for tools, better reason…
Browse files Browse the repository at this point in the history
…ing model support (microsoft#318)
  • Loading branch information
bkrabach authored Feb 7, 2025
1 parent 58eb20b commit f80e015
Show file tree
Hide file tree
Showing 13 changed files with 1,273 additions and 183 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client

from .__mcp_server_configs import get_mcp_server_configs
from .__model import MCPServerConfig, MCPSession, ToolsConfigModel

logger = logging.getLogger(__name__)


def get_env_dict(server_config: MCPServerConfig) -> dict[str, str]:
"""Get the environment variables as a dictionary."""
return {env.key: env.value for env in server_config.env}


@asynccontextmanager
async def connect_to_mcp_server(server_config: MCPServerConfig) -> AsyncIterator[Optional[ClientSession]]:
"""Connect to a single MCP server defined in the config."""
Expand All @@ -27,7 +31,9 @@ async def connect_to_mcp_server(server_config: MCPServerConfig) -> AsyncIterator
async def connect_to_mcp_server_stdio(server_config: MCPServerConfig) -> AsyncIterator[Optional[ClientSession]]:
"""Connect to a single MCP server defined in the config."""

server_params = StdioServerParameters(command=server_config.command, args=server_config.args, env=server_config.env)
server_params = StdioServerParameters(
command=server_config.command, args=server_config.args, env=get_env_dict(server_config)
)
try:
logger.debug(
f"Attempting to connect to {server_config.key} with command: {server_config.command} {' '.join(server_config.args)}"
Expand All @@ -47,7 +53,8 @@ async def connect_to_mcp_server_sse(server_config: MCPServerConfig) -> AsyncIter

try:
logger.debug(f"Attempting to connect to {server_config.key} with SSE transport: {server_config.command}")
async with sse_client(url=server_config.command, headers=server_config.env) as (read_stream, write_stream):
headers = get_env_dict(server_config)
async with sse_client(url=server_config.command, headers=headers) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as client_session:
await client_session.initialize()
yield client_session # Yield the session for use
Expand All @@ -56,20 +63,15 @@ async def connect_to_mcp_server_sse(server_config: MCPServerConfig) -> AsyncIter
yield None


def is_mcp_server_enabled(server_config: MCPServerConfig, tools_config: ToolsConfigModel) -> bool:
"""Check if an MCP server is enabled."""
return tools_config.tool_servers_enabled.model_dump().get(f"{server_config.key}_enabled", False)


async def establish_mcp_sessions(tools_config: ToolsConfigModel, stack: AsyncExitStack) -> List[MCPSession]:
"""
Establish connections to MCP servers using the provided AsyncExitStack.
"""

mcp_sessions: List[MCPSession] = []
for server_config in get_mcp_server_configs(tools_config):
for server_config in tools_config.mcp_servers:
# Check to see if the server is enabled
if not is_mcp_server_enabled(server_config, tools_config):
if not server_config.enabled:
logger.debug(f"Skipping disabled server: {server_config.key}")
continue

Expand All @@ -88,4 +90,4 @@ async def establish_mcp_sessions(tools_config: ToolsConfigModel, stack: AsyncExi

def get_mcp_server_prompts(tools_config: ToolsConfigModel) -> List[str]:
"""Get the prompts for all MCP servers."""
return [server.prompt for server in get_mcp_server_configs(tools_config) if server.prompt]
return [mcp_server.prompt for mcp_server in tools_config.mcp_servers if mcp_server.prompt]
168 changes: 79 additions & 89 deletions assistants/codespace-assistant/assistant/extensions/tools/__model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from enum import StrEnum
from textwrap import dedent
from typing import Annotated, Any, List, Optional
from typing import Annotated, Any, List

from attr import dataclass
from mcp import ClientSession, Tool
Expand All @@ -12,15 +12,6 @@
logger = logging.getLogger(__name__)


@dataclass
class MCPServerConfig:
key: str
command: str
args: List[str]
env: Optional[dict[str, str]] = None
prompt: Optional[str] = None


class MCPSession:
name: str
client_session: ClientSession
Expand Down Expand Up @@ -67,55 +58,32 @@ class ToolCallResult:
metadata: dict[str, Any]


class MCPServersEnabledConfigModel(BaseModel):
# NOTE: create a property for each of the mcp servers following the convention of: {server_key}_enabled
class MCPServerEnvConfig(BaseModel):
key: Annotated[str, Field(title="Key", description="Environment variable key.")]
value: Annotated[str, Field(title="Value", description="Environment variable value.")]

filesystem_enabled: Annotated[
bool,
Field(
title="File System Enabled",
description="Enable file system tools, granting access to defined file system paths for read/write.",
),
] = True

memory_enabled: Annotated[
bool,
Field(
title="Memory Enabled",
description="Enable memory tools, allowing for storing and retrieving data in memory.",
),
] = True
class MCPServerConfig(BaseModel):
enabled: Annotated[bool, Field(title="Enabled", description="Enable the server.")] = True

vscode_enabled: Annotated[
bool,
Field(
title="VSCode Enabled",
description=dedent("""
Enable VSCode tools, supporting testing and evaluation of code via VSCode integration.
To use this tool, the project must be running in VSCode (tested in Codespaces, but may work
locally), and the `mcp-server-vscode` VSCode extension must be running.
""").strip(),
),
] = False
key: Annotated[str, Field(title="Key", description="Unique key for the server configuration.")]

sequential_thinking_enabled: Annotated[
bool,
Field(
title="Sequential Thinking Enabled",
description="Enable sequential thinking tools, supporting sequential processing of information.",
),
] = False
command: Annotated[
str, Field(title="Command", description="Command to run the server."), UISchema(widget="textarea")
]

giphy_enabled: Annotated[
bool,
Field(
title="Giphy Enabled",
description=dedent("""
Enable Giphy tools for searching and retrieving GIFs. Must start the Giphy server via the
VSCode Run and Debug panel.
""").strip(),
),
] = False
args: Annotated[List[str], Field(title="Arguments", description="Arguments to pass to the server.")]

env: Annotated[
List[MCPServerEnvConfig],
Field(title="Environment Variables", description="Environment variables to set."),
] = []

prompt: Annotated[
str,
Field(title="Prompt", description="Instructions for using the server."),
UISchema(widget="textarea"),
] = ""


class ToolsConfigModel(BaseModel):
Expand Down Expand Up @@ -166,43 +134,65 @@ class ToolsConfigModel(BaseModel):
- The search tool does not appear to support wildcards, but does work with partial file names.
""").strip()

# instructions_for_non_tool_models: Annotated[
# str,
# Field(
# title="Tools Instructions for Models Without Tools Support",
# description=dedent("""
# Some models don't support tools (like OpenAI reasoning models), so these instructions
# are only used to implement tool support through custom instruction and injection of
# the tool definitions. Make sure to include {{tools}} in the instructions.
# """),
# ),
# UISchema(widget="textarea", enable_markdown_in_description=True),
# ] = dedent("""
# You can perform specific tasks using available tools. When you need to use a tool, respond
# with a strict JSON object containing only the tool's `id` and function name and arguments.
# \n\n
# Available Tools:
# {{tools}}
# \n\n
# ### How to Use Tools:
# - If you need to use a tool to answer the user's query, respond with **ONLY** a JSON object.
# - If you can answer without using a tool, provide the answer directly.
# - **No code, no text, no markdown** within the JSON.
# - Ensure that all values are plain data types (e.g., strings, numbers).
# - **Do not** include any additional characters, functions, or expressions within the JSON.
# """).strip()

file_system_paths: Annotated[
list[str],
mcp_servers: Annotated[
List[MCPServerConfig],
Field(
title="File System Paths",
description="Paths to the file system for tools to use in the container or local.",
title="MCP Servers",
description="Configuration for MCP servers that provide tools to the assistant.",
),
] = ["/workspaces/semanticworkbench"]

tool_servers_enabled: Annotated[MCPServersEnabledConfigModel, Field(title="Tool Servers Enabled")] = (
MCPServersEnabledConfigModel()
)
] = [
MCPServerConfig(
key="filesystem",
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/workspaces/semanticworkbench"],
),
MCPServerConfig(
key="memory",
command="npx",
args=["-y", "@modelcontextprotocol/server-memory"],
prompt=dedent("""
Follow these steps for each interaction:
1. Memory Retrieval:
- Always begin your chat by saying only "Remembering..." and retrieve all relevant information
from your knowledge graph
- Always refer to your knowledge graph as your "memory"
2. Memory
- While conversing with the user, be attentive to any new information that falls into these categories:
a) Basic Identity (age, gender, location, job title, education level, etc.)
b) Behaviors (interests, habits, etc.)
c) Preferences (communication style, preferred language, etc.)
d) Goals (goals, targets, aspirations, etc.)
e) Relationships (personal and professional relationships up to 3 degrees of separation)
3. Memory Update:
- If any new information was gathered during the interaction, update your memory as follows:
a) Create entities for recurring organizations, people, and significant events
b) Connect them to the current entities using relations
b) Store facts about them as observations
""").strip(),
enabled=False,
),
MCPServerConfig(
key="vscode",
command="http://127.0.0.1:6010/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="giphy",
command="http://http://127.0.0.1:6000/sse",
args=[],
enabled=False,
),
MCPServerConfig(
key="sequential_thinking",
command="npx",
args=["-y", "@modelcontextprotocol/server-sequential-thinking"],
enabled=False,
),
]

tools_disabled: Annotated[
list[str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from assistant_extensions.attachments import AttachmentsConfigModel, AttachmentsExtension
from openai.types.chat import (
ChatCompletionDeveloperMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolParam,
Expand Down Expand Up @@ -68,12 +69,26 @@ async def build_request(
prompts_config, context, participants, silence_token, additional_system_message_content
)

chat_message_params: List[ChatCompletionMessageParam] = [
chat_message_params: List[ChatCompletionMessageParam] = []

if request_config.is_reasoning_model:
# Reasoning models use developer messages instead of system messages
developer_message_content = (
f"Formatting re-enabled\n{system_message_content}"
if request_config.enable_markdown_in_reasoning_response
else system_message_content
)
chat_message_params.append(
ChatCompletionDeveloperMessageParam(
role="developer",
content=developer_message_content,
)
)
else:
ChatCompletionSystemMessageParam(
role="system",
content=system_message_content,
)
]

# Initialize token count to track the number of tokens used
# Add history messages last, as they are what will be truncated if the token limit is reached
Expand Down Expand Up @@ -118,8 +133,7 @@ async def build_request(

# Add room for reasoning tokens if using a reasoning model
if request_config.is_reasoning_model:
adjustment_percent = request_config.reasoning_token_overhead_percentage
available_tokens -= int(request_config.response_tokens * adjustment_percent / 100)
available_tokens -= request_config.reasoning_token_allocation

# Get history messages
history_messages_result = await get_history_messages(
Expand Down
Loading

0 comments on commit f80e015

Please sign in to comment.