diff --git a/base.go b/base.go new file mode 100644 index 0000000..a2dc6d3 --- /dev/null +++ b/base.go @@ -0,0 +1,34 @@ +package mcp + +import ( + "context" + "strconv" +) + +type base struct { + router *router + stream Stream + interceptors []Interceptor +} + +func (b *base) listen(ctx context.Context, handler func(ctx context.Context, msg *Message) error) error { + for { + msg, err := b.stream.Recv() + if err != nil { + return err + } + if msg.Method != nil { + go func() { + handler(ctx, msg) + }() + } else { + id, err := strconv.ParseUint(msg.ID.String(), 10, 64) + if err != nil { + continue + } + if inbox, ok := b.router.Remove(id); ok { + inbox <- msg + } + } + } +} diff --git a/call.go b/call.go new file mode 100644 index 0000000..eb51278 --- /dev/null +++ b/call.go @@ -0,0 +1,77 @@ +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" +) + +func call[P any, R any](ctx context.Context, c *base, method string, req *Request[P]) (*Response[R], error) { + id, inbox := c.router.Add() + + var interceptor Interceptor + if len(c.interceptors) > 0 { + interceptor = newStack(c.interceptors) + } else { + interceptor = UnaryInterceptorFunc( + func(next UnaryFunc) UnaryFunc { + return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { + return next(ctx, request) + }) + }, + ) + } + + inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { + rawmsg, err := json.Marshal(req.Params) + if err != nil { + return nil, err + } + + msgID := json.Number(request.ID()) + msgVersion := "2.0" + msgParams := json.RawMessage(rawmsg) + + msg := &Message{ + ID: &msgID, + JsonRPC: &msgVersion, + Method: &method, + Params: &msgParams, + } + + if err := c.stream.Send(msg); err != nil { + return nil, err + } + + var result R + + select { + case resp := <-inbox: + if resp.Error != nil { + return nil, NewError(resp.Error.Code, errors.New(resp.Error.Message)) + } + if resp.Result == nil { + return nil, fmt.Errorf("no result") + } + if err := json.Unmarshal(*resp.Result, &result); err != nil { + return nil, err + } + case <-ctx.Done(): + return nil, ctx.Err() + } + + return NewResponse(&result), nil + }) + + req.id = strconv.FormatUint(id, 10) + req.method = method + + resp, err := interceptor.WrapUnary(inner)(ctx, req) + if err != nil { + return nil, err + } + + return resp.(*Response[R]), nil +} diff --git a/client.go b/client.go index 196c452..1d2f78e 100644 --- a/client.go +++ b/client.go @@ -1,178 +1,107 @@ package mcp import ( - "bufio" "context" - "encoding/json" - "errors" "fmt" - "io" - "strconv" - "sync" - - "github.com/riza-io/mcp-go/internal/jsonrpc" ) -type Client interface { - Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error) - ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error) - ListTools(ctx context.Context, req *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) - CallTool(ctx context.Context, req *Request[CallToolRequest]) (*Response[CallToolResponse], error) - ListPrompts(ctx context.Context, req *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error) - GetPrompt(ctx context.Context, req *Request[GetPromptRequest]) (*Response[GetPromptResponse], error) - ReadResource(ctx context.Context, req *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error) - ListResourceTemplates(ctx context.Context, req *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error) - Completion(ctx context.Context, req *Request[CompletionRequest]) (*Response[CompletionResponse], error) - Ping(ctx context.Context, req *Request[PingRequest]) (*Response[PingResponse], error) - SetLogLevel(ctx context.Context, req *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error) -} - -type StdioClient struct { - in io.Reader - out io.Writer - scanner *bufio.Scanner - next int - lock sync.Mutex +type ClientHandler interface { + Sampling(ctx context.Context, request *Request[SamplingRequest]) (*Response[SamplingResponse], error) + Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) + LogMessage(ctx context.Context, request *Request[LogMessageRequest]) +} + +type UnimplementedClient struct{} + +func (u *UnimplementedClient) Sampling(ctx context.Context, request *Request[SamplingRequest]) (*Response[SamplingResponse], error) { + return nil, fmt.Errorf("not implemented") +} + +func (u *UnimplementedClient) LogMessage(ctx context.Context, request *Request[LogMessageRequest]) { +} + +func (c *UnimplementedClient) Ping(ctx context.Context, req *Request[PingRequest]) (*Response[PingResponse], error) { + return NewResponse(&PingResponse{}), nil +} + +type Client struct { + handler ClientHandler interceptors []Interceptor + base *base } -func NewStdioClient(stdin io.Reader, stdout io.Writer, opts ...Option) Client { - c := &StdioClient{ - in: stdin, - out: stdout, - scanner: bufio.NewScanner(stdin), +func NewClient(stream Stream, handler ClientHandler, opts ...Option) *Client { + c := &Client{ + handler: handler, } - for _, opt := range opts { opt.applyToClient(c) } - + c.base = &base{ + router: newRouter(), + interceptors: c.interceptors, + stream: stream, + } return c } -func clientCallUnary[P any, R any](ctx context.Context, c *StdioClient, method string, req *Request[P]) (*Response[R], error) { - // Ensure that we are not sending multiple requests at the same time - c.lock.Lock() - defer c.lock.Unlock() - - defer func() { - // Increment the ID counter - c.next++ - }() - - var interceptor Interceptor - if len(c.interceptors) > 0 { - interceptor = newStack(c.interceptors) - } else { - interceptor = UnaryInterceptorFunc( - func(next UnaryFunc) UnaryFunc { - return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { - return next(ctx, request) - }) - }, - ) - } - - inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { - rawmsg, err := json.Marshal(req.Params) - if err != nil { - return nil, err - } - - msg := jsonrpc.Request{ - ID: json.Number(request.ID()), - JsonRPC: "2.0", - Method: request.Method(), - Params: json.RawMessage(rawmsg), - } - - bs, err := json.Marshal(msg) - if err != nil { - return nil, err - } - - fmt.Fprintln(c.out, string(bs)) - - var result R - - for c.scanner.Scan() { - line := c.scanner.Bytes() - - var resp jsonrpc.Response - - if err := json.Unmarshal(line, &resp); err != nil { - return nil, err - } - - if resp.Error != nil { - return nil, NewError(resp.Error.Code, errors.New(resp.Error.Message)) - } - - if err := json.Unmarshal(resp.Result, &result); err != nil { - return nil, err - } - - break - } - - if err := c.scanner.Err(); err != nil { - return nil, err - } - - return NewResponse(&result), nil - }) - - req.id = strconv.Itoa(c.next) - req.method = method +// sync.Once? +func (c *Client) Listen(ctx context.Context) error { + return c.base.listen(ctx, c.processMessage) +} - resp, err := interceptor.WrapUnary(inner)(ctx, req) - if err != nil { - return nil, err +func (c *Client) processMessage(ctx context.Context, msg *Message) error { + srv := c.handler + switch m := *msg.Method; m { + case "ping": + return process(ctx, c.base, msg, srv.Ping) + case "notifications/message": + return process(ctx, c.base, msg, noop(srv.LogMessage)) + default: + return fmt.Errorf("unknown method: %s", m) } - - return resp.(*Response[R]), nil } -func (c *StdioClient) Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error) { - return clientCallUnary[InitializeRequest, InitializeResponse](ctx, c, "initialize", request) +func (c *Client) Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error) { + return call[InitializeRequest, InitializeResponse](ctx, c.base, "initialize", request) } -func (c *StdioClient) ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error) { - return clientCallUnary[ListResourcesRequest, ListResourcesResponse](ctx, c, "resources/list", request) +func (c *Client) ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error) { + return call[ListResourcesRequest, ListResourcesResponse](ctx, c.base, "resources/list", request) } -func (c *StdioClient) ListTools(ctx context.Context, request *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) { - return clientCallUnary[ListToolsRequest, ListToolsResponse](ctx, c, "tools/list", request) +func (c *Client) ListTools(ctx context.Context, request *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) { + return call[ListToolsRequest, ListToolsResponse](ctx, c.base, "tools/list", request) } -func (c *StdioClient) CallTool(ctx context.Context, request *Request[CallToolRequest]) (*Response[CallToolResponse], error) { - return clientCallUnary[CallToolRequest, CallToolResponse](ctx, c, "tools/call", request) +func (c *Client) CallTool(ctx context.Context, request *Request[CallToolRequest]) (*Response[CallToolResponse], error) { + return call[CallToolRequest, CallToolResponse](ctx, c.base, "tools/call", request) } -func (c *StdioClient) ListPrompts(ctx context.Context, request *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error) { - return clientCallUnary[ListPromptsRequest, ListPromptsResponse](ctx, c, "prompts/list", request) +func (c *Client) ListPrompts(ctx context.Context, request *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error) { + return call[ListPromptsRequest, ListPromptsResponse](ctx, c.base, "prompts/list", request) } -func (c *StdioClient) GetPrompt(ctx context.Context, request *Request[GetPromptRequest]) (*Response[GetPromptResponse], error) { - return clientCallUnary[GetPromptRequest, GetPromptResponse](ctx, c, "prompts/get", request) +func (c *Client) GetPrompt(ctx context.Context, request *Request[GetPromptRequest]) (*Response[GetPromptResponse], error) { + return call[GetPromptRequest, GetPromptResponse](ctx, c.base, "prompts/get", request) } -func (c *StdioClient) ReadResource(ctx context.Context, request *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error) { - return clientCallUnary[ReadResourceRequest, ReadResourceResponse](ctx, c, "resources/read", request) +func (c *Client) ReadResource(ctx context.Context, request *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error) { + return call[ReadResourceRequest, ReadResourceResponse](ctx, c.base, "resources/read", request) } -func (c *StdioClient) ListResourceTemplates(ctx context.Context, request *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error) { - return clientCallUnary[ListResourceTemplatesRequest, ListResourceTemplatesResponse](ctx, c, "resources/templates/list", request) +func (c *Client) ListResourceTemplates(ctx context.Context, request *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error) { + return call[ListResourceTemplatesRequest, ListResourceTemplatesResponse](ctx, c.base, "resources/templates/list", request) } -func (c *StdioClient) Completion(ctx context.Context, request *Request[CompletionRequest]) (*Response[CompletionResponse], error) { - return clientCallUnary[CompletionRequest, CompletionResponse](ctx, c, "completion", request) +func (c *Client) Completion(ctx context.Context, request *Request[CompletionRequest]) (*Response[CompletionResponse], error) { + return call[CompletionRequest, CompletionResponse](ctx, c.base, "completion", request) } -func (c *StdioClient) Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) { - return clientCallUnary[PingRequest, PingResponse](ctx, c, "ping", request) +func (c *Client) Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) { + return call[PingRequest, PingResponse](ctx, c.base, "ping", request) } -func (c *StdioClient) SetLogLevel(ctx context.Context, request *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error) { - return clientCallUnary[SetLogLevelRequest, SetLogLevelResponse](ctx, c, "logging/setLevel", request) +func (c *Client) SetLogLevel(ctx context.Context, request *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error) { + return call[SetLogLevelRequest, SetLogLevelResponse](ctx, c.base, "logging/setLevel", request) } diff --git a/examples/fs/main.go b/examples/fs/main.go index a5625c7..42b919d 100644 --- a/examples/fs/main.go +++ b/examples/fs/main.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/riza-io/mcp-go" + "github.com/riza-io/mcp-go/stdio" ) type FSServer struct { @@ -73,11 +74,11 @@ func main() { root = "/" } - server := mcp.NewStdioServer(&FSServer{ + server := mcp.NewServer(stdio.NewStream(os.Stdin, os.Stdout), &FSServer{ fs: os.DirFS(root), }) - if err := server.Listen(context.Background(), os.Stdin, os.Stdout); err != nil { + if err := server.Listen(context.Background()); err != nil { log.Fatal(err) } } diff --git a/examples/weather/main.go b/examples/weather/main.go index 7da0bf9..b04cdac 100644 --- a/examples/weather/main.go +++ b/examples/weather/main.go @@ -6,6 +6,7 @@ import ( "os" "github.com/riza-io/mcp-go" + "github.com/riza-io/mcp-go/stdio" ) type WeatherServer struct { @@ -32,12 +33,12 @@ func main() { log.Fatal("OPENWEATHER_API_KEY environment variable required") } - server := mcp.NewStdioServer(&WeatherServer{ + server := mcp.NewServer(stdio.NewStream(os.Stdin, os.Stdout), &WeatherServer{ defaultCity: "London", key: os.Getenv("OPENWEATHER_API_KEY"), }) - if err := server.Listen(ctx, os.Stdin, os.Stdout); err != nil { + if err := server.Listen(ctx); err != nil { log.Fatal(err) } } diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index fc46109..98bba34 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -2,16 +2,22 @@ package endtoend import ( "context" + "encoding/json" "io" "testing" "github.com/riza-io/mcp-go" + "github.com/riza-io/mcp-go/stdio" ) type server struct { mcp.UnimplementedServer } +type client struct { + mcp.UnimplementedClient +} + func (s *server) Initialize(ctx context.Context, req *mcp.Request[mcp.InitializeRequest]) (*mcp.Response[mcp.InitializeResponse], error) { return mcp.NewResponse(&mcp.InitializeResponse{ ProtocolVersion: req.Params.ProtocolVersion, @@ -44,17 +50,24 @@ func TestEndToEnd(t *testing.T) { stdinr, stdinw := io.Pipe() stdoutr, stdoutw := io.Pipe() - client := mcp.NewStdioClient(stdinr, stdoutw) - srv := mcp.NewStdioServer(&server{}, mcp.WithInterceptors(loggingInterceptor)) + c := mcp.NewClient(stdio.NewStream(stdinr, stdoutw), &client{}, + mcp.WithInterceptors(loggingInterceptor)) + s := mcp.NewServer(stdio.NewStream(stdoutr, stdinw), &server{}) + + go func() { + if err := s.Listen(ctx); err != nil { + t.Fatalf("failed to listen: %v", err) + } + }() go func() { - if err := srv.Listen(ctx, stdoutr, stdinw); err != nil { + if err := c.Listen(ctx); err != nil { t.Fatalf("failed to listen: %v", err) } }() t.Run("initialize", func(t *testing.T) { - resp, err := client.Initialize(ctx, mcp.NewRequest(&mcp.InitializeRequest{ + resp, err := c.Initialize(ctx, mcp.NewRequest(&mcp.InitializeRequest{ ProtocolVersion: "1.0.0", })) if err != nil { @@ -65,15 +78,33 @@ func TestEndToEnd(t *testing.T) { } }) - t.Run("ping", func(t *testing.T) { - _, err := client.Ping(ctx, mcp.NewRequest(&mcp.PingRequest{})) + t.Run("client/ping", func(t *testing.T) { + _, err := c.Ping(ctx, mcp.NewRequest(&mcp.PingRequest{})) if err != nil { t.Fatalf("failed to ping server: %v", err) } }) + t.Run("server/ping", func(t *testing.T) { + _, err := s.Ping(ctx, mcp.NewRequest(&mcp.PingRequest{})) + if err != nil { + t.Fatalf("failed to ping client: %v", err) + } + }) + + t.Run("server/sendLogMessage", func(t *testing.T) { + err := s.LogMessage(ctx, mcp.NewRequest(&mcp.LogMessageRequest{ + Level: mcp.LevelInfo, + Logger: "test", + Data: json.RawMessage(`{"message": "test"}`), + })) + if err != nil { + t.Fatalf("failed to send log message: %v", err) + } + }) + t.Run("set log level", func(t *testing.T) { - _, err := client.SetLogLevel(ctx, mcp.NewRequest(&mcp.SetLogLevelRequest{ + _, err := c.SetLogLevel(ctx, mcp.NewRequest(&mcp.SetLogLevelRequest{ Level: mcp.LevelInfo, })) if err != nil { diff --git a/internal/transport/stdio/stdio.go b/internal/transport/stdio/stdio.go deleted file mode 100644 index 873c926..0000000 --- a/internal/transport/stdio/stdio.go +++ /dev/null @@ -1 +0,0 @@ -package stdio diff --git a/messages.go b/messages.go index da73f21..f0e1316 100644 --- a/messages.go +++ b/messages.go @@ -122,11 +122,11 @@ type GetPromptRequest struct { } type GetPromptResponse struct { - Description string `json:"description"` - Messages []Message `json:"messages"` + Description string `json:"description"` + Messages []PromptMessage `json:"messages"` } -type Message struct { +type PromptMessage struct { Role string `json:"role"` Content Content `json:"content"` } @@ -228,3 +228,17 @@ type SetLogLevelRequest struct { type SetLogLevelResponse struct { } + +type LogMessageRequest struct { + Level Level `json:"level"` + Logger string `json:"logger"` + Data json.RawMessage `json:"data"` +} + +type SamplingRequest struct { + MaxTokens int `json:"maxTokens"` +} + +type SamplingResponse struct { + Role string `json:"role"` +} diff --git a/notify.go b/notify.go new file mode 100644 index 0000000..8c90953 --- /dev/null +++ b/notify.go @@ -0,0 +1,48 @@ +package mcp + +import ( + "context" + "encoding/json" +) + +func notify[P any](ctx context.Context, c *base, method string, req *Request[P]) error { + var interceptor Interceptor + if len(c.interceptors) > 0 { + interceptor = newStack(c.interceptors) + } else { + interceptor = UnaryInterceptorFunc( + func(next UnaryFunc) UnaryFunc { + return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { + return next(ctx, request) + }) + }, + ) + } + + inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { + rawmsg, err := json.Marshal(req.Params) + if err != nil { + return nil, err + } + + msgVersion := "2.0" + msgParams := json.RawMessage(rawmsg) + + msg := &Message{ + JsonRPC: &msgVersion, + Method: &method, + Params: &msgParams, + } + + return nil, c.stream.Send(msg) + }) + + req.method = method + + _, err := interceptor.WrapUnary(inner)(ctx, req) + if err != nil { + return err + } + + return nil +} diff --git a/option.go b/option.go index 55bdc8a..f8dcf07 100644 --- a/option.go +++ b/option.go @@ -6,7 +6,7 @@ type Option interface { } type ClientOption interface { - applyToClient(c *StdioClient) + applyToClient(c *Client) } type ServerOption interface { @@ -21,7 +21,7 @@ type interceptorsOption struct { Interceptors []Interceptor } -func (o *interceptorsOption) applyToClient(c *StdioClient) { +func (o *interceptorsOption) applyToClient(c *Client) { c.interceptors = o.Interceptors } diff --git a/process.go b/process.go new file mode 100644 index 0000000..552b7db --- /dev/null +++ b/process.go @@ -0,0 +1,93 @@ +package mcp + +import ( + "context" + "encoding/json" +) + +type empty struct{} + +func noop[T any](method func(ctx context.Context, req *Request[T])) func(context.Context, *Request[T]) (*Response[empty], error) { + return func(ctx context.Context, req *Request[T]) (*Response[empty], error) { + method(ctx, req) + return NewResponse(&empty{}), nil + } +} + +func process[T, V any](ctx context.Context, cfg *base, msg *Message, method func(ctx context.Context, req *Request[T]) (*Response[V], error)) error { + var interceptor Interceptor + if len(cfg.interceptors) > 0 { + interceptor = newStack(cfg.interceptors) + } else { + interceptor = UnaryInterceptorFunc( + func(next UnaryFunc) UnaryFunc { + return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { + return next(ctx, request) + }) + }, + ) + } + + var params T + + if msg.Params != nil && len(*msg.Params) > 0 { + if err := json.Unmarshal(*msg.Params, ¶ms); err != nil { + return err + } + } + + req := NewRequest(¶ms) + if msg.ID != nil { + req.id = msg.ID.String() + } + req.method = *msg.Method + + inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { + req := request.(*Request[T]) + resp, rerr := method(ctx, req) + if rerr != nil { + return nil, rerr + } + resp.id = req.id + return resp, nil + }) + + rr, err := interceptor.WrapUnary(inner)(ctx, req) + + // If the incoming message has no ID, we don't need to send a response + if msg.ID == nil { + return nil + } + + if err != nil { + return cfg.stream.Send(&Message{ + ID: msg.ID, + JsonRPC: msg.JsonRPC, + Error: &ErrorDetail{ + Code: 9, + Message: err.Error(), + }, + }) + } + + resp := rr.(*Response[V]) + + rawresult, err := json.Marshal(resp.Result) + if err != nil { + return cfg.stream.Send(&Message{ + ID: msg.ID, + JsonRPC: msg.JsonRPC, + Error: &ErrorDetail{ + Code: 9, + Message: err.Error(), + }, + }) + } + + rawmsg := json.RawMessage(rawresult) + return cfg.stream.Send(&Message{ + ID: msg.ID, + JsonRPC: msg.JsonRPC, + Result: &rawmsg, + }) +} diff --git a/router.go b/router.go new file mode 100644 index 0000000..0453de7 --- /dev/null +++ b/router.go @@ -0,0 +1,35 @@ +package mcp + +import "sync" + +type router struct { + lock sync.Mutex + next uint64 + boxes map[uint64]chan *Message +} + +func newRouter() *router { + return &router{ + boxes: make(map[uint64]chan *Message), + } +} + +func (r *router) Add() (uint64, chan *Message) { + r.lock.Lock() + id := r.next + r.next++ + inbox := make(chan *Message, 1) + r.boxes[id] = inbox + r.lock.Unlock() + return id, inbox +} + +func (r *router) Remove(id uint64) (chan *Message, bool) { + r.lock.Lock() + inbox, ok := r.boxes[id] + if ok { + delete(r.boxes, id) + } + r.lock.Unlock() + return inbox, ok +} diff --git a/rpc.go b/rpc.go new file mode 100644 index 0000000..2cca0a5 --- /dev/null +++ b/rpc.go @@ -0,0 +1,23 @@ +package mcp + +import "encoding/json" + +type Stream interface { + Recv() (*Message, error) + Send(msg *Message) error +} + +type Message struct { + ID *json.Number `json:"id"` + JsonRPC *string `json:"jsonrpc"` + Method *string `json:"method"` + Params *json.RawMessage `json:"params"` + Result *json.RawMessage `json:"result,omitempty"` + Error *ErrorDetail `json:"error,omitempty"` +} + +type ErrorDetail struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` +} diff --git a/server.go b/server.go index 09c8174..c84bb58 100644 --- a/server.go +++ b/server.go @@ -1,17 +1,11 @@ package mcp import ( - "bufio" - "bytes" "context" - "encoding/json" "fmt" - "io" - - "github.com/riza-io/mcp-go/internal/jsonrpc" ) -type Server interface { +type ServerHandler interface { Initialize(ctx context.Context, req *Request[InitializeRequest]) (*Response[InitializeResponse], error) ListTools(ctx context.Context, req *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) CallTool(ctx context.Context, req *Request[CallToolRequest]) (*Response[CallToolResponse], error) @@ -71,166 +65,68 @@ func (s *UnimplementedServer) SetLogLevel(ctx context.Context, req *Request[SetL return nil, fmt.Errorf("unimplemented") } -func process[T, V any](ctx context.Context, cfg *serverConfig, msg jsonrpc.Request, params *T, method func(ctx context.Context, req *Request[T]) (*Response[V], error)) (any, error) { - var interceptor Interceptor - if len(cfg.interceptors) > 0 { - interceptor = newStack(cfg.interceptors) - } else { - interceptor = UnaryInterceptorFunc( - func(next UnaryFunc) UnaryFunc { - return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { - return next(ctx, request) - }) - }, - ) - } - - if len(msg.Params) > 0 { - if err := json.Unmarshal(msg.Params, ¶ms); err != nil { - return nil, err - } - } - req := NewRequest(params) - req.id = msg.ID.String() - req.method = msg.Method - - inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { - req := request.(*Request[T]) - resp, rerr := method(ctx, req) - if rerr != nil { - return nil, rerr - } - resp.id = req.id - return resp, nil - }) - - rr, err := interceptor.WrapUnary(inner)(ctx, req) - if err != nil { - return nil, err - } - - resp := rr.(*Response[V]) - return resp.Result, nil - -} - type serverConfig struct { interceptors []Interceptor } -type StdioServer struct { - cfg *serverConfig - srv Server +type Server struct { + handler ServerHandler + base *base } -func NewStdioServer(srv Server, opts ...Option) *StdioServer { +func NewServer(stream Stream, handler ServerHandler, opts ...Option) *Server { cfg := &serverConfig{} for _, opt := range opts { opt.applyToServer(cfg) } - return &StdioServer{ - cfg: cfg, - srv: srv, + return &Server{ + handler: handler, + base: &base{ + router: newRouter(), + interceptors: cfg.interceptors, + stream: stream, + }, } } -func (s StdioServer) Listen(ctx context.Context, r io.Reader, w io.Writer) error { - scanner := bufio.NewScanner(r) - - for scanner.Scan() { - - // Recover from panics when processing a message - bs, err := s.processMessage(ctx, scanner.Bytes()) - if err == nil { - fmt.Fprintln(w, string(bs)) - } - } - - if err := scanner.Err(); err != nil { - return err - } - return nil +func (s *Server) Listen(ctx context.Context) error { + return s.base.listen(ctx, s.processMessage) } -func (s StdioServer) processMessage(ctx context.Context, line []byte) ([]byte, error) { - cfg := s.cfg - srv := s.srv - - dec := json.NewDecoder(bytes.NewReader(line)) - - var msg jsonrpc.Request - if err := dec.Decode(&msg); err != nil { - return nil, err - } +func (s *Server) Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) { + return call[PingRequest, PingResponse](ctx, s.base, "ping", request) +} - var result any - var err error - code := 9 +func (s *Server) LogMessage(ctx context.Context, request *Request[LogMessageRequest]) error { + return notify[LogMessageRequest](ctx, s.base, "notifications/message", request) +} - switch msg.Method { +func (s *Server) processMessage(ctx context.Context, msg *Message) error { + h := s.handler + switch m := *msg.Method; m { case "initialize": - params := &InitializeRequest{} - result, err = process(ctx, cfg, msg, params, srv.Initialize) + return process(ctx, s.base, msg, h.Initialize) case "completion/complete": - params := &CompletionRequest{} - result, err = process(ctx, cfg, msg, params, srv.Completion) + return process(ctx, s.base, msg, h.Completion) case "tools/list": - params := &ListToolsRequest{} - result, err = process(ctx, cfg, msg, params, srv.ListTools) + return process(ctx, s.base, msg, h.ListTools) case "tools/call": - params := &CallToolRequest{} - result, err = process(ctx, cfg, msg, params, srv.CallTool) + return process(ctx, s.base, msg, h.CallTool) case "prompts/list": - params := &ListPromptsRequest{} - result, err = process(ctx, cfg, msg, params, srv.ListPrompts) + return process(ctx, s.base, msg, h.ListPrompts) case "prompts/get": - params := &GetPromptRequest{} - result, err = process(ctx, cfg, msg, params, srv.GetPrompt) + return process(ctx, s.base, msg, h.GetPrompt) case "resources/list": - params := &ListResourcesRequest{} - result, err = process(ctx, cfg, msg, params, srv.ListResources) + return process(ctx, s.base, msg, h.ListResources) case "resources/read": - params := &ReadResourceRequest{} - result, err = process(ctx, cfg, msg, params, srv.ReadResource) + return process(ctx, s.base, msg, h.ReadResource) case "resources/templates/list": - params := &ListResourceTemplatesRequest{} - result, err = process(ctx, cfg, msg, params, srv.ListResourceTemplates) + return process(ctx, s.base, msg, h.ListResourceTemplates) case "ping": - params := &PingRequest{} - result, err = process(ctx, cfg, msg, params, srv.Ping) + return process(ctx, s.base, msg, h.Ping) case "logging/setLevel": - params := &SetLogLevelRequest{} - result, err = process(ctx, cfg, msg, params, srv.SetLogLevel) + return process(ctx, s.base, msg, h.SetLogLevel) default: - if msg.ID == "" { - // Ignore notifications - return nil, nil - } - code = -32601 - err = fmt.Errorf("unknown method: %s", msg.Method) + return fmt.Errorf("unknown method: %s", m) } - - var resp jsonrpc.Response - if err != nil { - resp = jsonrpc.Response{ - ID: msg.ID, - JsonRPC: msg.JsonRPC, - Error: &jsonrpc.ErrorDetail{ - Code: code, - Message: err.Error(), - }, - } - } else { - rawresult, err := json.Marshal(result) - if err != nil { - return nil, err - } - resp = jsonrpc.Response{ - ID: msg.ID, - JsonRPC: msg.JsonRPC, - Result: rawresult, - } - } - - return json.Marshal(resp) } diff --git a/stdio/stdio.go b/stdio/stdio.go new file mode 100644 index 0000000..36c58fd --- /dev/null +++ b/stdio/stdio.go @@ -0,0 +1,48 @@ +package stdio + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "sync" + + "github.com/riza-io/mcp-go" +) + +type Stream struct { + rlock sync.Mutex + scan *bufio.Scanner + w io.Writer + wlock sync.Mutex +} + +func NewStream(r io.Reader, w io.Writer) *Stream { + return &Stream{ + scan: bufio.NewScanner(r), + w: w, + } +} + +func (s *Stream) Recv() (*mcp.Message, error) { + if !s.scan.Scan() { + return nil, s.scan.Err() + } + line := s.scan.Bytes() + var msg mcp.Message + if err := json.Unmarshal(line, &msg); err != nil { + return nil, err + } + return &msg, nil +} + +func (s *Stream) Send(msg *mcp.Message) error { + bs, err := json.Marshal(msg) + if err != nil { + return err + } + s.wlock.Lock() + _, err = fmt.Fprintln(s.w, string(bs)) + s.wlock.Unlock() + return err +}