Skip to content

Commit 90bd877

Browse files
authored
fix: write back error message if the response marshal failed (#235)
1 parent f3fef81 commit 90bd877

File tree

2 files changed

+90
-14
lines changed

2 files changed

+90
-14
lines changed

server/sse.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,15 +457,13 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
457457
go func() {
458458
// Process message through MCPServer
459459
response := s.server.HandleMessage(ctx, rawMessage)
460-
461460
// Only send response if there is one (not for notifications)
462461
if response != nil {
463462
var message string
464463
if eventData, err := json.Marshal(response); err != nil {
465464
// If there is an error marshalling the response, send a generic error response
466465
log.Printf("failed to marshal response: %v", err)
467466
message = fmt.Sprintf("event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n")
468-
return
469467
} else {
470468
message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)
471469
}

server/sse_test.go

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func TestSSEServer(t *testing.T) {
6262
defer sseResp.Body.Close()
6363

6464
// Read the endpoint event
65-
endpointEvent, err := readSeeEvent(sseResp)
65+
endpointEvent, err := readSSEEvent(sseResp)
6666
if err != nil {
6767
t.Fatalf("Failed to read SSE response: %v", err)
6868
}
@@ -195,7 +195,7 @@ func TestSSEServer(t *testing.T) {
195195
}
196196
defer resp.Body.Close()
197197

198-
endpointEvent, err = readSeeEvent(sseResp)
198+
endpointEvent, err = readSSEEvent(sseResp)
199199
if err != nil {
200200
t.Fatalf("Failed to read SSE response: %v", err)
201201
}
@@ -590,7 +590,7 @@ func TestSSEServer(t *testing.T) {
590590
defer sseResp.Body.Close()
591591

592592
// Read the endpoint event
593-
endpointEvent, err := readSeeEvent(sseResp)
593+
endpointEvent, err := readSSEEvent(sseResp)
594594
if err != nil {
595595
t.Fatalf("Failed to read SSE response: %v", err)
596596
}
@@ -632,16 +632,16 @@ func TestSSEServer(t *testing.T) {
632632
}
633633

634634
// Verify response
635-
endpointEvent, err = readSeeEvent(sseResp)
635+
endpointEvent, err = readSSEEvent(sseResp)
636636
if err != nil {
637637
t.Fatalf("Failed to read SSE response: %v", err)
638638
}
639-
respFromSee := strings.TrimSpace(
639+
respFromSSE := strings.TrimSpace(
640640
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
641641
)
642642

643643
var response map[string]interface{}
644-
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
644+
if err := json.NewDecoder(strings.NewReader(respFromSSE)).Decode(&response); err != nil {
645645
t.Fatalf("Failed to decode response: %v", err)
646646
}
647647

@@ -680,17 +680,17 @@ func TestSSEServer(t *testing.T) {
680680
}
681681
defer resp.Body.Close()
682682

683-
endpointEvent, err = readSeeEvent(sseResp)
683+
endpointEvent, err = readSSEEvent(sseResp)
684684
if err != nil {
685685
t.Fatalf("Failed to read SSE response: %v", err)
686686
}
687687

688-
respFromSee = strings.TrimSpace(
688+
respFromSSE = strings.TrimSpace(
689689
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
690690
)
691691

692692
response = make(map[string]interface{})
693-
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
693+
if err := json.NewDecoder(strings.NewReader(respFromSSE)).Decode(&response); err != nil {
694694
t.Fatalf("Failed to decode response: %v", err)
695695
}
696696

@@ -1140,7 +1140,7 @@ func TestSSEServer(t *testing.T) {
11401140
registeredSession = s
11411141
}
11421142
})
1143-
1143+
11441144
mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks))
11451145
testServer := NewTestServer(mcpServer)
11461146
defer testServer.Close()
@@ -1153,7 +1153,7 @@ func TestSSEServer(t *testing.T) {
11531153
defer sseResp.Body.Close()
11541154

11551155
// Read the endpoint event to ensure session is established
1156-
_, err = readSeeEvent(sseResp)
1156+
_, err = readSSEEvent(sseResp)
11571157
if err != nil {
11581158
t.Fatalf("Failed to read SSE response: %v", err)
11591159
}
@@ -1240,9 +1240,87 @@ func TestSSEServer(t *testing.T) {
12401240
t.Error("Expected final_tool to exist")
12411241
}
12421242
})
1243+
1244+
t.Run("TestServerResponseMarshalError", func(t *testing.T) {
1245+
mcpServer := NewMCPServer("test", "1.0.0",
1246+
WithResourceCapabilities(true, true),
1247+
WithHooks(&Hooks{
1248+
OnAfterInitialize: []OnAfterInitializeFunc{
1249+
func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) {
1250+
result.Result.Meta = map[string]interface{}{"invalid": func() {}} // marshal will fail
1251+
},
1252+
},
1253+
}),
1254+
)
1255+
testServer := NewTestServer(mcpServer)
1256+
defer testServer.Close()
1257+
1258+
// Connect to SSE endpoint
1259+
sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL))
1260+
if err != nil {
1261+
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
1262+
}
1263+
defer sseResp.Body.Close()
1264+
1265+
// Read the endpoint event
1266+
endpointEvent, err := readSSEEvent(sseResp)
1267+
if err != nil {
1268+
t.Fatalf("Failed to read SSE response: %v", err)
1269+
}
1270+
if !strings.Contains(endpointEvent, "event: endpoint") {
1271+
t.Fatalf("Expected endpoint event, got: %s", endpointEvent)
1272+
}
1273+
1274+
// Extract message endpoint URL
1275+
messageURL := strings.TrimSpace(
1276+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
1277+
)
1278+
1279+
// Send initialize request
1280+
initRequest := map[string]interface{}{
1281+
"jsonrpc": "2.0",
1282+
"id": 1,
1283+
"method": "initialize",
1284+
"params": map[string]interface{}{
1285+
"protocolVersion": "2024-11-05",
1286+
"clientInfo": map[string]interface{}{
1287+
"name": "test-client",
1288+
"version": "1.0.0",
1289+
},
1290+
},
1291+
}
1292+
1293+
requestBody, err := json.Marshal(initRequest)
1294+
if err != nil {
1295+
t.Fatalf("Failed to marshal request: %v", err)
1296+
}
1297+
1298+
resp, err := http.Post(
1299+
messageURL,
1300+
"application/json",
1301+
bytes.NewBuffer(requestBody),
1302+
)
1303+
if err != nil {
1304+
t.Fatalf("Failed to send message: %v", err)
1305+
}
1306+
defer resp.Body.Close()
1307+
1308+
if resp.StatusCode != http.StatusAccepted {
1309+
t.Errorf("Expected status 202, got %d", resp.StatusCode)
1310+
}
1311+
1312+
endpointEvent, err = readSSEEvent(sseResp)
1313+
if err != nil {
1314+
t.Fatalf("Failed to read SSE response: %v", err)
1315+
}
1316+
1317+
if !strings.Contains(endpointEvent, "\"id\": null") {
1318+
t.Errorf("Expected id to be null")
1319+
}
1320+
})
12431321
}
12441322

1245-
func readSeeEvent(sseResp *http.Response) (string, error) {
1323+
func readSSEEvent(sseResp *http.Response) (string, error) {
12461324
buf := make([]byte, 1024)
12471325
n, err := sseResp.Body.Read(buf)
12481326
if err != nil {

0 commit comments

Comments
 (0)