diff --git a/server/internal/gen/request_handler.go.tmpl b/server/internal/gen/request_handler.go.tmpl index 4e139e17..5c69f5fa 100644 --- a/server/internal/gen/request_handler.go.tmpl +++ b/server/internal/gen/request_handler.go.tmpl @@ -24,6 +24,7 @@ func (s *MCPServer) HandleMessage( JSONRPC string `json:"jsonrpc"` Method mcp.MCPMethod `json:"method"` ID any `json:"id,omitempty"` + Result any `json:"result,omitempty"` } if err := json.Unmarshal(message, &baseMessage); err != nil { @@ -56,6 +57,12 @@ func (s *MCPServer) HandleMessage( return nil // Return nil for notifications } + if baseMessage.Result != nil { + // this is a response to a request sent by the server (e.g. from a ping + // sent due to WithKeepAlive option) + return nil + } + switch baseMessage.Method { {{- range .}} case mcp.{{.MethodName}}: diff --git a/server/request_handler.go b/server/request_handler.go index 946ca7ab..55d2d19e 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -23,6 +23,7 @@ func (s *MCPServer) HandleMessage( JSONRPC string `json:"jsonrpc"` Method mcp.MCPMethod `json:"method"` ID any `json:"id,omitempty"` + Result any `json:"result,omitempty"` } if err := json.Unmarshal(message, &baseMessage); err != nil { @@ -55,6 +56,12 @@ func (s *MCPServer) HandleMessage( return nil // Return nil for notifications } + if baseMessage.Result != nil { + // this is a response to a request sent by the server (e.g. from a ping + // sent due to WithKeepAlive option) + return nil + } + switch baseMessage.Method { case mcp.MethodInitialize: var request mcp.InitializeRequest diff --git a/server/sse.go b/server/sse.go index f69451c6..20405965 100644 --- a/server/sse.go +++ b/server/sse.go @@ -23,6 +23,7 @@ type sseSession struct { done chan struct{} eventQueue chan string // Channel for queuing events sessionID string + requestID atomic.Int64 notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool } @@ -282,8 +283,16 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { for { select { case <-ticker.C: - //: ping - 2025-03-27 07:44:38.682659+00:00 - session.eventQueue <- fmt.Sprintf(":ping - %s\n\n", time.Now().Format(time.RFC3339)) + message := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: session.requestID.Add(1), + Request: mcp.Request{ + Method: "ping", + }, + } + messageBytes, _ := json.Marshal(message) + pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes) + session.eventQueue <- pingMsg case <-session.done: return case <-r.Context().Done(): diff --git a/server/sse_test.go b/server/sse_test.go index 111c5845..8e0f9264 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -1,9 +1,11 @@ package server import ( + "bufio" "bytes" "context" "encoding/json" + "io" "fmt" "math/rand" "net/http" @@ -739,4 +741,117 @@ func TestSSEServer(t *testing.T) { } } }) + + t.Run("Client receives and can respond to ping messages", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + testServer := NewTestServer(mcpServer, + WithKeepAlive(true), + WithKeepAliveInterval(50*time.Millisecond), + ) + defer testServer.Close() + + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + reader := bufio.NewReader(sseResp.Body) + + var messageURL string + var pingID float64 + + for { + line, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read SSE event: %v", err) + } + + if strings.HasPrefix(line, "event: endpoint") { + dataLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read endpoint data: %v", err) + } + messageURL = strings.TrimSpace(strings.TrimPrefix(dataLine, "data: ")) + + _, err = reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read blank line: %v", err) + } + } + + if strings.HasPrefix(line, "event: message") { + dataLine, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read message data: %v", err) + } + + pingData := strings.TrimSpace(strings.TrimPrefix(dataLine, "data:")) + var pingMsg mcp.JSONRPCRequest + if err := json.Unmarshal([]byte(pingData), &pingMsg); err != nil { + t.Fatalf("Failed to parse ping message: %v", err) + } + + if pingMsg.Method == "ping" { + pingID = pingMsg.ID.(float64) + t.Logf("Received ping with ID: %f", pingID) + break // We got the ping, exit the loop + } + + _, err = reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read blank line: %v", err) + } + } + + if messageURL != "" && pingID != 0 { + break + } + } + + if messageURL == "" { + t.Fatal("Did not receive message endpoint URL") + } + + pingResponse := map[string]any{ + "jsonrpc": "2.0", + "id": pingID, + "result": map[string]any{}, + } + + requestBody, err := json.Marshal(pingResponse) + if err != nil { + t.Fatalf("Failed to marshal ping response: %v", err) + } + + resp, err := http.Post( + messageURL, + "application/json", + bytes.NewBuffer(requestBody), + ) + if err != nil { + t.Fatalf("Failed to send ping response: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202 for ping response, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if len(body) > 0 { + var response map[string]any + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to parse response body: %v", err) + } + + if response["error"] != nil { + t.Errorf("Expected no error in response, got %v", response["error"]) + } + } + }) }