Skip to content

Commit d453b0b

Browse files
committed
Export SendNotificationToAllClients
[First PR to project] Export sendNotificationToAllClients so it can be called from the MCP server to send unsolicited notifications to all clients (sessions). I have a use-case for my MCP server where I need to send notifications of state change to all sessions. My MCP server is fronting a device hub, and the device's on the hub would like to send notifications up to the LLM. There is probably a reason why this was not exported, so please forgive this PR if that's the case. I just could not figure out what the spec says about this. In any case, I added a test and tested with the Inspector. The Inspectory sees the unsolicited notifications. Cool! I also tried with Cursor and Claude, but neither seem to recognize notifications, no matter how hard I tried to convince them.
1 parent 71b910b commit d453b0b

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

server/server.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ func (s *MCPServer) UnregisterSession(
211211
s.sessions.Delete(sessionID)
212212
}
213213

214-
// sendNotificationToAllClients sends a notification to all the currently active clients.
215-
func (s *MCPServer) sendNotificationToAllClients(
214+
// SendNotificationToAllClients sends a notification to all the currently active clients.
215+
func (s *MCPServer) SendNotificationToAllClients(
216216
method string,
217217
params map[string]any,
218218
) {
@@ -472,7 +472,7 @@ func (s *MCPServer) AddTools(tools ...ServerTool) {
472472
s.toolsMu.Unlock()
473473

474474
// Send notification to all initialized sessions
475-
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
475+
s.SendNotificationToAllClients("notifications/tools/list_changed", nil)
476476
}
477477

478478
// SetTools replaces all existing tools with the provided list
@@ -492,7 +492,7 @@ func (s *MCPServer) DeleteTools(names ...string) {
492492
s.toolsMu.Unlock()
493493

494494
// Send notification to all initialized sessions
495-
s.sendNotificationToAllClients("notifications/tools/list_changed", nil)
495+
s.SendNotificationToAllClients("notifications/tools/list_changed", nil)
496496
}
497497

498498
// AddNotificationHandler registers a new handler for incoming notifications

server/server_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,75 @@ func TestMCPServer_SendNotificationToClient(t *testing.T) {
573573
}
574574
}
575575

576+
func TestMCPServer_SendNotificationToAllClients(t *testing.T) {
577+
578+
contextPrepare := func(ctx context.Context, srv *MCPServer) context.Context {
579+
// Create 5 active sessions
580+
for i := 0; i < 5; i++ {
581+
err := srv.RegisterSession(ctx, &fakeSession{
582+
sessionID: fmt.Sprintf("test%d", i),
583+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
584+
initialized: true,
585+
})
586+
require.NoError(t, err)
587+
}
588+
return ctx
589+
}
590+
591+
validate := func(t *testing.T, ctx context.Context, srv *MCPServer) {
592+
// Send 10 notifications to all sessions
593+
for i := 0; i < 10; i++ {
594+
srv.SendNotificationToAllClients("method", map[string]any{
595+
"count": i,
596+
})
597+
}
598+
599+
// Verify each session received all 10 notifications
600+
srv.sessions.Range(func(k, v any) bool {
601+
session := v.(ClientSession)
602+
fakeSess := session.(*fakeSession)
603+
notificationCount := 0
604+
605+
// Read all notifications from the channel
606+
for notificationCount < 10 {
607+
select {
608+
case notification := <-fakeSess.notificationChannel:
609+
// Verify notification method
610+
assert.Equal(t, "method", notification.Method)
611+
// Verify count parameter
612+
count, ok := notification.Params.AdditionalFields["count"]
613+
assert.True(t, ok, "count parameter not found")
614+
assert.Equal(t, notificationCount, count.(int), "count should match notification count")
615+
notificationCount++
616+
case <-time.After(100 * time.Millisecond):
617+
t.Errorf("timeout waiting for notification %d for session %s", notificationCount, session.SessionID())
618+
return false
619+
}
620+
}
621+
622+
// Verify no more notifications
623+
select {
624+
case notification := <-fakeSess.notificationChannel:
625+
t.Errorf("unexpected notification received: %v", notification)
626+
default:
627+
// Channel empty as expected
628+
}
629+
return true
630+
})
631+
}
632+
633+
t.Run("all sessions", func(t *testing.T) {
634+
server := NewMCPServer("test-server", "1.0.0")
635+
ctx := contextPrepare(context.Background(), server)
636+
_ = server.HandleMessage(ctx, []byte(`{
637+
"jsonrpc": "2.0",
638+
"id": 1,
639+
"method": "initialize"
640+
}`))
641+
validate(t, ctx, server)
642+
})
643+
}
644+
576645
func TestMCPServer_PromptHandling(t *testing.T) {
577646
server := NewMCPServer("test-server", "1.0.0",
578647
WithPromptCapabilities(true),

0 commit comments

Comments
 (0)