diff --git a/mcp_bridge/mcp_server/sse.py b/mcp_bridge/mcp_server/sse.py index a32f062..96629d9 100644 --- a/mcp_bridge/mcp_server/sse.py +++ b/mcp_bridge/mcp_server/sse.py @@ -1,38 +1,33 @@ import asyncio from anyio import BrokenResourceError from fastapi.responses import StreamingResponse -from mcp.server.sse import SseServerTransport +from .sse_transport import SseServerTransport from fastapi import APIRouter, Request -from loguru import logger +from pydantic import ValidationError from .server import server, options router = APIRouter(prefix="/sse") -sse = SseServerTransport("/mcp-server/sse") +sse = SseServerTransport("./messages") @router.get("/", response_class=StreamingResponse) async def handle_sse(request: Request): - """sse endpoint for using MCP-Bridge with external clients""" - - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: + async with sse.connect_sse(request) as streams: try: - await server.run(*streams, initialization_options=options) - + await server.run(streams[0], streams[1], options) except BrokenResourceError: - # Handle gracefully when client disconnects - logger.info("Client disconnected from SSE connection") + pass except asyncio.CancelledError: - logger.info("SSE connection was cancelled") - - -@router.post("/") + pass + except ValidationError: + pass + except Exception: + raise + await request.close() + +@router.post("/messages") async def handle_messages(request: Request): - try: - await sse.handle_post_message(request.scope, request.receive, request._send) - except Exception as e: - logger.exception(e) - # raise e + await sse.handle_post_message(request.scope, request.receive, request._send) + await request.close() diff --git a/mcp_bridge/mcp_server/sse_transport.py b/mcp_bridge/mcp_server/sse_transport.py new file mode 100644 index 0000000..8a341d4 --- /dev/null +++ b/mcp_bridge/mcp_server/sse_transport.py @@ -0,0 +1,149 @@ +""" + +This is a modification of the original code from the mcp sdk +They try to directly contron the ASGI lifespan which causes issues with fastapi + +also switched the logger to loguru since we are vendoring it anyway + +""" + +from contextlib import asynccontextmanager +from typing import Any +from urllib.parse import quote +from uuid import UUID, uuid4 + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError +from sse_starlette import EventSourceResponse +from fastapi.requests import Request +from fastapi.responses import Response +from starlette.types import Receive, Scope, Send + +import mcp.types as types + +from loguru import logger +logger.disable("mcp_server.sse_transport") + +class SseServerTransport: + """ + SSE server transport for MCP. This class provides _two_ ASGI applications, + suitable to be used with a framework like Starlette and a server like Hypercorn: + + 1. connect_sse() is an ASGI application which receives incoming GET requests, + and sets up a new SSE stream to send server messages to the client. + 2. handle_post_message() is an ASGI application which receives incoming POST + requests, which should contain client messages that link to a + previously-established SSE session. + """ + + _endpoint: str + _read_stream_writers: dict[ + UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception] + ] + + def __init__(self, endpoint: str) -> None: + """ + Creates a new SSE server transport, which will direct the client to POST + messages to the relative or absolute URL given. + """ + + super().__init__() + self._endpoint = endpoint + self._read_stream_writers = {} + logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") + + @asynccontextmanager + async def connect_sse(self, request: Request): + if request.scope["type"] != "http": + logger.error("connect_sse received non-HTTP request") + raise ValueError("connect_sse can only handle HTTP requests") + + logger.debug("Setting up SSE connection") + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + session_id = uuid4() + session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" + self._read_stream_writers[session_id] = read_stream_writer + logger.debug(f"Created new session with ID: {session_id}") + + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream( + 0, dict[str, Any] + ) + + async def sse_writer(): + logger.debug("Starting SSE writer") + async with sse_stream_writer, write_stream_reader: + await sse_stream_writer.send({"event": "endpoint", "data": session_uri}) + logger.debug(f"Sent endpoint event: {session_uri}") + + async for message in write_stream_reader: + logger.debug(f"Sending message via SSE: {message}") + await sse_stream_writer.send( + { + "event": "message", + "data": message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + ) + + async with anyio.create_task_group() as tg: + response = EventSourceResponse( + content=sse_stream_reader, data_sender_callable=sse_writer + ) + logger.debug("Starting SSE response task") + tg.start_soon(response, request.scope, request.receive, request._send) + + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + + async def handle_post_message( + self, scope: Scope, receive: Receive, send: Send + ) -> Response: + logger.debug("Handling POST message") + request = Request(scope, receive) + + session_id_param = request.query_params.get("session_id") + if session_id_param is None: + logger.warning("Received request without session_id") + response = Response("session_id is required", status_code=400) + return response + + try: + session_id = UUID(hex=session_id_param) + logger.debug(f"Parsed session ID: {session_id}") + except ValueError: + logger.warning(f"Received invalid session ID: {session_id_param}") + response = Response("Invalid session ID", status_code=400) + return response + + writer = self._read_stream_writers.get(session_id) + if not writer: + logger.warning(f"Could not find session for ID: {session_id}") + response = Response("Could not find session", status_code=404) + return response + + json = await request.json() + logger.debug(f"Received JSON: {json}") + + try: + message = types.JSONRPCMessage.model_validate(json) + logger.debug(f"Validated client message: {message}") + except ValidationError as err: + logger.error(f"Failed to parse message: {err}") + response = Response("Could not parse message", status_code=400) + await writer.send(err) + return response + + logger.debug(f"Sending message to writer: {message}") + response = Response("Accepted", status_code=202) + await writer.send(message) + return response