Skip to content

Commit

Permalink
Add support for client capabilities (#5)
Browse files Browse the repository at this point in the history
* nit: rename

* move files

* pass capability in initialization

* handle server requests

* make API more robust

* lint

---------

Co-authored-by: Gui Sabran <gsabran@www.com>
  • Loading branch information
gsabran and Gui Sabran authored Dec 26, 2024
1 parent ef2617b commit 4cbe50a
Show file tree
Hide file tree
Showing 29 changed files with 351 additions and 129 deletions.
96 changes: 77 additions & 19 deletions MCPClient/Sources/MCPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import Combine
import Foundation
import MCPShared

public typealias SamplingRequestHandler = ((CreateMessageRequest.Params) async throws -> CreateMessageRequest.Result)
public typealias ListRootsRequestHandler = ((ListRootsRequest.Params?) async throws -> ListRootsRequest.Result)

// MARK: - MCPClient

public actor MCPClient: MCPClientInterface {
Expand All @@ -11,26 +14,30 @@ public actor MCPClient: MCPClientInterface {

public init(
info: Implementation,
capabilities: ClientCapabilities,
transport: Transport)
transport: Transport,
capabilities: ClientCapabilityHandlers = .init())
async throws {
try await self.init(
info: info,
capabilities: capabilities,
getMcpConnection: { try MCPConnection(
samplingRequestHandler: capabilities.sampling?.handler,
listRootRequestHandler: capabilities.roots?.handler,
connection: try MCPClientConnection(
info: info,
capabilities: capabilities,
transport: transport) })
capabilities: ClientCapabilities(
experimental: nil, // TODO: support experimental requests
roots: capabilities.roots?.info,
sampling: capabilities.sampling?.info),
transport: transport))
}

init(
info _: Implementation,
capabilities _: ClientCapabilities,
getMcpConnection: @escaping () throws -> MCPConnectionInterface)
samplingRequestHandler: SamplingRequestHandler? = nil,
listRootRequestHandler: ListRootsRequestHandler? = nil,
connection: MCPClientConnectionInterface)
async throws {
self.getMcpConnection = getMcpConnection

// Initialize the connection, and then update server capabilities.
self.connection = connection
self.samplingRequestHandler = samplingRequestHandler
self.listRootRequestHandler = listRootRequestHandler
try await connect()
Task { try await self.updateTools() }
Task { try await self.updatePrompts() }
Expand Down Expand Up @@ -111,23 +118,28 @@ public actor MCPClient: MCPClientInterface {
return try await connectionInfo.connection.readResource(.init(uri: uri))
}

// MARK: Internal

let connection: MCPClientConnectionInterface

// MARK: Private

private struct ConnectionInfo {
let connection: MCPConnectionInterface
let connection: MCPClientConnectionInterface
let serverInfo: Implementation
let serverCapabilities: ServerCapabilities
}

private let samplingRequestHandler: SamplingRequestHandler?
private let listRootRequestHandler: ListRootsRequestHandler?

private var connectionInfo: ConnectionInfo?

private let _tools = CurrentValueSubject<ServerCapabilityState<[Tool]>?, Never>(nil)
private let _prompts = CurrentValueSubject<ServerCapabilityState<[Prompt]>?, Never>(nil)
private let _resources = CurrentValueSubject<ServerCapabilityState<[Resource]>?, Never>(nil)
private let _resourceTemplates = CurrentValueSubject<ServerCapabilityState<[ResourceTemplate]>?, Never>(nil)

private let getMcpConnection: () throws -> MCPConnectionInterface

private var progressHandlers = [String: (progress: Double, total: Double?) -> Void]()

private func startListeningToNotifications() async throws {
Expand Down Expand Up @@ -163,6 +175,52 @@ public actor MCPClient: MCPClientInterface {
}
}

private func startListeningToRequests() async throws {
let connectionInfo = try getConnectionInfo()
let requests = await connectionInfo.connection.requestsToHandle
Task { [weak self] in
for await(request, completion) in requests {
guard let self else {
completion(.failure(.init(
code: JRPCErrorCodes.internalError.rawValue,
message: "The client disconnected")))
return
}
switch request {
case .createMessage(let params):
if let handler = await self.samplingRequestHandler {
do {
completion(.success(try await handler(params)))
} catch {
completion(.failure(.init(
code: JRPCErrorCodes.internalError.rawValue,
message: error.localizedDescription)))
}
} else {
completion(.failure(.init(
code: JRPCErrorCodes.invalidRequest.rawValue,
message: "Sampling is not supported by this client")))
}

case .listRoots(let params):
if let handler = await self.listRootRequestHandler {
do {
completion(.success(try await handler(params)))
} catch {
completion(.failure(.init(
code: JRPCErrorCodes.internalError.rawValue,
message: error.localizedDescription)))
}
} else {
completion(.failure(.init(
code: JRPCErrorCodes.invalidRequest.rawValue,
message: "Listing roots is not supported by this client")))
}
}
}
}
}

private func startPinging() {
// TODO
}
Expand Down Expand Up @@ -212,19 +270,19 @@ public actor MCPClient: MCPClientInterface {
}

private func connect() async throws {
let mcpConnection = try getMcpConnection()
let response = try await mcpConnection.initialize()
let response = try await connection.initialize()
guard response.protocolVersion == MCP.protocolVersion else {
throw MCPClientError.versionMismatch
}

connectionInfo = ConnectionInfo(
connection: mcpConnection,
connection: connection,
serverInfo: response.serverInfo,
serverCapabilities: response.capabilities)

try await mcpConnection.acknowledgeInitialization()
try await connection.acknowledgeInitialization()
try await startListeningToNotifications()
try await startListeningToRequests()
startPinging()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import OSLog

private let mcpLogger = Logger(subsystem: Bundle.main.bundleIdentifier.map { "\($0).mcp" } ?? "com.app.mcp", category: "mcp")

// MARK: - MCPConnection
// MARK: - MCPClientConnection

public actor MCPConnection: MCPConnectionInterface {
public actor MCPClientConnection: MCPClientConnectionInterface {

// MARK: Lifecycle

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ public typealias AnyJRPCResponse = Swift.Result<Encodable & Sendable, AnyJSONRPC

public typealias HandleServerRequest = (ServerRequest, (AnyJRPCResponse) -> Void)

// MARK: - MCPConnectionInterface
// MARK: - MCPClientConnectionInterface

/// The MCP JRPC Bridge is a stateless interface to the MCP server that provides a higher level Swift interface.
/// It does not implement any of the stateful behaviors of the MCP server, such as subscribing to changes, detecting connection health,
/// ensuring that the connection has been initialized before being used etc.
///
/// For most use cases, `MCPClient` should be a preferred interface.
public protocol MCPConnectionInterface {
public protocol MCPClientConnectionInterface {
/// The notifications received by the server.
var notifications: AsyncStream<ServerNotification> { get async }
// TODO: look at moving the request handler to the init
Expand Down
13 changes: 13 additions & 0 deletions MCPClient/Sources/MCPClientInterface.swift
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import JSONRPC
import MCPShared
import MemberwiseInit

// MARK: - MCPClientInterface

public protocol MCPClientInterface { }

public typealias Transport = DataChannel

// MARK: - ClientCapabilityHandlers

/// Describes the supported capabilities of an MCP client, and how to handle each of the supported ones.
///
/// Note: This is similar to `ClientCapabilities`, with the addition of the handler function.
@MemberwiseInit(.public, _optionalsDefaultNil: true)
public struct ClientCapabilityHandlers {
public let roots: CapabilityHandler<ListChangedCapability, ListRootsRequestHandler>?
public let sampling: CapabilityHandler<EmptyObject, SamplingRequestHandler>?
// TODO: add experimental
}

// MARK: - MCPClientError

public enum MCPClientError: Error {
Expand Down
36 changes: 18 additions & 18 deletions MCPClient/Sources/MockMCPConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import MCPShared
#if DEBUG
// TODO: move to a test helper package

/// A mock `MCPConnection` that can be used in tests.
class MockMCPConnection: MCPConnectionInterface {
/// A mock `MCPClientConnection` that can be used in tests.
class MockMCPClientConnection: MCPClientConnectionInterface {

// MARK: Lifecycle

Expand Down Expand Up @@ -89,109 +89,109 @@ class MockMCPConnection: MCPConnectionInterface {
if let initializeStub {
return try await initializeStub()
}
throw MockMCPConnectionError.notImplemented(function: "initialize")
throw MockMCPClientConnectionError.notImplemented(function: "initialize")
}

func acknowledgeInitialization() async throws {
if let acknowledgeInitializationStub {
return try await acknowledgeInitializationStub()
}
throw MockMCPConnectionError.notImplemented(function: "acknowledgeInitialization")
throw MockMCPClientConnectionError.notImplemented(function: "acknowledgeInitialization")
}

func ping() async throws {
if let pingStub {
return try await pingStub()
}
throw MockMCPConnectionError.notImplemented(function: "ping")
throw MockMCPClientConnectionError.notImplemented(function: "ping")
}

func listPrompts() async throws -> [Prompt] {
if let listPromptsStub {
return try await listPromptsStub()
}
throw MockMCPConnectionError.notImplemented(function: "listPrompts")
throw MockMCPClientConnectionError.notImplemented(function: "listPrompts")
}

func getPrompt(_ params: GetPromptRequest.Params) async throws -> GetPromptRequest.Result {
if let getPromptStub {
return try await getPromptStub(params)
}
throw MockMCPConnectionError.notImplemented(function: "getPrompt")
throw MockMCPClientConnectionError.notImplemented(function: "getPrompt")
}

func listResources() async throws -> [Resource] {
if let listResourcesStub {
return try await listResourcesStub()
}
throw MockMCPConnectionError.notImplemented(function: "listResources")
throw MockMCPClientConnectionError.notImplemented(function: "listResources")
}

func readResource(_ params: ReadResourceRequest.Params) async throws -> ReadResourceRequest.Result {
if let readResourceStub {
return try await readResourceStub(params)
}
throw MockMCPConnectionError.notImplemented(function: "readResource")
throw MockMCPClientConnectionError.notImplemented(function: "readResource")
}

func subscribeToUpdateToResource(_ params: SubscribeRequest.Params) async throws {
if let subscribeToUpdateToResourceStub {
return try await subscribeToUpdateToResourceStub(params)
}
throw MockMCPConnectionError.notImplemented(function: "subscribeToUpdateToResource")
throw MockMCPClientConnectionError.notImplemented(function: "subscribeToUpdateToResource")
}

func unsubscribeToUpdateToResource(_ params: UnsubscribeRequest.Params) async throws {
if let unsubscribeToUpdateToResourceStub {
return try await unsubscribeToUpdateToResourceStub(params)
}
throw MockMCPConnectionError.notImplemented(function: "unsubscribeToUpdateToResource")
throw MockMCPClientConnectionError.notImplemented(function: "unsubscribeToUpdateToResource")
}

func listResourceTemplates() async throws -> [ResourceTemplate] {
if let listResourceTemplatesStub {
return try await listResourceTemplatesStub()
}
throw MockMCPConnectionError.notImplemented(function: "listResourceTemplates")
throw MockMCPClientConnectionError.notImplemented(function: "listResourceTemplates")
}

func listTools() async throws -> [Tool] {
if let listToolsStub {
return try await listToolsStub()
}
throw MockMCPConnectionError.notImplemented(function: "listTools")
throw MockMCPClientConnectionError.notImplemented(function: "listTools")
}

func call(toolName: String, arguments: JSON?, progressToken: ProgressToken?) async throws -> CallToolRequest.Result {
if let callToolStub {
return try await callToolStub(toolName, arguments, progressToken)
}
throw MockMCPConnectionError.notImplemented(function: "callTool")
throw MockMCPClientConnectionError.notImplemented(function: "callTool")
}

func requestCompletion(_ params: CompleteRequest.Params) async throws -> CompleteRequest.Result {
if let requestCompletionStub {
return try await requestCompletionStub(params)
}
throw MockMCPConnectionError.notImplemented(function: "requestCompletion")
throw MockMCPClientConnectionError.notImplemented(function: "requestCompletion")
}

func setLogLevel(_ params: SetLevelRequest.Params) async throws -> SetLevelRequest.Result {
if let setLogLevelStub {
return try await setLogLevelStub(params)
}
throw MockMCPConnectionError.notImplemented(function: "setLogLevel")
throw MockMCPClientConnectionError.notImplemented(function: "setLogLevel")
}

func log(_ params: LoggingMessageNotification.Params) async throws {
if let logStub {
return try await logStub(params)
}
throw MockMCPConnectionError.notImplemented(function: "log")
throw MockMCPClientConnectionError.notImplemented(function: "log")
}
}

enum MockMCPConnectionError: Error {
enum MockMCPClientConnectionError: Error {
case notImplemented(function: String)
}

Expand Down
Loading

0 comments on commit 4cbe50a

Please sign in to comment.