-
-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from SecretiveShell/add-mcp-clients
Add mcp clients
- Loading branch information
Showing
8 changed files
with
142 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from contextlib import asynccontextmanager | ||
from mcp_clients.McpClientManager import ClientManager | ||
from loguru import logger | ||
|
||
@asynccontextmanager | ||
async def lifespan(app): | ||
"""Lifespan context manager for fastapi""" | ||
|
||
# startup | ||
logger.log("DEBUG", "Entered fastapi lifespan") | ||
await ClientManager.initialize() | ||
logger.log("DEBUG", "Initialized MCP Client Manager") | ||
|
||
logger.log("DEBUG", "Yielding lifespan") | ||
yield | ||
logger.log("DEBUG", "Returned form lifespan yield") | ||
|
||
# shutdown | ||
|
||
logger.log("DEBUG", "Exiting fastapi lifespan") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,13 @@ | ||
from fastapi import FastAPI | ||
from endpoints import router | ||
from endpoints import router as endpointRouter | ||
from mcp_endpoints import router as mcpRouter | ||
from lifespan import lifespan | ||
|
||
app = FastAPI( | ||
title="MCP Bridge", | ||
description="A middleware application to add MCP support to openai compatible apis", | ||
lifespan=lifespan, | ||
) | ||
|
||
app.include_router(router) | ||
app.include_router(endpointRouter) | ||
app.include_router(mcpRouter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Any | ||
import asyncio | ||
from loguru import logger | ||
from mcp import ClientSession | ||
|
||
class ClientInstance: | ||
name: str | ||
lock: asyncio.Lock | ||
session: None | ||
|
||
def __init__(self, name: str, client): | ||
logger.log("DEBUG", f"Creating client instance for {name}") | ||
self.name = name | ||
self._client = client | ||
self.lock = asyncio.Lock() | ||
|
||
async def start(self): | ||
asyncio.create_task(self._maintain_session()) | ||
|
||
async def _maintain_session(self): | ||
async with self._client as client: | ||
async with ClientSession(*client) as session: | ||
await session.initialize() | ||
logger.debug(f"finished initialise session for {self.name}") | ||
self.session = session | ||
await asyncio.Future() | ||
|
||
|
||
async def __aenter__(self): | ||
await self.lock.acquire() | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb): | ||
self.lock.release() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Any | ||
from config import settings | ||
from mcp import StdioServerParameters | ||
from mcp_clients.StdioClientFactory import construct_stdio_client | ||
from .ClientInstance import ClientInstance | ||
from loguru import logger | ||
|
||
|
||
class MCPClientManager: | ||
clients: dict[str, ClientInstance] = {} | ||
|
||
async def initialize(self): | ||
logger.log("DEBUG", "Initializing MCP Client Manager") | ||
for server_name, server_config in settings.mcp_servers.items(): | ||
self.clients[server_name] = await self.construct_client( | ||
server_name, server_config | ||
) | ||
await self.clients[server_name].start() # TODO: make these sessions start async? | ||
|
||
async def construct_client(self, name, server_config) -> ClientInstance: | ||
logger.log("DEBUG", f"Constructing client for {server_config}") | ||
|
||
if isinstance(server_config, StdioServerParameters): | ||
client = await construct_stdio_client(server_config) | ||
|
||
else: | ||
raise NotImplementedError( | ||
"Only StdioServerParameters are supported for now" | ||
) | ||
|
||
return ClientInstance(name, client) | ||
|
||
def get_client(self, server_name: str): | ||
return self.clients[server_name] | ||
|
||
def get_clients(self): | ||
return list(self.clients.items()) | ||
|
||
|
||
ClientManager = MCPClientManager() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from mcp import ClientSession, StdioServerParameters, stdio_client | ||
from .ClientInstance import ClientInstance | ||
from loguru import logger | ||
import shutil | ||
import os | ||
|
||
async def construct_stdio_client(config: StdioServerParameters): | ||
logger.log("DEBUG", "Constructing Stdio Server") | ||
|
||
env = dict(os.environ.copy()) | ||
|
||
if config.env is not None: | ||
env.update(config.env) | ||
|
||
server_parameters = StdioServerParameters( | ||
command=shutil.which(config.command), | ||
args=config.args, | ||
env=env, | ||
) | ||
|
||
return stdio_client(server_parameters) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from fastapi import APIRouter | ||
from mcp_clients.McpClientManager import ClientManager | ||
|
||
router = APIRouter( | ||
prefix="/mcp" | ||
) | ||
|
||
@router.get("/tools") | ||
async def get_tools() : | ||
tools = {} | ||
|
||
for name, client in ClientManager.get_clients() : | ||
tools[name] = await client.session.list_tools() | ||
|
||
return tools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters