Skip to content

feat: Tool Handler Middleware #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ func main() {
"1.0.0",
server.WithResourceCapabilities(true, true),
server.WithLogging(),
server.WithRecovery(),
)

// Add a calculator tool
Expand Down Expand Up @@ -522,6 +523,12 @@ initialization.
Add the `Hooks` to the server at the time of creation using the
`server.WithHooks` option.

### Tool Handler Middleware

Add middleware to tool call handlers using the `server.WithToolHandlerMiddleware` option. Middlewares can be registered on server creation and are applied on every tool call.

A recovery middleware option is available to recover from panics in a tool call and can be added to the server with the `server.WithRecovery` option.

## Contributing

<details>
Expand Down
62 changes: 47 additions & 15 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (
// ToolHandlerFunc handles tool calls with given arguments.
type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)

// ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc.
type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc

// ServerTool combines a Tool with its ToolHandlerFunc.
type ServerTool struct {
Tool mcp.Tool
Expand Down Expand Up @@ -138,20 +141,21 @@ type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCN
// MCPServer implements a Model Control Protocol server that can handle various types of requests
// including resources, prompts, and tools.
type MCPServer struct {
mu sync.RWMutex // Add mutex for protecting shared resources
name string
version string
instructions string
resources map[string]resourceEntry
resourceTemplates map[string]resourceTemplateEntry
prompts map[string]mcp.Prompt
promptHandlers map[string]PromptHandlerFunc
tools map[string]ServerTool
notificationHandlers map[string]NotificationHandlerFunc
capabilities serverCapabilities
paginationLimit *int
sessions sync.Map
hooks *Hooks
mu sync.RWMutex // Add mutex for protecting shared resources
name string
version string
instructions string
resources map[string]resourceEntry
resourceTemplates map[string]resourceTemplateEntry
prompts map[string]mcp.Prompt
promptHandlers map[string]PromptHandlerFunc
tools map[string]ServerTool
toolHandlerMiddlewares []ToolHandlerMiddleware
notificationHandlers map[string]NotificationHandlerFunc
capabilities serverCapabilities
paginationLimit *int
sessions sync.Map
hooks *Hooks
}

// serverKey is the context key for storing the server instance
Expand Down Expand Up @@ -291,6 +295,30 @@ func WithResourceCapabilities(subscribe, listChanged bool) ServerOption {
}
}

// WithToolHandlerMiddleware allows adding a middleware for the
// tool handler call chain.
func WithToolHandlerMiddleware(
toolHandlerMiddleware ToolHandlerMiddleware,
) ServerOption {
return func(s *MCPServer) {
s.toolHandlerMiddlewares = append(s.toolHandlerMiddlewares, toolHandlerMiddleware)
}
}

// WithRecovery adds a middleware that recovers from panics in tool handlers.
func WithRecovery() ServerOption {
return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic recovered in %s tool handler: %v", request.Params.Name, r)
}
}()
return next(ctx, request)
}
})
}

// WithHooks allows adding hooks that will be called before or after
// either [all] requests or before / after specific request methods, or else
// prior to returning an error to the client.
Expand Down Expand Up @@ -801,7 +829,11 @@ func (s *MCPServer) handleToolCall(
}
}

result, err := tool.Handler(ctx, request)
finalHandler := tool.Handler
for i := len(s.toolHandlerMiddlewares) - 1; i >= 0; i-- {
finalHandler = s.toolHandlerMiddlewares[i](finalHandler)
}
result, err := finalHandler(ctx, request)
if err != nil {
return nil, &requestError{
id: id,
Expand Down
33 changes: 33 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1345,3 +1345,36 @@ func TestMCPServer_WithHooks(t *testing.T) {
assert.IsType(t, afterPingData[0].msg, onSuccessData[0].msg, "OnSuccess message should be same type as AfterPing message")
assert.IsType(t, afterPingData[0].res, onSuccessData[0].res, "OnSuccess result should be same type as AfterPing result")
}

func TestMCPServer_WithRecover(t *testing.T) {
panicToolHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
panic("test panic")
}

server := NewMCPServer(
"test-server",
"1.0.0",
WithRecovery(),
)

server.AddTool(
mcp.NewTool("panic-tool"),
panicToolHandler,
)

response := server.HandleMessage(context.Background(), []byte(`{
"jsonrpc": "2.0",
"id": 4,
"method": "tools/call",
"params": {
"name": "panic-tool"
}
}`))

errorResponse, ok := response.(mcp.JSONRPCError)

require.True(t, ok)
assert.Equal(t, mcp.INTERNAL_ERROR, errorResponse.Error.Code)
assert.Equal(t, "panic recovered in panic-tool tool handler: test panic", errorResponse.Error.Message)
assert.Nil(t, errorResponse.Error.Data)
}