Skip to content

Commit

Permalink
fix issue with calling tool, and add test (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsabran authored Dec 29, 2024
1 parent ca552ec commit a436420
Show file tree
Hide file tree
Showing 17 changed files with 490 additions and 83 deletions.
13 changes: 13 additions & 0 deletions ExampleMCPServer/Sources/Tools.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

import AppKit
import JSONSchemaBuilder
import MCPServer

// MARK: - EmptyInput

@Schemable
struct EmptyInput { }

let testTool = Tool(name: "test") { (_: EmptyInput) async throws in
[]
}
1 change: 1 addition & 0 deletions ExampleMCPServer/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ let server = try await MCPServer(
Tool(name: "repeat") { (input: RepeatToolInput) in
[.text(.init(text: input.text))]
},
testTool,
]),
transport: proxy(transport))

Expand Down
8 changes: 4 additions & 4 deletions MCPClient/Sources/MCPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public actor MCPClient: MCPClientInterface {
async throws -> CallToolResult
{
guard serverInfo.capabilities.tools != nil else {
throw MCPError.notSupported
throw MCPError.capabilityNotSupported
}
var progressToken: String? = nil
if let progressHandler {
Expand All @@ -107,14 +107,14 @@ public actor MCPClient: MCPClientInterface {

public func getPrompt(named name: String, arguments: JSON? = nil) async throws -> GetPromptResult {
guard serverInfo.capabilities.prompts != nil else {
throw MCPError.notSupported
throw MCPError.capabilityNotSupported
}
return try await connection.getPrompt(.init(name: name, arguments: arguments))
}

public func readResource(uri: String) async throws -> ReadResourceResult {
guard serverInfo.capabilities.resources != nil else {
throw MCPError.notSupported
throw MCPError.capabilityNotSupported
}
return try await connection.readResource(.init(uri: uri))
}
Expand All @@ -137,7 +137,7 @@ public actor MCPClient: MCPClientInterface {
private static func connectToServer(connection: MCPClientConnectionInterface) async throws -> ServerInfo {
let response = try await connection.initialize()
guard response.protocolVersion == MCP.protocolVersion else {
throw MCPClientError.versionMismatch
throw MCPClientError.versionMismatch(received: response.protocolVersion, expected: MCP.protocolVersion)
}

try await connection.acknowledgeInitialization()
Expand Down
17 changes: 16 additions & 1 deletion MCPClient/Sources/MCPClientInterface.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Foundation
import JSONRPC
import MCPInterface
import MemberwiseInit
Expand Down Expand Up @@ -53,6 +54,20 @@ public struct ClientCapabilityHandlers {
// MARK: - MCPClientError

public enum MCPClientError: Error {
case versionMismatch
case versionMismatch(received: String, expected: String)
case toolCallError(executionErrors: [CallToolResult.ExecutionError])
}

// MARK: LocalizedError

extension MCPClientError: LocalizedError {

public var errorDescription: String? {
switch self {
case .versionMismatch(let received, let expected):
return "Version mismatch between server and client. Received: \(received), Expected: \(expected)"
case .toolCallError(let executionErrors):
return "Error executing tool:\n\(executionErrors.map { $0.errorDescription ?? "unknown error" }.joined(separator: "\n\n"))"
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

import MCPClient
import MCPInterface

#if DEBUG
// TODO: move to a test helper package
// MARK: - MockMCPClientConnection

/// A mock `MCPClientConnection` that can be used in tests.
class MockMCPClientConnection: MCPClientConnectionInterface {
Expand Down Expand Up @@ -202,7 +202,8 @@ class MockMCPClientConnection: MCPClientConnectionInterface {

}

// MARK: - MockMCPClientConnectionError

enum MockMCPClientConnectionError: Error {
case notImplemented(function: String)
}
#endif
17 changes: 15 additions & 2 deletions MCPInterface/Sources/Interfaces.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Foundation
import JSONRPC
import MemberwiseInit

Expand Down Expand Up @@ -50,15 +51,27 @@ extension CapabilityStatus {
case .supported(let capability):
return capability
case .notSupported:
throw MCPError.notSupported
throw MCPError.capabilityNotSupported
}
}
}

// MARK: - MCPError

public enum MCPError: Error {
case notSupported
case capabilityNotSupported
}

// MARK: LocalizedError

extension MCPError: LocalizedError {

public var errorDescription: String? {
switch self {
case .capabilityNotSupported:
return "The requested capability is not supported"
}
}
}

public typealias HandleServerRequest = (ServerRequest, (AnyJRPCResponse) -> Void)
Expand Down
27 changes: 27 additions & 0 deletions MCPInterface/Sources/mcp_interfaces/Interface+extensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,30 @@ extension PromptReference {
name = try container.decode(String.self, forKey: "name")
}
}

// MARK: - CallToolResult.ExecutionError + LocalizedError

extension CallToolResult.ExecutionError: LocalizedError {

public var errorDescription: String? {
text
}
}

// MARK: - JRPCError + LocalizedError

extension JRPCError: LocalizedError {

public var errorDescription: String? {
if let data {
do {
if let dataStr = String(data: try JSONEncoder().encode(data), encoding: .utf8) {
return "JRPC error \(code): \(message)\n\(dataStr)"
}
} catch {
// will fall back to the default error description
}
}
return "JRPC error \(code): \(message)"
}
}
3 changes: 3 additions & 0 deletions MCPServer/Sources/Convenience/JSONSchema+typealias.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import JSONSchema

typealias JSONSchema_JSONValue = JSONSchema.JSONValue
51 changes: 14 additions & 37 deletions MCPServer/Sources/Convenience/Schemable+extensions.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Foundation
import JSONRPC
import JSONSchema
import JSONSchemaBuilder
import MCPInterface
Expand All @@ -17,7 +18,7 @@ import MCPInterface

/// Definition for a tool the client can call.
public protocol CallableTool {
associatedtype Input: Decodable
associatedtype Input
/// A JSON Schema object defining the expected parameters for the tool.
var inputSchema: JSON { get }
/// The name of the tool.
Expand Down Expand Up @@ -85,44 +86,22 @@ extension Tool where Input: Schemable {
description: description,
inputSchema: Input.schema.schemaValue.json,
decodeInput: { data in
let json = try JSONDecoder().decode(JSONValue.self, from: data)
let json = try JSONDecoder().decode(JSONSchema_JSONValue.self, from: data)

switch Input.schema.parse(json) {
case .valid(let value):
return value
case .invalid(let errors):
throw errors.first ?? MCPServerError.toolCallError(errors)
case .invalid:
throw MCPServerError.decodingError(input: data, schema: Input.schema.schemaValue.json)
}
},
call: call)
}
}

extension Tool where Input: Decodable {
public init(
name: String,
description: String? = nil,
inputSchema: JSON,
call: @escaping (Input) async throws -> [TextContentOrImageContentOrEmbeddedResource])
{
self.init(
name: name,
description: description,
inputSchema: inputSchema,
decodeInput: { data in
try JSONDecoder().decode(Input.self, from: data)
},
call: call)
}
}

extension CallableTool {
public func decodeInput(_ input: JSON?) throws -> Input {
let data = try JSONEncoder().encode(input)
return try JSONDecoder().decode(Input.self, from: data)
}

public func call(_ input: JSON?) async throws -> [TextContentOrImageContentOrEmbeddedResource] {
let input = try decodeInput(input)
public func call(json: JSON?) async throws -> [TextContentOrImageContentOrEmbeddedResource] {
let input: Input = try decodeInput(json)
return try await call(input)
}
}
Expand All @@ -138,11 +117,13 @@ extension Array where Element == any CallableTool {
handler: { request in
let name = request.name
guard let tool = toolsByName[name] else {
throw MCPError.notSupported
throw JSONRPCResponseError<JSONRPC.JSONValue>(
code: JRPCErrorCodes.invalidParams.rawValue,
message: "Unknown tool: \(name)")
}
let arguments = request.arguments
do {
let content = try await tool.call(arguments)
let content = try await tool.call(json: arguments)
return CallToolResult(content: content)
} catch {
return CallToolResult(content: [.text(.init(text: error.localizedDescription))], isError: true)
Expand All @@ -158,13 +139,13 @@ extension Array where Element == any CallableTool {
}

/// Convert between the JSON representation from `JSONSchema` and ours
extension [KeywordIdentifier: JSONValue] {
extension [KeywordIdentifier: JSONSchema_JSONValue] {
fileprivate var json: JSON {
.object(mapValues { $0.value })
}
}

extension JSONValue {
extension JSONSchema_JSONValue {
fileprivate var value: JSON.Value {
switch self {
case .null:
Expand All @@ -184,7 +165,3 @@ extension JSONValue {
}
}
}

// MARK: - ParseIssue + Error

extension ParseIssue: @retroactive Error { }
Loading

0 comments on commit a436420

Please sign in to comment.