From 751188fdcfd0815cf185a8f0683deae4adbb260c Mon Sep 17 00:00:00 2001 From: Nasr Date: Fri, 6 Dec 2024 16:32:13 +0700 Subject: [PATCH] mcp fixes --- Cargo.lock | 47 ++++- crates/torii/server/Cargo.toml | 4 +- crates/torii/server/src/handlers/mcp.rs | 270 +++++++++++++++--------- crates/torii/server/src/handlers/mod.rs | 2 +- crates/torii/server/src/proxy.rs | 4 +- 5 files changed, 213 insertions(+), 114 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 16ea7a1ed6..04cd6aa600 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7389,6 +7389,19 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tungstenite" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9" +dependencies = [ + "hyper 0.14.30", + "pin-project-lite", + "tokio", + "tokio-tungstenite 0.20.1", + "tungstenite 0.20.1", +] + [[package]] name = "hyper-util" version = "0.1.8" @@ -15248,26 +15261,26 @@ dependencies = [ [[package]] name = "tokio-tungstenite" -version = "0.21.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c" dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.20.1", ] [[package]] name = "tokio-tungstenite" -version = "0.22.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d46baf930138837d65e25e3b33be49c9228579a6135dbf756b5cb9e4283e7cef" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.21.0", ] [[package]] @@ -15739,6 +15752,7 @@ dependencies = [ "http-body 0.4.6", "hyper 0.14.30", "hyper-reverse-proxy", + "hyper-tungstenite", "image", "indexmap 2.5.0", "lazy_static", @@ -15748,7 +15762,7 @@ dependencies = [ "serde_json", "sqlx", "tokio", - "tokio-tungstenite 0.22.0", + "tokio-tungstenite 0.20.1", "tokio-util", "torii-core", "tower 0.4.13", @@ -16036,6 +16050,25 @@ dependencies = [ "syn 2.0.77", ] +[[package]] +name = "tungstenite" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e3dac10fd62eaf6617d3a904ae222845979aec67c615d1c842b4002c7666fb9" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 0.2.12", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "tungstenite" version = "0.21.0" diff --git a/crates/torii/server/Cargo.toml b/crates/torii/server/Cargo.toml index 63bc842926..2b2e136248 100644 --- a/crates/torii/server/Cargo.toml +++ b/crates/torii/server/Cargo.toml @@ -31,6 +31,6 @@ tracing.workspace = true warp.workspace = true form_urlencoded = "1.2.1" async-trait = "0.1.83" -tokio-tungstenite = "0.22.0" -hyper-tungstenite = "0.22.0" +tokio-tungstenite = "0.20.0" +hyper-tungstenite = "0.11.1" futures-util.workspace = true \ No newline at end of file diff --git a/crates/torii/server/src/handlers/mcp.rs b/crates/torii/server/src/handlers/mcp.rs index 41d7f8f25f..ff53f1e422 100644 --- a/crates/torii/server/src/handlers/mcp.rs +++ b/crates/torii/server/src/handlers/mcp.rs @@ -1,9 +1,12 @@ use std::sync::Arc; + +use base64::engine::general_purpose::STANDARD; +use base64::Engine; use futures_util::{SinkExt, StreamExt}; use hyper::{Body, Request, Response, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use sqlx::{SqlitePool, Row}; +use sqlx::{Column, Row, SqlitePool, TypeInfo}; use tokio_tungstenite::tungstenite::Message; use super::Handler; @@ -28,9 +31,9 @@ struct JsonRpcRequest { #[derive(Debug, Deserialize)] struct JsonRpcNotification { - jsonrpc: String, - method: String, - params: Option, + _jsonrpc: String, + _method: String, + _params: Option, } #[derive(Debug, Serialize)] @@ -73,6 +76,7 @@ struct ResourceCapabilities { list_changed: bool, } +#[derive(Clone)] pub struct McpHandler { pool: Arc, } @@ -83,27 +87,39 @@ impl McpHandler { } async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse { + if request.jsonrpc != JSONRPC_VERSION { + return JsonRpcResponse { + jsonrpc: JSONRPC_VERSION.to_string(), + id: request.id, + result: None, + error: Some(JsonRpcError { + code: -32600, + message: "Invalid Request".to_string(), + data: None, + }), + }; + } + match request.method.as_str() { - "initialize" => { - JsonRpcResponse { - jsonrpc: JSONRPC_VERSION.to_string(), - id: request.id, - result: Some(json!({ - "protocolVersion": MCP_VERSION, - "serverInfo": Implementation { - name: "torii-mcp".to_string(), - version: env!("CARGO_PKG_VERSION").to_string(), + "initialize" => JsonRpcResponse { + jsonrpc: JSONRPC_VERSION.to_string(), + id: request.id, + result: Some(json!({ + "protocolVersion": MCP_VERSION, + "serverInfo": Implementation { + name: "torii-mcp".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }, + "capabilities": ServerCapabilities { + tools: ToolCapabilities { + list_changed: true, }, - "capabilities": ServerCapabilities { - tools: ToolCapabilities { - list_changed: true, - }, - resources: ResourceCapabilities { - subscribe: true, - list_changed: true, - }, + resources: ResourceCapabilities { + subscribe: true, + list_changed: true, }, - "instructions": r#" + }, + "instructions": r#" Torii - Dojo Game Indexer for Starknet Torii is a specialized indexer designed for Dojo games running on Starknet. It indexes and tracks Entity Component System (ECS) data, providing a comprehensive view of game state and history. @@ -167,47 +183,44 @@ The database is optimized for querying game state and history, allowing clients - Monitor state changes - Generate game statistics "# - })), - error: None, - } + })), + error: None, }, - "tools/list" => { - JsonRpcResponse { - jsonrpc: JSONRPC_VERSION.to_string(), - id: request.id, - result: Some(json!({ - "tools": [ - { - "name": "query", - "description": "Execute a SQL query on the database", - "inputSchema": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "SQL query to execute" - } - }, - "required": ["query"] - } - }, - { - "name": "schema", - "description": "Retrieve the database schema including tables, columns, and their types", - "inputSchema": { - "type": "object", - "properties": { - "table": { - "type": "string", - "description": "Optional table name to get schema for. If omitted, returns schema for all tables." - } + "tools/list" => JsonRpcResponse { + jsonrpc: JSONRPC_VERSION.to_string(), + id: request.id, + result: Some(json!({ + "tools": [ + { + "name": "query", + "description": "Execute a SQL query on the database", + "inputSchema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "SQL query to execute" + } + }, + "required": ["query"] + } + }, + { + "name": "schema", + "description": "Retrieve the database schema including tables, columns, and their types", + "inputSchema": { + "type": "object", + "properties": { + "table": { + "type": "string", + "description": "Optional table name to get schema for. If omitted, returns schema for all tables." } } } - ] - })), - error: None, - } + } + ] + })), + error: None, }, "tools/call" => { if let Some(params) = &request.params { @@ -237,7 +250,7 @@ The database is optimized for querying game state and history, allowing clients }), } } - }, + } _ => JsonRpcResponse { jsonrpc: JSONRPC_VERSION.to_string(), id: request.id, @@ -260,13 +273,11 @@ The database is optimized for querying game state and history, allowing clients while let Some(msg) = read.next().await { if let Ok(Message::Text(text)) = msg { let response = match serde_json::from_str::(&text) { - Ok(JsonRpcMessage::Request(request)) => { - self.handle_request(request).await - }, + Ok(JsonRpcMessage::Request(request)) => self.handle_request(request).await, Ok(JsonRpcMessage::Notification(_notification)) => { // Handle notifications if needed continue; - }, + } Err(e) => JsonRpcResponse { jsonrpc: JSONRPC_VERSION.to_string(), id: Value::Null, @@ -279,9 +290,8 @@ The database is optimized for querying game state and history, allowing clients }, }; - if let Err(e) = write - .send(Message::Text(serde_json::to_string(&response).unwrap())) - .await + if let Err(e) = + write.send(Message::Text(serde_json::to_string(&response).unwrap())).await { eprintln!("Error sending message: {}", e); break; @@ -291,14 +301,15 @@ The database is optimized for querying game state and history, allowing clients } async fn handle_schema_tool(&self, request: JsonRpcRequest) -> JsonRpcResponse { - let table_filter = request.params + let table_filter = request + .params .as_ref() .and_then(|p| p.get("arguments")) .and_then(|args| args.get("table")) .and_then(Value::as_str); let schema_query = match table_filter { - Some(table) => format!( + Some(_table) => format!( "SELECT m.name as table_name, p.* @@ -320,19 +331,14 @@ The database is optimized for querying game state and history, allowing clients }; let rows = match table_filter { - Some(table) => sqlx::query(&schema_query) - .bind(table) - .fetch_all(&*self.pool) - .await, - None => sqlx::query(&schema_query) - .fetch_all(&*self.pool) - .await, + Some(table) => sqlx::query(&schema_query).bind(table).fetch_all(&*self.pool).await, + None => sqlx::query(&schema_query).fetch_all(&*self.pool).await, }; match rows { Ok(rows) => { let mut schema = serde_json::Map::new(); - + for row in rows { let table_name: String = row.try_get("table_name").unwrap(); let column_name: String = row.try_get("name").unwrap(); @@ -341,18 +347,24 @@ The database is optimized for querying game state and history, allowing clients let pk: bool = row.try_get::("pk").unwrap(); let default_value: Option = row.try_get("dflt_value").unwrap(); - let table_entry = schema.entry(table_name) - .or_insert_with(|| json!({ + let table_entry = schema.entry(table_name).or_insert_with(|| { + json!({ "columns": serde_json::Map::new() - })); - - if let Some(columns) = table_entry.get_mut("columns").and_then(|v| v.as_object_mut()) { - columns.insert(column_name, json!({ - "type": column_type, - "nullable": !not_null, - "primary_key": pk, - "default": default_value - })); + }) + }); + + if let Some(columns) = + table_entry.get_mut("columns").and_then(|v| v.as_object_mut()) + { + columns.insert( + column_name, + json!({ + "type": column_type, + "nullable": !not_null, + "primary_key": pk, + "default": default_value + }), + ); } } @@ -367,7 +379,7 @@ The database is optimized for querying game state and history, allowing clients })), error: None, } - }, + } Err(e) => JsonRpcResponse { jsonrpc: JSONRPC_VERSION.to_string(), id: request.id, @@ -387,13 +399,63 @@ The database is optimized for querying game state and history, allowing clients match sqlx::query(query).fetch_all(&*self.pool).await { Ok(rows) => { // Convert rows to JSON using the same logic as SqlHandler - let result = rows.iter().map(|row| { - let mut obj = serde_json::Map::new(); - for (i, column) in row.columns().iter().enumerate() { - // ... row conversion logic from SqlHandler ... - } - Value::Object(obj) - }).collect::>(); + let result = rows + .iter() + .map(|row| { + let mut obj = serde_json::Map::new(); + for (i, column) in row.columns().iter().enumerate() { + let value: serde_json::Value = match column.type_info().name() { + "TEXT" => row.get::, _>(i).map_or( + serde_json::Value::Null, + serde_json::Value::String, + ), + "INTEGER" => row + .get::, _>(i) + .map_or(serde_json::Value::Null, |n| { + serde_json::Value::Number(n.into()) + }), + "REAL" => row.get::, _>(i).map_or( + serde_json::Value::Null, + |f| { + serde_json::Number::from_f64(f).map_or( + serde_json::Value::Null, + serde_json::Value::Number, + ) + }, + ), + "BLOB" => row.get::>, _>(i).map_or( + serde_json::Value::Null, + |bytes| { + serde_json::Value::String(STANDARD.encode(bytes)) + }, + ), + _ => { + // Try different types in order + if let Ok(val) = row.try_get::(i) { + serde_json::Value::Number(val.into()) + } else if let Ok(val) = row.try_get::(i) { + // Handle floating point numbers + serde_json::json!(val) + } else if let Ok(val) = row.try_get::(i) { + serde_json::Value::Bool(val) + } else if let Ok(val) = row.try_get::(i) { + serde_json::Value::String(val) + } else { + // Handle or fallback to BLOB as base64 + let val = row.get::>, _>(i); + val.map_or(serde_json::Value::Null, |bytes| { + serde_json::Value::String( + STANDARD.encode(bytes), + ) + }) + } + } + }; + obj.insert(column.name().to_string(), value); + } + Value::Object(obj) + }) + .collect::>(); JsonRpcResponse { jsonrpc: JSONRPC_VERSION.to_string(), @@ -406,7 +468,7 @@ The database is optimized for querying game state and history, allowing clients })), error: None, } - }, + } Err(e) => JsonRpcResponse { jsonrpc: JSONRPC_VERSION.to_string(), id: request.id, @@ -448,11 +510,13 @@ The database is optimized for querying game state and history, allowing clients #[async_trait::async_trait] impl Handler for McpHandler { fn should_handle(&self, req: &Request) -> bool { - req.uri().path().starts_with("/mcp") && req.headers() - .get("upgrade") - .and_then(|h| h.to_str().ok()) - .map(|h| h.eq_ignore_ascii_case("websocket")) - .unwrap_or(false) + req.uri().path().starts_with("/mcp") + && req + .headers() + .get("upgrade") + .and_then(|h| h.to_str().ok()) + .map(|h| h.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) } async fn handle(&self, req: Request) -> Response { @@ -475,4 +539,4 @@ impl Handler for McpHandler { .unwrap() } } -} \ No newline at end of file +} diff --git a/crates/torii/server/src/handlers/mod.rs b/crates/torii/server/src/handlers/mod.rs index eb56474a1e..d40ece8ad0 100644 --- a/crates/torii/server/src/handlers/mod.rs +++ b/crates/torii/server/src/handlers/mod.rs @@ -1,8 +1,8 @@ pub mod graphql; pub mod grpc; +pub mod mcp; pub mod sql; pub mod static_files; -pub mod mcp; use hyper::{Body, Request, Response}; diff --git a/crates/torii/server/src/proxy.rs b/crates/torii/server/src/proxy.rs index 7f276aedaf..bf1c57ae12 100644 --- a/crates/torii/server/src/proxy.rs +++ b/crates/torii/server/src/proxy.rs @@ -19,6 +19,7 @@ use tower_http::cors::{AllowOrigin, CorsLayer}; use crate::handlers::graphql::GraphQLHandler; use crate::handlers::grpc::GrpcHandler; +use crate::handlers::mcp::McpHandler; use crate::handlers::sql::SqlHandler; use crate::handlers::static_files::StaticHandler; use crate::handlers::Handler; @@ -172,10 +173,11 @@ async fn handle( req: Request, ) -> Result, Infallible> { let handlers: Vec> = vec![ - Box::new(SqlHandler::new(pool)), + Box::new(SqlHandler::new(pool.clone())), Box::new(GraphQLHandler::new(client_ip, graphql_addr)), Box::new(GrpcHandler::new(client_ip, grpc_addr)), Box::new(StaticHandler::new(client_ip, artifacts_addr)), + Box::new(McpHandler::new(pool.clone())), ]; for handler in handlers {