diff --git a/client/client.go b/client/client.go index 7689633c..dd0e31a0 100644 --- a/client/client.go +++ b/client/client.go @@ -104,7 +104,7 @@ func (c *Client) sendRequest( request := transport.JSONRPCRequest{ JSONRPC: mcp.JSONRPC_VERSION, - ID: id, + ID: mcp.NewRequestId(id), Method: method, Params: params, } diff --git a/client/transport/interface.go b/client/transport/interface.go index 8ac75d74..2fba4abf 100644 --- a/client/transport/interface.go +++ b/client/transport/interface.go @@ -27,15 +27,15 @@ type Interface interface { } type JSONRPCRequest struct { - JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` - Method string `json:"method"` - Params any `json:"params,omitempty"` + JSONRPC string `json:"jsonrpc"` + ID mcp.RequestId `json:"id"` + Method string `json:"method"` + Params any `json:"params,omitempty"` } type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id"` + ID mcp.RequestId `json:"id"` Result json.RawMessage `json:"result"` Error *struct { Code int `json:"code"` diff --git a/client/transport/sse.go b/client/transport/sse.go index eda9446e..24c5ce35 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -25,7 +25,7 @@ type SSE struct { baseURL *url.URL endpoint *url.URL httpClient *http.Client - responses map[int64]chan *JSONRPCResponse + responses map[string]chan *JSONRPCResponse mu sync.RWMutex onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex @@ -62,7 +62,7 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { smc := &SSE{ baseURL: parsedURL, httpClient: &http.Client{}, - responses: make(map[int64]chan *JSONRPCResponse), + responses: make(map[string]chan *JSONRPCResponse), endpointChan: make(chan struct{}), headers: make(map[string]string), } @@ -200,7 +200,7 @@ func (c *SSE) handleSSEEvent(event, data string) { } // Handle notification - if baseMessage.ID == nil { + if baseMessage.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(data), ¬ification); err != nil { return @@ -213,14 +213,17 @@ func (c *SSE) handleSSEEvent(event, data string) { return } + // Create string key for map lookup + idKey := baseMessage.ID.String() + c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] + ch, exists := c.responses[idKey] c.mu.RUnlock() - if ok { + if exists { ch <- &baseMessage c.mu.Lock() - delete(c.responses, *baseMessage.ID) + delete(c.responses, idKey) c.mu.Unlock() } } @@ -267,14 +270,17 @@ func (c *SSE) SendRequest( req.Header.Set(k, v) } + // Create string key for map lookup + idKey := request.ID.String() + // Register response channel responseChan := make(chan *JSONRPCResponse, 1) c.mu.Lock() - c.responses[request.ID] = responseChan + c.responses[idKey] = responseChan c.mu.Unlock() deleteResponseChan := func() { c.mu.Lock() - delete(c.responses, request.ID) + delete(c.responses, idKey) c.mu.Unlock() } @@ -327,7 +333,7 @@ func (c *SSE) Close() error { for _, ch := range c.responses { close(ch) } - c.responses = make(map[int64]chan *JSONRPCResponse) + c.responses = make(map[string]chan *JSONRPCResponse) c.mu.Unlock() return nil diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go index 230157d2..82074b11 100644 --- a/client/transport/sse_test.go +++ b/client/transport/sse_test.go @@ -160,7 +160,7 @@ func TestSSE(t *testing.T) { request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "debug/echo", Params: params, } @@ -174,7 +174,7 @@ func TestSSE(t *testing.T) { // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -187,8 +187,11 @@ func TestSSE(t *testing.T) { if result.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) } - if result.ID != 1 { - t.Errorf("Expected ID 1, got %d", result.ID) + idValue, ok := result.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", result.ID.Value()) + } else if idValue != 1 { + t.Errorf("Expected ID 1, got %d", idValue) } if result.Method != "debug/echo" { t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) @@ -211,7 +214,7 @@ func TestSSE(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 3, + ID: mcp.NewRequestId(int64(3)), Method: "debug/echo", } @@ -292,7 +295,7 @@ func TestSSE(t *testing.T) { // Each request has a unique ID and payload request := JSONRPCRequest{ JSONRPC: "2.0", - ID: int64(100 + idx), + ID: mcp.NewRequestId(int64(100 + idx)), Method: "debug/echo", Params: map[string]any{ "requestIndex": idx, @@ -317,15 +320,25 @@ func TestSSE(t *testing.T) { continue } - if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { - t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + if responses[i] == nil { + t.Errorf("Request %d: Response is nil", i) + continue + } + + expectedId := int64(100 + i) + idValue, ok := responses[i].ID.Value().(int64) + if !ok { + t.Errorf("Request %d: Expected ID to be int64, got %T", i, responses[i].ID.Value()) + continue + } else if idValue != expectedId { + t.Errorf("Request %d: Expected ID %d, got %d", i, expectedId, idValue) continue } // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -336,8 +349,11 @@ func TestSSE(t *testing.T) { } // Verify data matches what was sent - if result.ID != int64(100+i) { - t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + idValue, ok = result.ID.Value().(int64) + if !ok { + t.Errorf("Request %d: Expected ID to be int64, got %T", i, result.ID.Value()) + } else if idValue != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, idValue) } if result.Method != "debug/echo" { @@ -356,7 +372,7 @@ func TestSSE(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 100, + ID: mcp.NewRequestId(int64(100)), Method: "debug/echo_error_string", } @@ -378,8 +394,11 @@ func TestSSE(t *testing.T) { if responseError.Method != "debug/echo_error_string" { t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) } - if responseError.ID != 100 { - t.Errorf("Expected ID 100, got %d", responseError.ID) + idValue, ok := responseError.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", responseError.ID.Value()) + } else if idValue != 100 { + t.Errorf("Expected ID 100, got %d", idValue) } if responseError.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) @@ -453,7 +472,7 @@ func TestSSEErrors(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 99, + ID: mcp.NewRequestId(int64(99)), Method: "ping", } @@ -492,7 +511,7 @@ func TestSSEErrors(t *testing.T) { // Try to send a request after close request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "ping", } diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 3d9d832a..c300c405 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -26,7 +26,7 @@ type Stdio struct { stdin io.WriteCloser stdout *bufio.Reader stderr io.ReadCloser - responses map[int64]chan *JSONRPCResponse + responses map[string]chan *JSONRPCResponse mu sync.RWMutex done chan struct{} onNotification func(mcp.JSONRPCNotification) @@ -42,7 +42,7 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio stdout: bufio.NewReader(input), stderr: logging, - responses: make(map[int64]chan *JSONRPCResponse), + responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), } } @@ -61,7 +61,7 @@ func NewStdio( args: args, env: env, - responses: make(map[int64]chan *JSONRPCResponse), + responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), } @@ -181,7 +181,7 @@ func (c *Stdio) readResponses() { } // Handle notification - if baseMessage.ID == nil { + if baseMessage.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(line), ¬ification); err != nil { continue @@ -194,14 +194,17 @@ func (c *Stdio) readResponses() { continue } + // Create string key for map lookup + idKey := baseMessage.ID.String() + c.mu.RLock() - ch, ok := c.responses[*baseMessage.ID] + ch, exists := c.responses[idKey] c.mu.RUnlock() - if ok { + if exists { ch <- &baseMessage c.mu.Lock() - delete(c.responses, *baseMessage.ID) + delete(c.responses, idKey) c.mu.Unlock() } } @@ -227,14 +230,17 @@ func (c *Stdio) SendRequest( } requestBytes = append(requestBytes, '\n') + // Create string key for map lookup + idKey := request.ID.String() + // Register response channel responseChan := make(chan *JSONRPCResponse, 1) c.mu.Lock() - c.responses[request.ID] = responseChan + c.responses[idKey] = responseChan c.mu.Unlock() deleteResponseChan := func() { c.mu.Lock() - delete(c.responses, request.ID) + delete(c.responses, idKey) c.mu.Unlock() } diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 155859e1..3eea5b23 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -70,7 +70,7 @@ func TestStdio(t *testing.T) { defer stdio.Close() t.Run("SendRequest", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5000000000*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() params := map[string]any{ @@ -80,7 +80,7 @@ func TestStdio(t *testing.T) { request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "debug/echo", Params: params, } @@ -94,7 +94,7 @@ func TestStdio(t *testing.T) { // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -107,8 +107,11 @@ func TestStdio(t *testing.T) { if result.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) } - if result.ID != 1 { - t.Errorf("Expected ID 1, got %d", result.ID) + idValue, ok := result.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", result.ID.Value()) + } else if idValue != 1 { + t.Errorf("Expected ID 1, got %d", idValue) } if result.Method != "debug/echo" { t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) @@ -131,7 +134,7 @@ func TestStdio(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 3, + ID: mcp.NewRequestId(int64(3)), Method: "debug/echo", } @@ -211,7 +214,7 @@ func TestStdio(t *testing.T) { // Each request has a unique ID and payload request := JSONRPCRequest{ JSONRPC: "2.0", - ID: int64(100 + idx), + ID: mcp.NewRequestId(int64(100 + idx)), Method: "debug/echo", Params: map[string]any{ "requestIndex": idx, @@ -236,15 +239,25 @@ func TestStdio(t *testing.T) { continue } - if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { - t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + if responses[i] == nil { + t.Errorf("Request %d: Response is nil", i) + continue + } + + expectedId := int64(100 + i) + idValue, ok := responses[i].ID.Value().(int64) + if !ok { + t.Errorf("Request %d: Expected ID to be int64, got %T", i, responses[i].ID.Value()) + continue + } else if idValue != expectedId { + t.Errorf("Request %d: Expected ID %d, got %d", i, expectedId, idValue) continue } // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -255,8 +268,11 @@ func TestStdio(t *testing.T) { } // Verify data matches what was sent - if result.ID != int64(100+i) { - t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + idValue, ok = result.ID.Value().(int64) + if !ok { + t.Errorf("Request %d: Expected ID to be int64, got %T", i, result.ID.Value()) + } else if idValue != int64(100+i) { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, idValue) } if result.Method != "debug/echo" { @@ -271,11 +287,10 @@ func TestStdio(t *testing.T) { }) t.Run("ResponseError", func(t *testing.T) { - // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 100, + ID: mcp.NewRequestId(int64(100)), Method: "debug/echo_error_string", } @@ -297,14 +312,75 @@ func TestStdio(t *testing.T) { if responseError.Method != "debug/echo_error_string" { t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) } - if responseError.ID != 100 { - t.Errorf("Expected ID 100, got %d", responseError.ID) + idValue, ok := responseError.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", responseError.ID.Value()) + } else if idValue != 100 { + t.Errorf("Expected ID 100, got %d", idValue) } if responseError.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) } }) + t.Run("SendRequestWithStringID", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + params := map[string]any{ + "string": "string id test", + "array": []any{4, 5, 6}, + } + + // Use a string ID instead of an integer + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId("request-123"), + Method: "debug/echo", + Params: params, + } + + response, err := stdio.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed: %v", err) + } + + var result struct { + JSONRPC string `json:"jsonrpc"` + ID mcp.RequestId `json:"id"` + Method string `json:"method"` + Params map[string]any `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result: %v", err) + } + + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + + // Verify the ID is a string and has the expected value + idValue, ok := result.ID.Value().(string) + if !ok { + t.Errorf("Expected ID to be string, got %T", result.ID.Value()) + } else if idValue != "request-123" { + t.Errorf("Expected ID 'request-123', got '%s'", idValue) + } + + if result.Method != "debug/echo" { + t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) + } + + if str, ok := result.Params["string"].(string); !ok || str != "string id test" { + t.Errorf("Expected string 'string id test', got %v", result.Params["string"]) + } + + if arr, ok := result.Params["array"].([]any); !ok || len(arr) != 3 { + t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) + } + }) + } func TestStdioErrors(t *testing.T) { @@ -346,7 +422,7 @@ func TestStdioErrors(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 99, + ID: mcp.NewRequestId(int64(99)), Method: "ping", } @@ -398,7 +474,7 @@ func TestStdioErrors(t *testing.T) { // Try to send a request after close request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "ping", } diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 98719bd0..34677031 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -217,7 +217,7 @@ func (c *StreamableHTTP) SendRequest( } // should not be a notification - if response.ID == nil { + if response.ID.IsNil() { return nil, fmt.Errorf("response should contain RPC id: %v", response) } @@ -258,7 +258,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl } // Handle notification - if message.ID == nil { + if message.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(data), ¬ification); err != nil { fmt.Printf("failed to unmarshal notification: %v\n", err) diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index addddd20..de3cddff 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -147,7 +147,7 @@ func TestStreamableHTTP(t *testing.T) { initRequest := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(0)), Method: "initialize", } @@ -168,7 +168,7 @@ func TestStreamableHTTP(t *testing.T) { request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "debug/echo", Params: params, } @@ -182,7 +182,7 @@ func TestStreamableHTTP(t *testing.T) { // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -195,8 +195,11 @@ func TestStreamableHTTP(t *testing.T) { if result.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) } - if result.ID != 1 { - t.Errorf("Expected ID 1, got %d", result.ID) + idValue, ok := result.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", result.ID.Value()) + } else if idValue != 1 { + t.Errorf("Expected ID 1, got %d", idValue) } if result.Method != "debug/echo" { t.Errorf("Expected method 'debug/echo', got '%s'", result.Method) @@ -219,7 +222,7 @@ func TestStreamableHTTP(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 3, + ID: mcp.NewRequestId(int64(3)), Method: "debug/echo", } @@ -247,7 +250,7 @@ func TestStreamableHTTP(t *testing.T) { request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "debug/echo_notification", } @@ -266,7 +269,7 @@ func TestStreamableHTTP(t *testing.T) { if got == nil { t.Errorf("Notification handler did not send the expected notification: got nil") } - if int64(got["id"].(float64)) != request.ID || + if int64(got["id"].(float64)) != request.ID.Value().(int64) || got["jsonrpc"] != request.JSONRPC || got["method"] != request.Method { @@ -302,7 +305,7 @@ func TestStreamableHTTP(t *testing.T) { // Each request has a unique ID and payload request := JSONRPCRequest{ JSONRPC: "2.0", - ID: int64(100 + idx), + ID: mcp.NewRequestId(int64(100 + idx)), Method: "debug/echo", Params: map[string]any{ "requestIndex": idx, @@ -327,15 +330,25 @@ func TestStreamableHTTP(t *testing.T) { continue } - if responses[i] == nil || responses[i].ID == nil || *responses[i].ID != int64(100+i) { - t.Errorf("Request %d: Expected ID %d, got %v", i, 100+i, responses[i]) + if responses[i] == nil { + t.Errorf("Request %d: Response is nil", i) + continue + } + + expectedId := int64(100 + i) + idValue, ok := responses[i].ID.Value().(int64) + if !ok { + t.Errorf("Request %d: Expected ID to be int64, got %T", i, responses[i].ID.Value()) + continue + } else if idValue != expectedId { + t.Errorf("Request %d: Expected ID %d, got %d", i, expectedId, idValue) continue } // Parse the result to verify echo var result struct { JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` + ID mcp.RequestId `json:"id"` Method string `json:"method"` Params map[string]any `json:"params"` } @@ -346,8 +359,8 @@ func TestStreamableHTTP(t *testing.T) { } // Verify data matches what was sent - if result.ID != int64(100+i) { - t.Errorf("Request %d: Expected echoed ID %d, got %d", i, 100+i, result.ID) + if result.ID.Value().(int64) != expectedId { + t.Errorf("Request %d: Expected echoed ID %d, got %d", i, expectedId, result.ID.Value().(int64)) } if result.Method != "debug/echo" { @@ -368,7 +381,7 @@ func TestStreamableHTTP(t *testing.T) { // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 100, + ID: mcp.NewRequestId(int64(100)), Method: "debug/echo_error_string", } @@ -390,8 +403,11 @@ func TestStreamableHTTP(t *testing.T) { if responseError.Method != "debug/echo_error_string" { t.Errorf("Expected method 'debug/echo_error_string', got '%s'", responseError.Method) } - if responseError.ID != 100 { - t.Errorf("Expected ID 100, got %d", responseError.ID) + idValue, ok := responseError.ID.Value().(int64) + if !ok { + t.Errorf("Expected ID to be int64, got %T", responseError.ID.Value()) + } else if idValue != 100 { + t.Errorf("Expected ID 100, got %d", idValue) } if responseError.JSONRPC != "2.0" { t.Errorf("Expected JSONRPC '2.0', got '%s'", responseError.JSONRPC) @@ -421,7 +437,7 @@ func TestStreamableHTTPErrors(t *testing.T) { request := JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Method: "initialize", } diff --git a/mcp/types.go b/mcp/types.go index e7fdb6f0..d086ac90 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -4,6 +4,8 @@ package mcp import ( "encoding/json" + "fmt" + "strconv" "maps" "github.com/yosida95/uritemplate/v3" @@ -222,7 +224,75 @@ type Result struct { // RequestId is a uniquely identifying ID for a request in JSON-RPC. // It can be any JSON-serializable value, typically a number or string. -type RequestId any +type RequestId struct { + value any +} + +// NewRequestId creates a new RequestId with the given value +func NewRequestId(value any) RequestId { + return RequestId{value: value} +} + +// Value returns the underlying value of the RequestId +func (r RequestId) Value() any { + return r.value +} + +// String returns a string representation of the RequestId +func (r RequestId) String() string { + switch v := r.value.(type) { + case string: + return "string:" + v + case int64: + return "int64:" + strconv.FormatInt(v, 10) + case float64: + if v == float64(int64(v)) { + return "int64:" + strconv.FormatInt(int64(v), 10) + } + return "float64:" + strconv.FormatFloat(v, 'f', -1, 64) + case nil: + return "" + default: + return "unknown:" + fmt.Sprintf("%v", v) + } +} + +// IsNil returns true if the RequestId is nil +func (r RequestId) IsNil() bool { + return r.value == nil +} + +func (r RequestId) MarshalJSON() ([]byte, error) { + return json.Marshal(r.value) +} + +func (r *RequestId) UnmarshalJSON(data []byte) error { + + if string(data) == "null" { + r.value = nil + return nil + } + + // Try unmarshaling as string first + var s string + if err := json.Unmarshal(data, &s); err == nil { + r.value = s + return nil + } + + // JSON numbers are unmarshaled as float64 in Go + var f float64 + if err := json.Unmarshal(data, &f); err == nil { + if f == float64(int64(f)) { + r.value = int64(f) + } else { + r.value = f + } + return nil + } + + return fmt.Errorf("invalid request id: %s", string(data)) +} // JSONRPCRequest represents a request that expects a response. type JSONRPCRequest struct { diff --git a/server/server.go b/server/server.go index b31b4865..e5b48a5e 100644 --- a/server/server.go +++ b/server/server.go @@ -101,7 +101,7 @@ func (e *requestError) Error() string { func (e *requestError) ToJSONRPCError() mcp.JSONRPCError { return mcp.JSONRPCError{ JSONRPC: mcp.JSONRPC_VERSION, - ID: e.id, + ID: mcp.NewRequestId(e.id), Error: struct { Code int `json:"code"` Message string `json:"message"` @@ -937,7 +937,7 @@ func (s *MCPServer) handleNotification( func createResponse(id any, result any) mcp.JSONRPCMessage { return mcp.JSONRPCResponse{ JSONRPC: mcp.JSONRPC_VERSION, - ID: id, + ID: mcp.NewRequestId(id), Result: result, } } @@ -949,7 +949,7 @@ func createErrorResponse( ) mcp.JSONRPCMessage { return mcp.JSONRPCError{ JSONRPC: mcp.JSONRPC_VERSION, - ID: id, + ID: mcp.NewRequestId(id), Error: struct { Code int `json:"code"` Message string `json:"message"` diff --git a/server/server_test.go b/server/server_test.go index 4615b0fb..5c2bff4e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -134,7 +134,7 @@ func TestMCPServer_Capabilities(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", tt.options...) message := mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Request: mcp.Request{ Method: "initialize", }, @@ -388,7 +388,7 @@ func TestMCPServer_HandleValidMessages(t *testing.T) { name: "Initialize request", message: mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Request: mcp.Request{ Method: "initialize", }, @@ -413,7 +413,7 @@ func TestMCPServer_HandleValidMessages(t *testing.T) { name: "Ping request", message: mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Request: mcp.Request{ Method: "ping", }, @@ -430,7 +430,7 @@ func TestMCPServer_HandleValidMessages(t *testing.T) { name: "List resources", message: mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Request: mcp.Request{ Method: "resources/list", }, @@ -1127,7 +1127,7 @@ func TestMCPServer_Instructions(t *testing.T) { message := mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: 1, + ID: mcp.NewRequestId(int64(1)), Request: mcp.Request{ Method: "initialize", }, diff --git a/server/sse.go b/server/sse.go index 630927d1..d51a8979 100644 --- a/server/sse.go +++ b/server/sse.go @@ -338,7 +338,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { case <-ticker.C: message := mcp.JSONRPCRequest{ JSONRPC: "2.0", - ID: session.requestID.Add(1), + ID: mcp.NewRequestId(session.requestID.Add(1)), Request: mcp.Request{ Method: "ping", }, diff --git a/server/sse_test.go b/server/sse_test.go index aebf69d0..62bd616b 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -810,7 +810,16 @@ func TestSSEServer(t *testing.T) { } if pingMsg.Method == "ping" { - pingID = pingMsg.ID.(float64) + idValue, ok := pingMsg.ID.Value().(int64) + if ok { + pingID = float64(idValue) + } else { + floatValue, ok := pingMsg.ID.Value().(float64) + if !ok { + t.Fatalf("Expected ping ID to be number, got %T: %v", pingMsg.ID.Value(), pingMsg.ID.Value()) + } + pingID = floatValue + } t.Logf("Received ping with ID: %f", pingID) break // We got the ping, exit the loop } diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index 63f7835d..f561285e 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -6,19 +6,21 @@ import ( "fmt" "log/slog" "os" + + "github.com/mark3labs/mcp-go/mcp" ) type JSONRPCRequest struct { JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` + ID *mcp.RequestId `json:"id,omitempty"` Method string `json:"method"` Params json.RawMessage `json:"params"` } type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - ID *int64 `json:"id,omitempty"` - Result any `json:"result,omitempty"` + JSONRPC string `json:"jsonrpc"` + ID *mcp.RequestId `json:"id,omitempty"` + Result any `json:"result,omitempty"` Error *struct { Code int `json:"code"` Message string `json:"message"`