Skip to content

Commit

Permalink
Add the chat example
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Ishikawa <[email protected]>
  • Loading branch information
yu-iskw committed Feb 19, 2025
1 parent 316232a commit 92aac27
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 0 deletions.
Empty file added src/research_agent/__init__.py
Empty file.
105 changes: 105 additions & 0 deletions src/research_agent/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import argparse
import asyncio

from google import genai
from google.genai import chats, types
from loguru import logger

from research_agent.mcp_client import MCPClient
from research_agent.utils import to_gemini_tool


async def process_query(
chat_client: chats.Chat,
mcp_client: MCPClient,
query: str,
):
"""Process the user query using Gemini and MCP tools."""
response = chat_client.send_message(message=[query])
if not response.candidates:
raise RuntimeError("No response from Gemini")

returned_message = None
for candidate in response.candidates:
if candidate.content:
for part in candidate.content.parts:
# If the candidate is a text, add it to the returned message
if part.text:
returned_message = types.Content(
role="model", parts=[types.Part.from_text(text=part.text)]
)
# If the candidate is a tool call, call the tool
elif part.function_call:
tool_name = part.function_call.name
tool_args = part.function_call.args
logger.debug(f"Tool name: {tool_name}, tool args: {tool_args}")
tool_call = await mcp_client.call_tool(tool_name, tool_args)
if tool_call and tool_call.content:
returned_message = types.Content(
role="model",
parts=[
types.Part.from_text(text=content.text)
for content in tool_call.content
],
)
else:
raise RuntimeError(f"No tool call content {tool_call}")
else:
raise RuntimeError(f"Unknown part type {part}")
return returned_message


async def chat(server_url: str):
"""
Run the chat server.
"""
# Why do we use google-genai, not vertexai?
# Because it is easier to convert MCP tools to GenAI tools in google-genai.
genai_client = genai.Client(vertexai=True, location="us-central1")
mcp_client = MCPClient(name="document-search")
await mcp_client.connect_to_server(server_url=server_url)

# Collect tools from MCP server
mcp_tools = await mcp_client.list_tools()
# Convert MCP tools to GenAI tools
genai_tools = [to_gemini_tool(tool) for tool in mcp_tools.tools]

# Create chat client
chat_client = genai_client.chats.create(
model="gemini-2.0-flash",
config=types.GenerateContentConfig(
tools=genai_tools,
system_instruction="""
You are a helpful assistant to search documents.
You have to pass the query to the tool to search the documents as much natural as possible.
""",
),
)

print("If you want to quit, please enter 'bye'")
try:
while True:
# Get user query
query = input("Enter your query: ")
if query == "bye":
break

# Get response from GenAI
response = await process_query(chat_client, mcp_client, query)
print(response)
# pylint: disable=broad-except
except Exception as e:
await mcp_client.cleanup()
raise RuntimeError from e


if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser()
# trunk-ignore(bandit/B104)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8080)
args = parser.parse_args()
# Run the chat server
server_url = f"http://{args.host}:{args.port}/sse"
asyncio.run(chat(server_url))
56 changes: 56 additions & 0 deletions src/research_agent/mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from contextlib import AsyncExitStack
from typing import Optional

from mcp.client.session import ClientSession
from mcp.client.sse import sse_client


class MCPClient:
def __init__(self, name: str,server_url: Optional[str] = None):
# Initialize session and client objects
self.name = name
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()

if server_url:
self.connect_to_server(server_url)

async def connect_to_server(self, server_url: str):
"""Connect to an MCP server running with SSE transport"""
# Use AsyncExitStack to manage the contexts
_sse_client = sse_client(url=server_url)
streams = await self.exit_stack.enter_async_context(_sse_client)

_session_context = ClientSession(*streams)
self.session: ClientSession = await self.exit_stack.enter_async_context(
_session_context
)

# Initialize
await self.session.initialize()

async def cleanup(self):
"""Properly clean up the session and streams"""
await self.exit_stack.aclose()

async def list_tools(self):
return await self.session.list_tools()

async def call_tool(self, tool_name: str, tool_arguments: Optional[dict] = None):
return await self.session.call_tool(tool_name, tool_arguments)


if __name__ == "__main__":

async def main():
client = MCPClient()
await client.connect_to_server(server_url="http://0.0.0.0:8080/sse")
tools = await client.list_tools()
print(tools)
tool_call = await client.call_tool("document-search", {"query": "cpp segment とはなんですか?"})
print(tool_call)
await client.cleanup() # Ensure cleanup is called

import asyncio

asyncio.run(main())
36 changes: 36 additions & 0 deletions src/research_agent/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from google import genai
from google.genai import types as genai_types
from mcp import types as mcp_types


def to_gemini_tool(mcp_tool: mcp_types.Tool) -> genai_types.Tool:
"""
Converts an MCP tool schema to a Gemini tool.
Args:
name: The name of the tool.
description: The description of the tool.
input_schema: The input schema of the tool.
Returns:
A Gemini tool.
"""
required_params: list[str] = mcp_tool.inputSchema.get("required", [])
properties = {}
for key, value in mcp_tool.inputSchema.get("properties", {}).items():
schema_dict = {
"type": value.get("type", "STRING").upper(),
"description": value.get("description", ""),
}
properties[key] = genai_types.Schema(**schema_dict)

function = genai.types.FunctionDeclaration(
name=mcp_tool.name,
description=mcp_tool.description,
parameters=genai.types.Schema(
type="OBJECT",
properties=properties,
required=required_params,
),
)
return genai_types.Tool(function_declarations=[function])

0 comments on commit 92aac27

Please sign in to comment.