Skip to content

feat(server): convert ping messages to be spec compliant #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions server/internal/gen/request_handler.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func (s *MCPServer) HandleMessage(
JSONRPC string `json:"jsonrpc"`
Method mcp.MCPMethod `json:"method"`
ID any `json:"id,omitempty"`
Result any `json:"result,omitempty"`
}

if err := json.Unmarshal(message, &baseMessage); err != nil {
Expand Down Expand Up @@ -56,6 +57,12 @@ func (s *MCPServer) HandleMessage(
return nil // Return nil for notifications
}

if baseMessage.Result != nil {
// this is a response to a request sent by the server (e.g. from a ping
// sent due to WithKeepAlive option)
return nil
}

switch baseMessage.Method {
{{- range .}}
case mcp.{{.MethodName}}:
Expand Down
7 changes: 7 additions & 0 deletions server/request_handler.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type sseSession struct {
done chan struct{}
eventQueue chan string // Channel for queuing events
sessionID string
requestID atomic.Int64
notificationChannel chan mcp.JSONRPCNotification
initialized atomic.Bool
}
Expand Down Expand Up @@ -282,8 +283,16 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
for {
select {
case <-ticker.C:
//: ping - 2025-03-27 07:44:38.682659+00:00
session.eventQueue <- fmt.Sprintf(":ping - %s\n\n", time.Now().Format(time.RFC3339))
message := mcp.JSONRPCRequest{
JSONRPC: "2.0",
ID: session.requestID.Add(1),
Request: mcp.Request{
Method: "ping",
},
}
messageBytes, _ := json.Marshal(message)
pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes)
session.eventQueue <- pingMsg
case <-session.done:
return
case <-r.Context().Done():
Expand Down
115 changes: 115 additions & 0 deletions server/sse_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package server

import (
"bufio"
"bytes"
"context"
"encoding/json"
"io"
"fmt"
"math/rand"
"net/http"
Expand Down Expand Up @@ -739,4 +741,117 @@ func TestSSEServer(t *testing.T) {
}
}
})

t.Run("Client receives and can respond to ping messages", func(t *testing.T) {
mcpServer := NewMCPServer("test", "1.0.0")
testServer := NewTestServer(mcpServer,
WithKeepAlive(true),
WithKeepAliveInterval(50*time.Millisecond),
)
defer testServer.Close()

sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL))
if err != nil {
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
}
defer sseResp.Body.Close()

reader := bufio.NewReader(sseResp.Body)

var messageURL string
var pingID float64

for {
line, err := reader.ReadString('\n')
if err != nil {
t.Fatalf("Failed to read SSE event: %v", err)
}

if strings.HasPrefix(line, "event: endpoint") {
dataLine, err := reader.ReadString('\n')
if err != nil {
t.Fatalf("Failed to read endpoint data: %v", err)
}
messageURL = strings.TrimSpace(strings.TrimPrefix(dataLine, "data: "))

_, err = reader.ReadString('\n')
if err != nil {
t.Fatalf("Failed to read blank line: %v", err)
}
}

if strings.HasPrefix(line, "event: message") {
dataLine, err := reader.ReadString('\n')
if err != nil {
t.Fatalf("Failed to read message data: %v", err)
}

pingData := strings.TrimSpace(strings.TrimPrefix(dataLine, "data:"))
var pingMsg mcp.JSONRPCRequest
if err := json.Unmarshal([]byte(pingData), &pingMsg); err != nil {
t.Fatalf("Failed to parse ping message: %v", err)
}

if pingMsg.Method == "ping" {
pingID = pingMsg.ID.(float64)
t.Logf("Received ping with ID: %f", pingID)
break // We got the ping, exit the loop
}

_, err = reader.ReadString('\n')
if err != nil {
t.Fatalf("Failed to read blank line: %v", err)
}
}

if messageURL != "" && pingID != 0 {
break
}
}

if messageURL == "" {
t.Fatal("Did not receive message endpoint URL")
}

pingResponse := map[string]any{
"jsonrpc": "2.0",
"id": pingID,
"result": map[string]any{},
}

requestBody, err := json.Marshal(pingResponse)
if err != nil {
t.Fatalf("Failed to marshal ping response: %v", err)
}

resp, err := http.Post(
messageURL,
"application/json",
bytes.NewBuffer(requestBody),
)
if err != nil {
t.Fatalf("Failed to send ping response: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusAccepted {
t.Errorf("Expected status 202 for ping response, got %d", resp.StatusCode)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}

if len(body) > 0 {
var response map[string]any
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to parse response body: %v", err)
}

if response["error"] != nil {
t.Errorf("Expected no error in response, got %v", response["error"])
}
}
})
}