Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support server to client requests #11

Merged
merged 5 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions base.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package mcp

import (
"context"
"strconv"
)

type base struct {
router *router
stream Stream
interceptors []Interceptor
}

func (b *base) listen(ctx context.Context, handler func(ctx context.Context, msg *Message) error) error {
for {
msg, err := b.stream.Recv()
if err != nil {
return err
}
if msg.Method != nil {
go func() {
handler(ctx, msg)
}()
} else {
id, err := strconv.ParseUint(msg.ID.String(), 10, 64)
if err != nil {
continue
}
if inbox, ok := b.router.Remove(id); ok {
inbox <- msg
}
}
}
}
77 changes: 77 additions & 0 deletions call.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package mcp

import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
)

func call[P any, R any](ctx context.Context, c *base, method string, req *Request[P]) (*Response[R], error) {
id, inbox := c.router.Add()

var interceptor Interceptor
if len(c.interceptors) > 0 {
interceptor = newStack(c.interceptors)
} else {
interceptor = UnaryInterceptorFunc(
func(next UnaryFunc) UnaryFunc {
return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
return next(ctx, request)
})
},
)
}

inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
rawmsg, err := json.Marshal(req.Params)
if err != nil {
return nil, err
}

msgID := json.Number(request.ID())
msgVersion := "2.0"
msgParams := json.RawMessage(rawmsg)

msg := &Message{
ID: &msgID,
JsonRPC: &msgVersion,
Method: &method,
Params: &msgParams,
}

if err := c.stream.Send(msg); err != nil {
return nil, err
}

var result R

select {
case resp := <-inbox:
if resp.Error != nil {
return nil, NewError(resp.Error.Code, errors.New(resp.Error.Message))
}
if resp.Result == nil {
return nil, fmt.Errorf("no result")
}
if err := json.Unmarshal(*resp.Result, &result); err != nil {
return nil, err
}
case <-ctx.Done():
return nil, ctx.Err()
}

return NewResponse(&result), nil
})

req.id = strconv.FormatUint(id, 10)
req.method = method

resp, err := interceptor.WrapUnary(inner)(ctx, req)
if err != nil {
return nil, err
}

return resp.(*Response[R]), nil
}
201 changes: 65 additions & 136 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,178 +1,107 @@
package mcp

import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"sync"

"github.com/riza-io/mcp-go/internal/jsonrpc"
)

type Client interface {
Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error)
ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error)
ListTools(ctx context.Context, req *Request[ListToolsRequest]) (*Response[ListToolsResponse], error)
CallTool(ctx context.Context, req *Request[CallToolRequest]) (*Response[CallToolResponse], error)
ListPrompts(ctx context.Context, req *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error)
GetPrompt(ctx context.Context, req *Request[GetPromptRequest]) (*Response[GetPromptResponse], error)
ReadResource(ctx context.Context, req *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error)
ListResourceTemplates(ctx context.Context, req *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error)
Completion(ctx context.Context, req *Request[CompletionRequest]) (*Response[CompletionResponse], error)
Ping(ctx context.Context, req *Request[PingRequest]) (*Response[PingResponse], error)
SetLogLevel(ctx context.Context, req *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error)
}

type StdioClient struct {
in io.Reader
out io.Writer
scanner *bufio.Scanner
next int
lock sync.Mutex
type ClientHandler interface {
Sampling(ctx context.Context, request *Request[SamplingRequest]) (*Response[SamplingResponse], error)
Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error)
LogMessage(ctx context.Context, request *Request[LogMessageRequest])
}

type UnimplementedClient struct{}

func (u *UnimplementedClient) Sampling(ctx context.Context, request *Request[SamplingRequest]) (*Response[SamplingResponse], error) {
return nil, fmt.Errorf("not implemented")
}

func (u *UnimplementedClient) LogMessage(ctx context.Context, request *Request[LogMessageRequest]) {
}

func (c *UnimplementedClient) Ping(ctx context.Context, req *Request[PingRequest]) (*Response[PingResponse], error) {
return NewResponse(&PingResponse{}), nil
}

type Client struct {
handler ClientHandler
interceptors []Interceptor
base *base
}

func NewStdioClient(stdin io.Reader, stdout io.Writer, opts ...Option) Client {
c := &StdioClient{
in: stdin,
out: stdout,
scanner: bufio.NewScanner(stdin),
func NewClient(stream Stream, handler ClientHandler, opts ...Option) *Client {
c := &Client{
handler: handler,
}

for _, opt := range opts {
opt.applyToClient(c)
}

c.base = &base{
router: newRouter(),
interceptors: c.interceptors,
stream: stream,
}
return c
}

func clientCallUnary[P any, R any](ctx context.Context, c *StdioClient, method string, req *Request[P]) (*Response[R], error) {
// Ensure that we are not sending multiple requests at the same time
c.lock.Lock()
defer c.lock.Unlock()

defer func() {
// Increment the ID counter
c.next++
}()

var interceptor Interceptor
if len(c.interceptors) > 0 {
interceptor = newStack(c.interceptors)
} else {
interceptor = UnaryInterceptorFunc(
func(next UnaryFunc) UnaryFunc {
return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
return next(ctx, request)
})
},
)
}

inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
rawmsg, err := json.Marshal(req.Params)
if err != nil {
return nil, err
}

msg := jsonrpc.Request{
ID: json.Number(request.ID()),
JsonRPC: "2.0",
Method: request.Method(),
Params: json.RawMessage(rawmsg),
}

bs, err := json.Marshal(msg)
if err != nil {
return nil, err
}

fmt.Fprintln(c.out, string(bs))

var result R

for c.scanner.Scan() {
line := c.scanner.Bytes()

var resp jsonrpc.Response

if err := json.Unmarshal(line, &resp); err != nil {
return nil, err
}

if resp.Error != nil {
return nil, NewError(resp.Error.Code, errors.New(resp.Error.Message))
}

if err := json.Unmarshal(resp.Result, &result); err != nil {
return nil, err
}

break
}

if err := c.scanner.Err(); err != nil {
return nil, err
}

return NewResponse(&result), nil
})

req.id = strconv.Itoa(c.next)
req.method = method
// sync.Once?
func (c *Client) Listen(ctx context.Context) error {
return c.base.listen(ctx, c.processMessage)
}

resp, err := interceptor.WrapUnary(inner)(ctx, req)
if err != nil {
return nil, err
func (c *Client) processMessage(ctx context.Context, msg *Message) error {
srv := c.handler
switch m := *msg.Method; m {
case "ping":
return process(ctx, c.base, msg, srv.Ping)
case "notifications/message":
return process(ctx, c.base, msg, noop(srv.LogMessage))
default:
return fmt.Errorf("unknown method: %s", m)
}

return resp.(*Response[R]), nil
}

func (c *StdioClient) Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error) {
return clientCallUnary[InitializeRequest, InitializeResponse](ctx, c, "initialize", request)
func (c *Client) Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error) {
return call[InitializeRequest, InitializeResponse](ctx, c.base, "initialize", request)
}

func (c *StdioClient) ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error) {
return clientCallUnary[ListResourcesRequest, ListResourcesResponse](ctx, c, "resources/list", request)
func (c *Client) ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error) {
return call[ListResourcesRequest, ListResourcesResponse](ctx, c.base, "resources/list", request)
}

func (c *StdioClient) ListTools(ctx context.Context, request *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) {
return clientCallUnary[ListToolsRequest, ListToolsResponse](ctx, c, "tools/list", request)
func (c *Client) ListTools(ctx context.Context, request *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) {
return call[ListToolsRequest, ListToolsResponse](ctx, c.base, "tools/list", request)
}

func (c *StdioClient) CallTool(ctx context.Context, request *Request[CallToolRequest]) (*Response[CallToolResponse], error) {
return clientCallUnary[CallToolRequest, CallToolResponse](ctx, c, "tools/call", request)
func (c *Client) CallTool(ctx context.Context, request *Request[CallToolRequest]) (*Response[CallToolResponse], error) {
return call[CallToolRequest, CallToolResponse](ctx, c.base, "tools/call", request)
}

func (c *StdioClient) ListPrompts(ctx context.Context, request *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error) {
return clientCallUnary[ListPromptsRequest, ListPromptsResponse](ctx, c, "prompts/list", request)
func (c *Client) ListPrompts(ctx context.Context, request *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error) {
return call[ListPromptsRequest, ListPromptsResponse](ctx, c.base, "prompts/list", request)
}

func (c *StdioClient) GetPrompt(ctx context.Context, request *Request[GetPromptRequest]) (*Response[GetPromptResponse], error) {
return clientCallUnary[GetPromptRequest, GetPromptResponse](ctx, c, "prompts/get", request)
func (c *Client) GetPrompt(ctx context.Context, request *Request[GetPromptRequest]) (*Response[GetPromptResponse], error) {
return call[GetPromptRequest, GetPromptResponse](ctx, c.base, "prompts/get", request)
}

func (c *StdioClient) ReadResource(ctx context.Context, request *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error) {
return clientCallUnary[ReadResourceRequest, ReadResourceResponse](ctx, c, "resources/read", request)
func (c *Client) ReadResource(ctx context.Context, request *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error) {
return call[ReadResourceRequest, ReadResourceResponse](ctx, c.base, "resources/read", request)
}

func (c *StdioClient) ListResourceTemplates(ctx context.Context, request *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error) {
return clientCallUnary[ListResourceTemplatesRequest, ListResourceTemplatesResponse](ctx, c, "resources/templates/list", request)
func (c *Client) ListResourceTemplates(ctx context.Context, request *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error) {
return call[ListResourceTemplatesRequest, ListResourceTemplatesResponse](ctx, c.base, "resources/templates/list", request)
}

func (c *StdioClient) Completion(ctx context.Context, request *Request[CompletionRequest]) (*Response[CompletionResponse], error) {
return clientCallUnary[CompletionRequest, CompletionResponse](ctx, c, "completion", request)
func (c *Client) Completion(ctx context.Context, request *Request[CompletionRequest]) (*Response[CompletionResponse], error) {
return call[CompletionRequest, CompletionResponse](ctx, c.base, "completion", request)
}

func (c *StdioClient) Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) {
return clientCallUnary[PingRequest, PingResponse](ctx, c, "ping", request)
func (c *Client) Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) {
return call[PingRequest, PingResponse](ctx, c.base, "ping", request)
}

func (c *StdioClient) SetLogLevel(ctx context.Context, request *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error) {
return clientCallUnary[SetLogLevelRequest, SetLogLevelResponse](ctx, c, "logging/setLevel", request)
func (c *Client) SetLogLevel(ctx context.Context, request *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error) {
return call[SetLogLevelRequest, SetLogLevelResponse](ctx, c.base, "logging/setLevel", request)
}
5 changes: 3 additions & 2 deletions examples/fs/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"

"github.com/riza-io/mcp-go"
"github.com/riza-io/mcp-go/stdio"
)

type FSServer struct {
Expand Down Expand Up @@ -73,11 +74,11 @@ func main() {
root = "/"
}

server := mcp.NewStdioServer(&FSServer{
server := mcp.NewServer(stdio.NewStream(os.Stdin, os.Stdout), &FSServer{
fs: os.DirFS(root),
})

if err := server.Listen(context.Background(), os.Stdin, os.Stdout); err != nil {
if err := server.Listen(context.Background()); err != nil {
log.Fatal(err)
}
}
Loading
Loading