Skip to content

Commit 8ae18ed

Browse files
committed
Merge branch 'main' of github.com:mark3labs/mcp-go into docs
2 parents f68ecb5 + 563a9c7 commit 8ae18ed

File tree

19 files changed

+1797
-273
lines changed

19 files changed

+1797
-273
lines changed

README.md

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ func main() {
5858
}
5959

6060
func helloHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
61-
name, ok := request.Params.Arguments["name"].(string)
62-
if !ok {
63-
return nil, errors.New("name must be a string")
61+
name, err := request.RequireString("name")
62+
if err != nil {
63+
return mcp.NewToolResultError(err.Error()), nil
6464
}
6565

6666
return mcp.NewToolResultText(fmt.Sprintf("Hello, %s!", name)), nil
@@ -94,10 +94,15 @@ MCP Go handles all the complex protocol details and server management, so you ca
9494
- [Prompts](#prompts)
9595
- [Examples](#examples)
9696
- [Extras](#extras)
97+
- [Transports](#transports)
9798
- [Session Management](#session-management)
99+
- [Basic Session Handling](#basic-session-handling)
100+
- [Per-Session Tools](#per-session-tools)
101+
- [Tool Filtering](#tool-filtering)
102+
- [Working with Context](#working-with-context)
98103
- [Request Hooks](#request-hooks)
99104
- [Tool Handler Middleware](#tool-handler-middleware)
100-
- [Contributing](/CONTRIBUTING.md)
105+
- [Regenerating Server Code](#regenerating-server-code)
101106

102107
## Installation
103108

@@ -527,10 +532,14 @@ Prompts can include:
527532

528533
## Examples
529534

530-
For examples, see the `examples/` directory.
535+
For examples, see the [`examples/`](examples/) directory.
531536

532537
## Extras
533538

539+
### Transports
540+
541+
MCP-Go supports stdio, SSE and streamable-HTTP transport layers.
542+
534543
### Session Management
535544

536545
MCP-Go provides a robust session management system that allows you to:
@@ -756,3 +765,14 @@ Add middleware to tool call handlers using the `server.WithToolHandlerMiddleware
756765

757766
A recovery middleware option is available to recover from panics in a tool call and can be added to the server with the `server.WithRecovery` option.
758767

768+
### Regenerating Server Code
769+
770+
Server hooks and request handlers are generated. Regenerate them by running:
771+
772+
```bash
773+
go generate ./...
774+
```
775+
776+
You need `go` installed and the `goimports` tool available. The generator runs
777+
`goimports` automatically to format and fix imports.
778+

client/sse.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ func WithHeaders(headers map[string]string) transport.ClientOption {
1212
return transport.WithHeaders(headers)
1313
}
1414

15+
func WithHeaderFunc(headerFunc transport.HTTPHeaderFunc) transport.ClientOption {
16+
return transport.WithHeaderFunc(headerFunc)
17+
}
18+
1519
func WithHTTPClient(httpClient *http.Client) transport.ClientOption {
1620
return transport.WithHTTPClient(httpClient)
1721
}

client/sse_test.go

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package client
22

33
import (
44
"context"
5+
"net/http"
56
"testing"
67
"time"
78

@@ -11,6 +12,13 @@ import (
1112
"github.com/mark3labs/mcp-go/server"
1213
)
1314

15+
type contextKey string
16+
17+
const (
18+
testHeaderKey contextKey = "X-Test-Header"
19+
testHeaderFuncKey contextKey = "X-Test-Header-Func"
20+
)
21+
1422
func TestSSEMCPClient(t *testing.T) {
1523
// Create MCP server with capabilities
1624
mcpServer := server.NewMCPServer(
@@ -41,9 +49,29 @@ func TestSSEMCPClient(t *testing.T) {
4149
},
4250
}, nil
4351
})
52+
mcpServer.AddTool(mcp.NewTool(
53+
"test-tool-for-http-header",
54+
mcp.WithDescription("Test tool for http header"),
55+
), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
56+
// , X-Test-Header-Func
57+
return &mcp.CallToolResult{
58+
Content: []mcp.Content{
59+
mcp.TextContent{
60+
Type: "text",
61+
Text: "context from header: " + ctx.Value(testHeaderKey).(string) + ", " + ctx.Value(testHeaderFuncKey).(string),
62+
},
63+
},
64+
}, nil
65+
})
4466

4567
// Initialize
46-
testServer := server.NewTestServer(mcpServer)
68+
testServer := server.NewTestServer(mcpServer,
69+
server.WithSSEContextFunc(func(ctx context.Context, r *http.Request) context.Context {
70+
ctx = context.WithValue(ctx, testHeaderKey, r.Header.Get("X-Test-Header"))
71+
ctx = context.WithValue(ctx, testHeaderFuncKey, r.Header.Get("X-Test-Header-Func"))
72+
return ctx
73+
}),
74+
)
4775
defer testServer.Close()
4876

4977
t.Run("Can create client", func(t *testing.T) {
@@ -250,4 +278,56 @@ func TestSSEMCPClient(t *testing.T) {
250278
t.Errorf("Expected 1 content item, got %d", len(result.Content))
251279
}
252280
})
281+
282+
t.Run("CallTool with customized header", func(t *testing.T) {
283+
client, err := NewSSEMCPClient(testServer.URL+"/sse",
284+
WithHeaders(map[string]string{
285+
"X-Test-Header": "test-header-value",
286+
}),
287+
WithHeaderFunc(func(ctx context.Context) map[string]string {
288+
return map[string]string{
289+
"X-Test-Header-Func": "test-header-func-value",
290+
}
291+
}),
292+
)
293+
if err != nil {
294+
t.Fatalf("Failed to create client: %v", err)
295+
}
296+
defer client.Close()
297+
298+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
299+
defer cancel()
300+
301+
if err := client.Start(ctx); err != nil {
302+
t.Fatalf("Failed to start client: %v", err)
303+
}
304+
305+
// Initialize
306+
initRequest := mcp.InitializeRequest{}
307+
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
308+
initRequest.Params.ClientInfo = mcp.Implementation{
309+
Name: "test-client",
310+
Version: "1.0.0",
311+
}
312+
313+
_, err = client.Initialize(ctx, initRequest)
314+
if err != nil {
315+
t.Fatalf("Failed to initialize: %v", err)
316+
}
317+
318+
request := mcp.CallToolRequest{}
319+
request.Params.Name = "test-tool-for-http-header"
320+
321+
result, err := client.CallTool(ctx, request)
322+
if err != nil {
323+
t.Fatalf("CallTool failed: %v", err)
324+
}
325+
326+
if len(result.Content) != 1 {
327+
t.Errorf("Expected 1 content item, got %d", len(result.Content))
328+
}
329+
if result.Content[0].(mcp.TextContent).Text != "context from header: test-header-value, test-header-func-value" {
330+
t.Errorf("Got %q, want %q", result.Content[0].(mcp.TextContent).Text, "context from header: test-header-value, test-header-func-value")
331+
}
332+
})
253333
}

client/transport/interface.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ import (
77
"github.com/mark3labs/mcp-go/mcp"
88
)
99

10+
// HTTPHeaderFunc is a function that extracts header entries from the given context
11+
// and returns them as key-value pairs. This is typically used to add context values
12+
// as HTTP headers in outgoing requests.
13+
type HTTPHeaderFunc func(context.Context) map[string]string
14+
1015
// Interface for the transport layer.
1116
type Interface interface {
1217
// Start the connection. Start should only be called once.

client/transport/sse.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ type SSE struct {
3131
notifyMu sync.RWMutex
3232
endpointChan chan struct{}
3333
headers map[string]string
34+
headerFunc HTTPHeaderFunc
3435

3536
started atomic.Bool
3637
closed atomic.Bool
@@ -45,6 +46,12 @@ func WithHeaders(headers map[string]string) ClientOption {
4546
}
4647
}
4748

49+
func WithHeaderFunc(headerFunc HTTPHeaderFunc) ClientOption {
50+
return func(sc *SSE) {
51+
sc.headerFunc = headerFunc
52+
}
53+
}
54+
4855
func WithHTTPClient(httpClient *http.Client) ClientOption {
4956
return func(sc *SSE) {
5057
sc.httpClient = httpClient
@@ -99,6 +106,11 @@ func (c *SSE) Start(ctx context.Context) error {
99106
for k, v := range c.headers {
100107
req.Header.Set(k, v)
101108
}
109+
if c.headerFunc != nil {
110+
for k, v := range c.headerFunc(ctx) {
111+
req.Header.Set(k, v)
112+
}
113+
}
102114

103115
resp, err := c.httpClient.Do(req)
104116
if err != nil {
@@ -269,6 +281,11 @@ func (c *SSE) SendRequest(
269281
for k, v := range c.headers {
270282
req.Header.Set(k, v)
271283
}
284+
if c.headerFunc != nil {
285+
for k, v := range c.headerFunc(ctx) {
286+
req.Header.Set(k, v)
287+
}
288+
}
272289

273290
// Create string key for map lookup
274291
idKey := request.ID.String()
@@ -368,6 +385,11 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
368385
for k, v := range c.headers {
369386
req.Header.Set(k, v)
370387
}
388+
if c.headerFunc != nil {
389+
for k, v := range c.headerFunc(ctx) {
390+
req.Header.Set(k, v)
391+
}
392+
}
371393

372394
resp, err := c.httpClient.Do(req)
373395
if err != nil {

client/transport/streamable_http.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
2626
}
2727
}
2828

29+
func WithHTTPHeaderFunc(headerFunc HTTPHeaderFunc) StreamableHTTPCOption {
30+
return func(sc *StreamableHTTP) {
31+
sc.headerFunc = headerFunc
32+
}
33+
}
34+
2935
// WithHTTPTimeout sets the timeout for a HTTP request and stream.
3036
func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
3137
return func(sc *StreamableHTTP) {
@@ -52,6 +58,7 @@ type StreamableHTTP struct {
5258
baseURL *url.URL
5359
httpClient *http.Client
5460
headers map[string]string
61+
headerFunc HTTPHeaderFunc
5562

5663
sessionID atomic.Value // string
5764

@@ -127,7 +134,6 @@ func (c *StreamableHTTP) Close() error {
127134
}
128135

129136
const (
130-
initializeMethod = "initialize"
131137
headerKeySessionID = "Mcp-Session-Id"
132138
)
133139

@@ -173,6 +179,11 @@ func (c *StreamableHTTP) SendRequest(
173179
for k, v := range c.headers {
174180
req.Header.Set(k, v)
175181
}
182+
if c.headerFunc != nil {
183+
for k, v := range c.headerFunc(ctx) {
184+
req.Header.Set(k, v)
185+
}
186+
}
176187

177188
// Send request
178189
resp, err := c.httpClient.Do(req)
@@ -198,7 +209,7 @@ func (c *StreamableHTTP) SendRequest(
198209
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
199210
}
200211

201-
if request.Method == initializeMethod {
212+
if request.Method == string(mcp.MethodInitialize) {
202213
// saved the received session ID in the response
203214
// empty session ID is allowed
204215
if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
@@ -363,6 +374,11 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
363374
for k, v := range c.headers {
364375
req.Header.Set(k, v)
365376
}
377+
if c.headerFunc != nil {
378+
for k, v := range c.headerFunc(ctx) {
379+
req.Header.Set(k, v)
380+
}
381+
}
366382

367383
// Send request
368384
resp, err := c.httpClient.Do(req)

examples/custom_context/main.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,8 @@ func NewMCPServer() *MCPServer {
122122
}
123123
}
124124

125-
func (s *MCPServer) ServeSSE(addr string) *server.SSEServer {
126-
return server.NewSSEServer(s.server,
127-
server.WithBaseURL(fmt.Sprintf("http://%s", addr)),
125+
func (s *MCPServer) ServeHTTP() *server.StreamableHTTPServer {
126+
return server.NewStreamableHTTPServer(s.server,
128127
server.WithHTTPContextFunc(authFromRequest),
129128
)
130129
}
@@ -135,12 +134,12 @@ func (s *MCPServer) ServeStdio() error {
135134

136135
func main() {
137136
var transport string
138-
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)")
137+
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or http)")
139138
flag.StringVar(
140139
&transport,
141140
"transport",
142141
"stdio",
143-
"Transport type (stdio or sse)",
142+
"Transport type (stdio or http)",
144143
)
145144
flag.Parse()
146145

@@ -151,15 +150,15 @@ func main() {
151150
if err := s.ServeStdio(); err != nil {
152151
log.Fatalf("Server error: %v", err)
153152
}
154-
case "sse":
155-
sseServer := s.ServeSSE("localhost:8080")
156-
log.Printf("SSE server listening on :8080")
157-
if err := sseServer.Start(":8080"); err != nil {
153+
case "http":
154+
httpServer := s.ServeHTTP()
155+
log.Printf("HTTP server listening on :8080")
156+
if err := httpServer.Start(":8080"); err != nil {
158157
log.Fatalf("Server error: %v", err)
159158
}
160159
default:
161160
log.Fatalf(
162-
"Invalid transport type: %s. Must be 'stdio' or 'sse'",
161+
"Invalid transport type: %s. Must be 'stdio' or 'http'",
163162
transport,
164163
)
165164
}

examples/everything/main.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -475,17 +475,17 @@ func handleNotification(
475475

476476
func main() {
477477
var transport string
478-
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)")
479-
flag.StringVar(&transport, "transport", "stdio", "Transport type (stdio or sse)")
478+
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or http)")
479+
flag.StringVar(&transport, "transport", "stdio", "Transport type (stdio or http)")
480480
flag.Parse()
481481

482482
mcpServer := NewMCPServer()
483483

484-
// Only check for "sse" since stdio is the default
485-
if transport == "sse" {
486-
sseServer := server.NewSSEServer(mcpServer, server.WithBaseURL("http://localhost:8080"))
487-
log.Printf("SSE server listening on :8080")
488-
if err := sseServer.Start(":8080"); err != nil {
484+
// Only check for "http" since stdio is the default
485+
if transport == "http" {
486+
httpServer := server.NewStreamableHTTPServer(mcpServer)
487+
log.Printf("HTTP server listening on :8080/mcp")
488+
if err := httpServer.Start(":8080"); err != nil {
489489
log.Fatalf("Server error: %v", err)
490490
}
491491
} else {

0 commit comments

Comments
 (0)