-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Yu Ishikawa <[email protected]>
- Loading branch information
Showing
4 changed files
with
197 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import argparse | ||
import asyncio | ||
|
||
from google import genai | ||
from google.genai import chats, types | ||
from loguru import logger | ||
|
||
from research_agent.mcp_client import MCPClient | ||
from research_agent.utils import to_gemini_tool | ||
|
||
|
||
async def process_query( | ||
chat_client: chats.Chat, | ||
mcp_client: MCPClient, | ||
query: str, | ||
): | ||
"""Process the user query using Gemini and MCP tools.""" | ||
response = chat_client.send_message(message=[query]) | ||
if not response.candidates: | ||
raise RuntimeError("No response from Gemini") | ||
|
||
returned_message = None | ||
for candidate in response.candidates: | ||
if candidate.content: | ||
for part in candidate.content.parts: | ||
# If the candidate is a text, add it to the returned message | ||
if part.text: | ||
returned_message = types.Content( | ||
role="model", parts=[types.Part.from_text(text=part.text)] | ||
) | ||
# If the candidate is a tool call, call the tool | ||
elif part.function_call: | ||
tool_name = part.function_call.name | ||
tool_args = part.function_call.args | ||
logger.debug(f"Tool name: {tool_name}, tool args: {tool_args}") | ||
tool_call = await mcp_client.call_tool(tool_name, tool_args) | ||
if tool_call and tool_call.content: | ||
returned_message = types.Content( | ||
role="model", | ||
parts=[ | ||
types.Part.from_text(text=content.text) | ||
for content in tool_call.content | ||
], | ||
) | ||
else: | ||
raise RuntimeError(f"No tool call content {tool_call}") | ||
else: | ||
raise RuntimeError(f"Unknown part type {part}") | ||
return returned_message | ||
|
||
|
||
async def chat(server_url: str): | ||
""" | ||
Run the chat server. | ||
""" | ||
# Why do we use google-genai, not vertexai? | ||
# Because it is easier to convert MCP tools to GenAI tools in google-genai. | ||
genai_client = genai.Client(vertexai=True, location="us-central1") | ||
mcp_client = MCPClient(name="document-search") | ||
await mcp_client.connect_to_server(server_url=server_url) | ||
|
||
# Collect tools from MCP server | ||
mcp_tools = await mcp_client.list_tools() | ||
# Convert MCP tools to GenAI tools | ||
genai_tools = [to_gemini_tool(tool) for tool in mcp_tools.tools] | ||
|
||
# Create chat client | ||
chat_client = genai_client.chats.create( | ||
model="gemini-2.0-flash", | ||
config=types.GenerateContentConfig( | ||
tools=genai_tools, | ||
system_instruction=""" | ||
You are a helpful assistant to search documents. | ||
You have to pass the query to the tool to search the documents as much natural as possible. | ||
""", | ||
), | ||
) | ||
|
||
print("If you want to quit, please enter 'bye'") | ||
try: | ||
while True: | ||
# Get user query | ||
query = input("Enter your query: ") | ||
if query == "bye": | ||
break | ||
|
||
# Get response from GenAI | ||
response = await process_query(chat_client, mcp_client, query) | ||
print(response) | ||
# pylint: disable=broad-except | ||
except Exception as e: | ||
await mcp_client.cleanup() | ||
raise RuntimeError from e | ||
|
||
|
||
if __name__ == "__main__": | ||
# Parse command line arguments | ||
parser = argparse.ArgumentParser() | ||
# trunk-ignore(bandit/B104) | ||
parser.add_argument("--host", type=str, default="0.0.0.0") | ||
parser.add_argument("--port", type=int, default=8080) | ||
args = parser.parse_args() | ||
# Run the chat server | ||
server_url = f"http://{args.host}:{args.port}/sse" | ||
asyncio.run(chat(server_url)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from contextlib import AsyncExitStack | ||
from typing import Optional | ||
|
||
from mcp.client.session import ClientSession | ||
from mcp.client.sse import sse_client | ||
|
||
|
||
class MCPClient: | ||
def __init__(self, name: str,server_url: Optional[str] = None): | ||
# Initialize session and client objects | ||
self.name = name | ||
self.session: Optional[ClientSession] = None | ||
self.exit_stack = AsyncExitStack() | ||
|
||
if server_url: | ||
self.connect_to_server(server_url) | ||
|
||
async def connect_to_server(self, server_url: str): | ||
"""Connect to an MCP server running with SSE transport""" | ||
# Use AsyncExitStack to manage the contexts | ||
_sse_client = sse_client(url=server_url) | ||
streams = await self.exit_stack.enter_async_context(_sse_client) | ||
|
||
_session_context = ClientSession(*streams) | ||
self.session: ClientSession = await self.exit_stack.enter_async_context( | ||
_session_context | ||
) | ||
|
||
# Initialize | ||
await self.session.initialize() | ||
|
||
async def cleanup(self): | ||
"""Properly clean up the session and streams""" | ||
await self.exit_stack.aclose() | ||
|
||
async def list_tools(self): | ||
return await self.session.list_tools() | ||
|
||
async def call_tool(self, tool_name: str, tool_arguments: Optional[dict] = None): | ||
return await self.session.call_tool(tool_name, tool_arguments) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
async def main(): | ||
client = MCPClient() | ||
await client.connect_to_server(server_url="http://0.0.0.0:8080/sse") | ||
tools = await client.list_tools() | ||
print(tools) | ||
tool_call = await client.call_tool("document-search", {"query": "cpp segment とはなんですか?"}) | ||
print(tool_call) | ||
await client.cleanup() # Ensure cleanup is called | ||
|
||
import asyncio | ||
|
||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from google import genai | ||
from google.genai import types as genai_types | ||
from mcp import types as mcp_types | ||
|
||
|
||
def to_gemini_tool(mcp_tool: mcp_types.Tool) -> genai_types.Tool: | ||
""" | ||
Converts an MCP tool schema to a Gemini tool. | ||
Args: | ||
name: The name of the tool. | ||
description: The description of the tool. | ||
input_schema: The input schema of the tool. | ||
Returns: | ||
A Gemini tool. | ||
""" | ||
required_params: list[str] = mcp_tool.inputSchema.get("required", []) | ||
properties = {} | ||
for key, value in mcp_tool.inputSchema.get("properties", {}).items(): | ||
schema_dict = { | ||
"type": value.get("type", "STRING").upper(), | ||
"description": value.get("description", ""), | ||
} | ||
properties[key] = genai_types.Schema(**schema_dict) | ||
|
||
function = genai.types.FunctionDeclaration( | ||
name=mcp_tool.name, | ||
description=mcp_tool.description, | ||
parameters=genai.types.Schema( | ||
type="OBJECT", | ||
properties=properties, | ||
required=required_params, | ||
), | ||
) | ||
return genai_types.Tool(function_declarations=[function]) |