diff --git a/docs/modules/fxmcpserver.md b/docs/modules/fxmcpserver.md index 31c93a61..db771acd 100644 --- a/docs/modules/fxmcpserver.md +++ b/docs/modules/fxmcpserver.md @@ -23,12 +23,12 @@ It comes with: - automatic requests logging and tracing (method, target, duration, ...) - automatic requests metrics (count and duration) - possibility to register MCP resources, resource templates, prompts and tools -- possibility to register MCP SSE server context hooks -- possibility to expose the MCP server via Stdio (local) and/or HTTP SSE (remote) +- possibility to register MCP Streamable HTTP and SSE server context hooks +- possibility to expose the MCP server via Streamable HTTP (remote), HTTP SSE (remote) and Stdio (local) ## Installation -First install the module: +First, install the module: ```shell go get github.com/ankorstore/yokai/fxmcpserver @@ -60,10 +60,17 @@ modules: name: "MCP Server" # server name ("MCP server" by default) version: 1.0.0 # server version (1.0.0 by default) capabilities: - resources: true # to expose MCP resources & resource templates (disabled by default) + resources: true # to expose MCP resources and resource templates (disabled by default) prompts: true # to expose MCP prompts (disabled by default) tools: true # to expose MCP tools (disabled by default) transport: + stream: + expose: true # to remotely expose the MCP server via Streamable HTTP (disabled by default) + address: ":8083" # exposition address (":8083" by default) + stateless: false # stateless server mode (disabled by default) + base_path: "/mcp" # base path ("/mcp" by default) + keep_alive: true # to keep the connections alive + keep_alive_interval: 10 # keep alive interval in seconds (10 by default) sse: expose: true # to remotely expose the MCP server via SSE (disabled by default) address: ":8082" # exposition address (":8082" by default) @@ -74,7 +81,7 @@ modules: keep_alive: true # to keep connection alive keep_alive_interval: 10 # keep alive interval in seconds (10 by default) stdio: - expose: false # to locally expose the MCP server via Stdio (disabled by default) + expose: true # to locally expose the MCP server via Stdio (disabled by default) log: request: true # to log MCP requests contents (disabled by default) response: true # to log MCP responses contents (disabled by default) @@ -270,7 +277,7 @@ import ( func Register() fx.Option { return fx.Options( - // registers UserProfileResource as MCP resource + // registers UserProfileResource as MCP resource template fxmcpserver.AsMCPServerResourceTemplate(resource.NewUserProfileResource), // ... ) @@ -519,6 +526,70 @@ modules: ## Hooks +This module provides hooking mechanisms for the `StreamableHTTP` and `SSE` servers requests handling. + +### StreamableHTTP server hooks + +This module offers the possibility to provide context hooks with [MCPStreamableHTTPServerContextHook](https://github.com/ankorstore/yokai/blob/main/fxmcpserver/server/stream/context.go) implementations, that will be applied on each MCP StreamableHTTP request. + +For example, an MCP StreamableHTTP server context hook that adds a config value to the context: + +```go title="internal/mcp/resource/readme.go" +package hook + +import ( + "context" + "net/http" + + "github.com/ankorstore/yokai/config" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +type ExampleHook struct { + config *config.Config +} + +func NewExampleHook(config *config.Config) *ExampleHook { + return &ExampleHook{ + config: config, + } +} + +func (h *ExampleHook) Handle() server.HTTPContextFunc { + return func(ctx context.Context, r *http.Request) context.Context { + return context.WithValue(ctx, "foo", h.config.GetString("foo")) + } +} +``` + +You can register your MCP StreamableHTTP server context hook: + +- with `AsMCPStreamableHTTPServerContextHook()` to register a single MCP StreamableHTTP server context hook +- with `AsMCPStreamableHTTPServerContextHooks()` to register several MCP StreamableHTTP server context hooks at once + +```go title="internal/register.go" +package internal + +import ( + "github.com/ankorstore/yokai/fxmcpserver" + "github.com/foo/bar/internal/mcp/hook" + "go.uber.org/fx" +) + +func Register() fx.Option { + return fx.Options( + // registers ExampleHook as MCP StreamableHTTP server context hook + fxmcpserver.AsMCPStreamableHTTPServerContextHook(hook.NewExampleHook), + // ... + ) +} +``` + +The dependencies of your MCP StreamableHTTP server context hooks will be autowired. + +### SSE server hooks + This module offers the possibility to provide context hooks with [MCPSSEServerContextHook](https://github.com/ankorstore/yokai/blob/main/fxmcpserver/server/sse/context.go) implementations, that will be applied on each MCP SSE request. For example, an MCP SSE server context hook that adds a config value to the context: @@ -568,7 +639,7 @@ import ( func Register() fx.Option { return fx.Options( - // registers ReadmeResource as MCP resource + // registers ExampleHook as MCP SSE server context hook fxmcpserver.AsMCPSSEServerContextHook(hook.NewExampleHook), // ... ) @@ -684,7 +755,98 @@ mcp_server_requests_total{method="tools/call",status="success",target="calculato ## Testing -This module provides a [MCPSSETestServer](https://github.com/ankorstore/yokai/blob/main/fxmcpserver/fxmcpservertest/server.go) to enable you to easily test your exposed MCP registrations. +This module provide `StreamableHTTP` and `SSE` test servers, to functionally test your applications. + +### StreamableHTTP test server + +This module provides a [MCPStreamableHTTPTestServer](https://github.com/ankorstore/yokai/blob/main/fxmcpserver/fxmcpservertest/stream.go) to enable you to easily test your exposed MCP registrations. + +From this server, you can create a ready to use client via `StartClient()` to perform MCP requests, to functionally test your MCP server. + +You can easily assert on: + +- MCP responses +- logs +- traces +- metrics + +For example, to test an `MCP ping`: + +```go title="internal/mcp/ping_test.go" +package handler_test + +import ( + "testing" + + "github.com/ankorstore/yokai/log/logtest" + "github.com/ankorstore/yokai/trace/tracetest" + "github.com/foo/bar/internal" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "go.uber.org/fx" +) + +func TestMCPPing(t *testing.T) { + var testServer *fxmcpservertest.MCPStreamableHTTPTestServer + var logBuffer logtest.TestLogBuffer + var traceExporter tracetest.TestTraceExporter + var metricsRegistry *prometheus.Registry + + internal.RunTest(t, fx.Populate(&testServer, &logBuffer, &traceExporter, &metricsRegistry)) + + // close the test server once done + defer testServer.Close() + + // start test client + testClient, err := testServer.StartClient(context.Background()) + assert.NoError(t, err) + + // close the test client once done + defer testClient.Close() + + // send MCP ping request + err = testClient.Ping(context.Background()) + assert.NoError(t, err) + + // assertion on the logs buffer + logtest.AssertHasLogRecord(t, logBuffer, map[string]interface{}{ + "level": "info", + "mcpMethod": "ping", + "mcpTransport": "streamable-http", + "message": "MCP request success", + }) + + // assertion on the traces exporter + tracetest.AssertHasTraceSpan( + t, + traceExporter, + "MCP ping", + attribute.String("mcp.method", "ping"), + attribute.String("mcp.transport", "streamable-http"), + ) + + // assertion on the metrics registry + expectedMetric := ` + # HELP mcp_server_requests_total Number of processed MCP requests + # TYPE mcp_server_requests_total counter + mcp_server_requests_total{method="ping",status="success",target=""} 1 + ` + + err = testutil.GatherAndCompare( + metricsRegistry, + strings.NewReader(expectedMetric), + "mcp_server_requests_total", + ) + assert.NoError(t, err) +} +``` + +### SSE test server + +This module provides a [MCPSSETestServer](https://github.com/ankorstore/yokai/blob/main/fxmcpserver/fxmcpservertest/sse.go) to enable you to easily test your exposed MCP registrations. From this server, you can create a ready to use client via `StartClient()` to perform MCP requests, to functionally test your MCP server. diff --git a/fxmcpserver/README.md b/fxmcpserver/README.md index 182a2e7d..34a12c1a 100644 --- a/fxmcpserver/README.md +++ b/fxmcpserver/README.md @@ -12,16 +12,20 @@ * [Installation](#installation) * [Features](#features) * [Documentation](#documentation) - * [Dependencies](#dependencies) - * [Loading](#loading) - * [Configuration](#configuration) - * [Registration](#registration) - * [Resources](#resources) - * [Resource templates](#resource-templates) - * [Prompts](#prompts) - * [Tools](#tools) - * [Hooks](#hooks) - * [Testing](#testing) + * [Dependencies](#dependencies) + * [Loading](#loading) + * [Configuration](#configuration) + * [Registration](#registration) + * [Resources](#resources) + * [Resource templates](#resource-templates) + * [Prompts](#prompts) + * [Tools](#tools) + * [Hooks](#hooks) + * [StreamableHTTP server hooks](#streamablehttp-server-hooks) + * [SSE server hooks](#sse-server-hooks) + * [Testing](#testing) + * [StreamableHTTP test server](#streamablehttp-test-server) + * [SSE test server](#sse-test-server) ## Installation @@ -38,8 +42,8 @@ This module provides an [MCP server](https://modelcontextprotocol.io/introductio - automatic requests logging and tracing (method, target, duration, ...) - automatic requests metrics (count and duration) - possibility to register MCP resources, resource templates, prompts and tools -- possibility to register MCP SSE server context hooks -- possibility to expose the MCP server via Stdio (local) and/or HTTP SSE (remote) +- possibility to register MCP Streamable HTTP and SSE server context hooks +- possibility to expose the MCP server via Streamable HTTP (remote), HTTP SSE (remote) and Stdio (local) ## Documentation @@ -109,6 +113,13 @@ modules: prompts: true # to expose MCP prompts (disabled by default) tools: true # to expose MCP tools (disabled by default) transport: + stream: + expose: true # to remotely expose the MCP server via streamable HTTP (disabled by default) + address: ":8083" # exposition address (":8083" by default) + stateless: false # stateless server mode (disabled by default) + base_path: "/mcp" # base path ("/mcp" by default) + keep_alive: true # to keep the connections alive + keep_alive_interval: 10 # keep alive interval in seconds (10 by default) sse: expose: true # to remotely expose the MCP server via SSE (disabled by default) address: ":8082" # exposition address (":8082" by default) @@ -116,7 +127,7 @@ modules: base_path: "" # base path ("" by default) sse_endpoint: "/sse" # SSE endpoint ("/sse" by default) message_endpoint: "/message" # message endpoint ("/message" by default) - keep_alive: true # to keep connection alive + keep_alive: true # to keep the connections alive keep_alive_interval: 10 # keep alive interval in seconds (10 by default) stdio: expose: false # to locally expose the MCP server via Stdio (disabled by default) @@ -553,6 +564,68 @@ modules: ``` ### Hooks +This module provides hooking mechanisms for the `StreamableHTTP` and `SSE` servers requests handling. + +#### StreamableHTTP server hooks + +This module offers the possibility to provide context hooks with [MCPStreamableHTTPServerContextHook](server/stream/context.go) implementations, that will be applied on each MCP StreamableHTTP request. + +You can use the `AsMCPStreamableHTTPServerContextHook()` function to register an MCP StreamableHTTP server context hook, or `AsMCPStreamableHTTPServerContextHooks()` to register several MCP StreamableHTTP server context hooks at once. + +The dependencies of your MCP StreamableHTTP server context hooks will be autowired. + +```go +package main + +import ( + "context" + "net/http" + + "github.com/ankorstore/yokai/config" + "github.com/ankorstore/yokai/fxconfig" + "github.com/ankorstore/yokai/fxgenerate" + "github.com/ankorstore/yokai/fxlog" + "github.com/ankorstore/yokai/fxmcpserver" + "github.com/ankorstore/yokai/fxmetrics" + "github.com/ankorstore/yokai/fxtrace" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "go.uber.org/fx" +) + +type ExampleHook struct { + config *config.Config +} + +func NewExampleHook(config *config.Config) *ExampleHook { + return &ExampleHook{ + config: config, + } +} + +func (h *ExampleHook) Handle() server.HTTPContextFunc { + return func(ctx context.Context, r *http.Request) context.Context { + return context.WithValue(ctx, "foo", h.config.GetString("foo")) + } +} + +func main() { + fx.New( + fxconfig.FxConfigModule, + fxlog.FxLogModule, + fxtrace.FxTraceModule, + fxmetrics.FxMetricsModule, + fxgenerate.FxGenerateModule, + fxmcpserver.FxMCPServerModule, + fx.Options( + fxmcpserver.AsMCPStreamableHTTPServerContextHook(NewExampleHook), // registers the NewExampleHook as MCP StreamableHTTP server context hook + ), + ).Run() +} +``` + +#### SSE server hooks + This module offers the possibility to provide context hooks with [MCPSSEServerContextHook](server/sse/context.go) implementations, that will be applied on each MCP SSE request. You can use the `AsMCPSSEServerContextHook()` function to register an MCP SSE server context hook, or `AsMCPSSEServerContextHooks()` to register several MCP SSE server context hooks at once. @@ -611,7 +684,112 @@ func main() { ### Testing -This module provides a [MCPSSETestServer](fxmcpservertest/server.go) to enable you to easily test your exposed MCP capabilities. +This module provide `StreamableHTTP` and `SSE` test servers, to functionally test your applications. + +#### StreamableHTTP test server + +This module provides a [MCPStreamableHTTPTestServer](fxmcpservertest/stream.go) to enable you to easily test your exposed MCP capabilities. + +From this server, you can create a ready to use client via `StartClient()` to perform MCP requests, to functionally test your MCP server. + +You can then test it, considering `logs`, `traces` and `metrics` are enabled: + +```go +package internal_test + +import ( + "context" + "strings" + "testing" + + "github.com/ankorstore/yokai/fxconfig" + "github.com/ankorstore/yokai/fxgenerate" + "github.com/ankorstore/yokai/fxhttpserver" + "github.com/ankorstore/yokai/fxlog" + "github.com/ankorstore/yokai/fxmcpserver" + "github.com/ankorstore/yokai/fxmcpserver/fxmcpservertest" + "github.com/ankorstore/yokai/fxmetrics" + "github.com/ankorstore/yokai/fxtrace" + "github.com/ankorstore/yokai/log/logtest" + "github.com/ankorstore/yokai/trace/tracetest" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel/attribute" + "go.uber.org/fx" + "go.uber.org/fx/fxtest" +) + +func TestExample(t *testing.T) { + var testServer *fxmcpservertest.MCPStreamableHTTPTestServer + var logBuffer logtest.TestLogBuffer + var traceExporter tracetest.TestTraceExporter + var metricsRegistry *prometheus.Registry + + fxtest.New( + t, + fx.NopLogger, + fxconfig.FxConfigModule, + fxlog.FxLogModule, + fxtrace.FxTraceModule, + fxgenerate.FxGenerateModule, + fxmetrics.FxMetricsModule, + fxmcpserver.FxMCPServerModule, + fx.Populate(&testServer, &logBuffer, &traceExporter, &metricsRegistry), + ).RequireStart().RequireStop() + + // close the test server once done + defer testServer.Close() + + // start test client + testClient, err := testServer.StartClient(context.Background()) + assert.NoError(t, err) + + // close the test client once done + defer testClient.Close() + + // send MCP ping request + err = testClient.Ping(context.Background()) + assert.NoError(t, err) + + // assertion on the logs buffer + logtest.AssertHasLogRecord(t, logBuffer, map[string]interface{}{ + "level": "info", + "mcpMethod": "ping", + "mcpTransport": "streamable-http", + "message": "MCP request success", + }) + + // assertion on the traces exporter + tracetest.AssertHasTraceSpan( + t, + traceExporter, + "MCP ping", + attribute.String("mcp.method", "ping"), + attribute.String("mcp.transport", "streamable-http"), + ) + + // assertion on the metrics registry + expectedMetric := ` + # HELP mcp_server_requests_total Number of processed HTTP requests + # TYPE mcp_server_requests_total counter + mcp_server_requests_total{method="ping",status="success",target=""} 1 + ` + + err = testutil.GatherAndCompare( + metricsRegistry, + strings.NewReader(expectedMetric), + "mcp_server_requests_total", + ) + assert.NoError(t, err) +} +``` + +You can find more tests examples in this module own [tests](module_test.go). + +#### SSE test server + +This module provides a [MCPSSETestServer](fxmcpservertest/sse.go) to enable you to easily test your exposed MCP capabilities. From this server, you can create a ready to use client via `StartClient()` to perform MCP requests, to functionally test your MCP server. diff --git a/fxmcpserver/fxmcpservertest/server.go b/fxmcpserver/fxmcpservertest/sse.go similarity index 100% rename from fxmcpserver/fxmcpservertest/server.go rename to fxmcpserver/fxmcpservertest/sse.go diff --git a/fxmcpserver/fxmcpservertest/server_test.go b/fxmcpserver/fxmcpservertest/sse_test.go similarity index 100% rename from fxmcpserver/fxmcpservertest/server_test.go rename to fxmcpserver/fxmcpservertest/sse_test.go diff --git a/fxmcpserver/fxmcpservertest/stream.go b/fxmcpserver/fxmcpservertest/stream.go new file mode 100644 index 00000000..e6789292 --- /dev/null +++ b/fxmcpserver/fxmcpservertest/stream.go @@ -0,0 +1,73 @@ +package fxmcpservertest + +import ( + "context" + "github.com/ankorstore/yokai/fxmcpserver/server/stream" + "net/http/httptest" + + "github.com/ankorstore/yokai/config" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +type MCPStreamableHTTPTestServer struct { + config *config.Config + testServer *httptest.Server +} + +func NewMCPStreamableHTTPTestServer(cfg *config.Config, srv *server.MCPServer, hdl stream.MCPStreamableHTTPServerContextHandler) *MCPStreamableHTTPTestServer { + basePath := cfg.GetString("modules.mcp.server.transport.stream.base_path") + if basePath == "" { + basePath = stream.DefaultBasePath + } + + testSrv := server.NewTestStreamableHTTPServer( + srv, + server.WithHTTPContextFunc(hdl.Handle()), + server.WithEndpointPath(basePath), + ) + + return &MCPStreamableHTTPTestServer{ + config: cfg, + testServer: testSrv, + } +} + +func (s *MCPStreamableHTTPTestServer) Close() { + s.testServer.Close() +} + +func (s *MCPStreamableHTTPTestServer) StartClient(ctx context.Context, options ...transport.StreamableHTTPCOption) (*client.Client, error) { + basePath := s.config.GetString("modules.mcp.server.transport.stream.base_path") + if basePath == "" { + basePath = stream.DefaultBasePath + } + + baseURL := s.testServer.URL + basePath + + cli, err := client.NewStreamableHttpClient(baseURL, options...) + if err != nil { + return nil, err + } + + err = cli.Start(ctx) + if err != nil { + return nil, err + } + + initReq := mcp.InitializeRequest{} + initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initReq.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = cli.Initialize(ctx, initReq) + if err != nil { + return nil, err + } + + return cli, nil +} diff --git a/fxmcpserver/fxmcpservertest/stream_test.go b/fxmcpserver/fxmcpservertest/stream_test.go new file mode 100644 index 00000000..0d1b7bf8 --- /dev/null +++ b/fxmcpserver/fxmcpservertest/stream_test.go @@ -0,0 +1,52 @@ +package fxmcpservertest_test + +import ( + "context" + "github.com/ankorstore/yokai/fxmcpserver/server/stream" + "testing" + + "github.com/ankorstore/yokai/config" + "github.com/ankorstore/yokai/fxmcpserver/fxmcpservertest" + "github.com/ankorstore/yokai/generate/uuid" + "github.com/ankorstore/yokai/log" + "github.com/ankorstore/yokai/log/logtest" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/sdk/trace" +) + +func TestMCPStreamableHTTPTestServer(t *testing.T) { + t.Parallel() + + cfg, err := config.NewDefaultConfigFactory().Create( + config.WithFilePaths("../testdata/config"), + ) + assert.NoError(t, err) + + gm := uuid.NewDefaultUuidGenerator() + + tp := trace.NewTracerProvider() + + tmp := propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}) + + lb := logtest.NewDefaultTestLogBuffer() + lg, err := log.NewDefaultLoggerFactory().Create(log.WithOutputWriter(lb)) + assert.NoError(t, err) + + hdl := stream.NewDefaultMCPStreamableHTTPServerContextHandler(gm, tp, tmp, lg) + + mcpSrv := server.NewMCPServer("test-server", "1.0.0") + + srv := fxmcpservertest.NewMCPStreamableHTTPTestServer(cfg, mcpSrv, hdl) + defer srv.Close() + + cli, err := srv.StartClient(context.Background()) + assert.NoError(t, err) + + err = cli.Ping(context.Background()) + assert.NoError(t, err) + + err = cli.Close() + assert.NoError(t, err) +} diff --git a/fxmcpserver/go.mod b/fxmcpserver/go.mod index 86c657c9..4ad1f06d 100644 --- a/fxmcpserver/go.mod +++ b/fxmcpserver/go.mod @@ -14,14 +14,14 @@ require ( github.com/ankorstore/yokai/healthcheck v1.1.0 github.com/ankorstore/yokai/log v1.2.0 github.com/ankorstore/yokai/trace v1.4.0 - github.com/mark3labs/mcp-go v0.25.0 + github.com/mark3labs/mcp-go v0.31.0 github.com/prometheus/client_golang v1.22.0 github.com/stretchr/testify v1.10.0 go.opencensus.io v0.24.0 go.opentelemetry.io/otel v1.24.0 go.opentelemetry.io/otel/sdk v1.24.0 go.opentelemetry.io/otel/trace v1.24.0 - go.uber.org/fx v1.23.0 + go.uber.org/fx v1.24.0 ) require ( @@ -65,7 +65,7 @@ require ( go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.24.0 // indirect go.opentelemetry.io/otel/metric v1.24.0 // indirect go.opentelemetry.io/proto/otlp v1.1.0 // indirect - go.uber.org/dig v1.18.0 // indirect + go.uber.org/dig v1.19.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect diff --git a/fxmcpserver/go.sum b/fxmcpserver/go.sum index 77e9fe9e..69eb2543 100644 --- a/fxmcpserver/go.sum +++ b/fxmcpserver/go.sum @@ -89,8 +89,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mark3labs/mcp-go v0.25.0 h1:UUpcMT3L5hIhuDy7aifj4Bphw4Pfx1Rf8mzMXDe8RQw= -github.com/mark3labs/mcp-go v0.25.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.31.0 h1:4UxSV8aM770OPmTvaVe/b1rA2oZAjBMhGBfUgOGut+4= +github.com/mark3labs/mcp-go v0.31.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -169,10 +169,10 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= go.opentelemetry.io/proto/otlp v1.1.0 h1:2Di21piLrCqJ3U3eXGCTPHE9R8Nh+0uglSnOyxikMeI= go.opentelemetry.io/proto/otlp v1.1.0/go.mod h1:GpBHCBWiqvVLDqmHZsoMM3C5ySeKTC7ej/RNTae6MdY= -go.uber.org/dig v1.18.0 h1:imUL1UiY0Mg4bqbFfsRQO5G4CGRBec/ZujWTvSVp3pw= -go.uber.org/dig v1.18.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= -go.uber.org/fx v1.23.0 h1:lIr/gYWQGfTwGcSXWXu4vP5Ws6iqnNEIY+F/aFzCKTg= -go.uber.org/fx v1.23.0/go.mod h1:o/D9n+2mLP6v1EG+qsdT1O8wKopYAsqZasju97SDFCU= +go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4= +go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= +go.uber.org/fx v1.24.0 h1:wE8mruvpg2kiiL1Vqd0CC+tr0/24XIB10Iwp2lLWzkg= +go.uber.org/fx v1.24.0/go.mod h1:AmDeGyS+ZARGKM4tlH4FY2Jr63VjbEDJHtqXTGP5hbo= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= diff --git a/fxmcpserver/info.go b/fxmcpserver/info.go index b8104d50..07ac4298 100644 --- a/fxmcpserver/info.go +++ b/fxmcpserver/info.go @@ -5,46 +5,52 @@ import ( "github.com/ankorstore/yokai/fxmcpserver/server" "github.com/ankorstore/yokai/fxmcpserver/server/sse" "github.com/ankorstore/yokai/fxmcpserver/server/stdio" + "github.com/ankorstore/yokai/fxmcpserver/server/stream" ) // MCPServerModuleInfo is the MCP server module info. type MCPServerModuleInfo struct { - config *config.Config - registry *server.MCPServerRegistry - sseServer *sse.MCPSSEServer - stdioServer *stdio.MCPStdioServer + config *config.Config + registry *server.MCPServerRegistry + steamableHTTPServer *stream.MCPStreamableHTTPServer + sseServer *sse.MCPSSEServer + stdioServer *stdio.MCPStdioServer } // NewMCPServerModuleInfo returns a new MCPServerModuleInfo instance. func NewMCPServerModuleInfo( config *config.Config, registry *server.MCPServerRegistry, + steamableHTTPServer *stream.MCPStreamableHTTPServer, sseServer *sse.MCPSSEServer, stdioServer *stdio.MCPStdioServer, ) *MCPServerModuleInfo { return &MCPServerModuleInfo{ - config: config, - registry: registry, - sseServer: sseServer, - stdioServer: stdioServer, + config: config, + registry: registry, + steamableHTTPServer: steamableHTTPServer, + sseServer: sseServer, + stdioServer: stdioServer, } } -// Name return the name of the module info. +// Name returns the name of the module info. func (i *MCPServerModuleInfo) Name() string { return ModuleName } // Data return the data of the module info. func (i *MCPServerModuleInfo) Data() map[string]any { + streamableHTTPServerInfo := i.steamableHTTPServer.Info() sseServerInfo := i.sseServer.Info() stdioServerInfo := i.stdioServer.Info() mcpRegistryInfo := i.registry.Info() return map[string]any{ "transports": map[string]any{ - "sse": sseServerInfo, - "stdio": stdioServerInfo, + "stream": streamableHTTPServerInfo, + "sse": sseServerInfo, + "stdio": stdioServerInfo, }, "capabilities": map[string]any{ "tools": mcpRegistryInfo.Capabilities.Tools, diff --git a/fxmcpserver/info_test.go b/fxmcpserver/info_test.go index 634bcc68..8732bc7c 100644 --- a/fxmcpserver/info_test.go +++ b/fxmcpserver/info_test.go @@ -1,6 +1,7 @@ package fxmcpserver_test import ( + "github.com/ankorstore/yokai/fxmcpserver/server/stream" "testing" "github.com/ankorstore/yokai/config" @@ -42,10 +43,11 @@ func TestMCPServerModuleInfo(t *testing.T) { mcpSrv := server.NewMCPServer("test-server", "1.0.0") + streamSrv := stream.NewDefaultMCPStreamableHTTPServerFactory(cfg).Create(mcpSrv) sseSrv := sse.NewDefaultMCPSSEServerFactory(cfg).Create(mcpSrv) stdioSrv := stdio.NewDefaultMCPStdioServerFactory().Create(mcpSrv) - info := fxmcpserver.NewMCPServerModuleInfo(cfg, reg, sseSrv, stdioSrv) + info := fxmcpserver.NewMCPServerModuleInfo(cfg, reg, streamSrv, sseSrv, stdioSrv) assert.Equal(t, info.Name(), fxmcpserver.ModuleName) @@ -70,6 +72,18 @@ func TestMCPServerModuleInfo(t *testing.T) { "running": false, }, }, + "stream": map[string]any{ + "config": map[string]any{ + "address": ":0", + "stateless": true, + "base_path": stream.DefaultBasePath, + "keep_alive": true, + "keep_alive_interval": stream.DefaultKeepAliveInterval.Seconds(), + }, + "status": map[string]any{ + "running": false, + }, + }, }, "capabilities": map[string]any{ "tools": true, diff --git a/fxmcpserver/module.go b/fxmcpserver/module.go index ecb5818d..2d70f946 100644 --- a/fxmcpserver/module.go +++ b/fxmcpserver/module.go @@ -2,6 +2,7 @@ package fxmcpserver import ( "context" + "github.com/ankorstore/yokai/fxmcpserver/server/stream" "github.com/ankorstore/yokai/config" "github.com/ankorstore/yokai/fxmcpserver/fxmcpservertest" @@ -26,6 +27,8 @@ var FxMCPServerModule = fx.Module( // module fixed dependencies ProvideMCPServerRegistry, ProvideMCPServer, + ProvideMCPStreamableHTTPServer, + ProvideMCPStreamableHTTPTestServer, ProvideMCPSSEServer, ProvideMCPSSETestServer, ProvideMCPStdioServer, @@ -38,6 +41,14 @@ var FxMCPServerModule = fx.Module( ProvideDefaultMCPServerFactory, fx.As(new(fs.MCPServerFactory)), ), + fx.Annotate( + ProvideDefaultMCPStreamableHTTPServerContextHandler, + fx.As(new(stream.MCPStreamableHTTPServerContextHandler)), + ), + fx.Annotate( + ProvideDefaultMCPStreamableHTTPServerFactory, + fx.As(new(stream.MCPStreamableHTTPServerFactory)), + ), fx.Annotate( ProvideDefaultMCPSSEServerContextHandler, fx.As(new(sse.MCPSSEServerContextHandler)), @@ -125,6 +136,98 @@ func ProvideMCPServer(p ProvideMCPServerParam) *server.MCPServer { return srv } +// ProvideDefaultMCPStreamableHTTPContextHandlerParam allows injection of the required dependencies in ProvideDefaultMCPStreamableHTTPServerContextHandler. +type ProvideDefaultMCPStreamableHTTPContextHandlerParam struct { + fx.In + Generator uuid.UuidGenerator + TracerProvider trace.TracerProvider + Logger *log.Logger + MCPStreamableHTTPServerContextHooks []stream.MCPStreamableHTTPServerContextHook `group:"mcp-streamable-http-server-context-hooks"` +} + +// ProvideDefaultMCPStreamableHTTPServerContextHandler provides the default sse.MCPStreamableHTTPServerContextHandler instance. +func ProvideDefaultMCPStreamableHTTPServerContextHandler(p ProvideDefaultMCPStreamableHTTPContextHandlerParam) *stream.DefaultMCPStreamableHTTPServerContextHandler { + textMapPropagator := propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + ) + + return stream.NewDefaultMCPStreamableHTTPServerContextHandler( + p.Generator, + p.TracerProvider, + textMapPropagator, + p.Logger, + p.MCPStreamableHTTPServerContextHooks..., + ) +} + +// ProvideDefaultMCPStreamableHTTPServerFactoryParams allows injection of the required dependencies in ProvideDefaultMCPSSEServerFactory. +type ProvideDefaultMCPStreamableHTTPServerFactoryParams struct { + fx.In + Config *config.Config +} + +// ProvideDefaultMCPStreamableHTTPServerFactory provides the default sse.MCPStreamableHTTPServerFactory instance. +func ProvideDefaultMCPStreamableHTTPServerFactory(p ProvideDefaultMCPStreamableHTTPServerFactoryParams) *stream.DefaultMCPStreamableHTTPServerFactory { + return stream.NewDefaultMCPStreamableHTTPServerFactory(p.Config) +} + +// ProvideMCPStreamableHTTPServerParam allows injection of the required dependencies in ProvideMCPStreamableHTTPServer. +type ProvideMCPStreamableHTTPServerParam struct { + fx.In + LifeCycle fx.Lifecycle + Logger *log.Logger + Config *config.Config + MCPServer *server.MCPServer + MCPStreamableHTTPServerFactory stream.MCPStreamableHTTPServerFactory + MCPStreamableHTTPServerContextHandler stream.MCPStreamableHTTPServerContextHandler +} + +// ProvideMCPStreamableHTTPServer provides the stream.MCPStreamableHTTPServer. +func ProvideMCPStreamableHTTPServer(p ProvideMCPStreamableHTTPServerParam) *stream.MCPStreamableHTTPServer { + streamableHTTPServer := p.MCPStreamableHTTPServerFactory.Create( + p.MCPServer, + server.WithHTTPContextFunc(p.MCPStreamableHTTPServerContextHandler.Handle()), + ) + + streamableHTTPServerCtx := p.Logger.WithContext(context.Background()) + + if p.Config.GetBool("modules.mcp.server.transport.stream.expose") { + p.LifeCycle.Append(fx.Hook{ + OnStart: func(context.Context) error { + if !p.Config.IsTestEnv() { + //nolint:errcheck + go streamableHTTPServer.Start(streamableHTTPServerCtx) + } + + return nil + }, + OnStop: func(ctx context.Context) error { + if !p.Config.IsTestEnv() { + return streamableHTTPServer.Stop(ctx) + } + + return nil + }, + }) + } + + return streamableHTTPServer +} + +// ProvideMCPStreamableHTTPTestServerParam allows injection of the required dependencies in ProvideMCPStreamableHTTPTestServer. +type ProvideMCPStreamableHTTPTestServerParam struct { + fx.In + Config *config.Config + MCPServer *server.MCPServer + MCPStreamableHTTPServerContextHandler stream.MCPStreamableHTTPServerContextHandler +} + +// ProvideMCPStreamableHTTPTestServer provides the fxmcpservertest.MCPStreamableHTTPTestServer. +func ProvideMCPStreamableHTTPTestServer(p ProvideMCPStreamableHTTPTestServerParam) *fxmcpservertest.MCPStreamableHTTPTestServer { + return fxmcpservertest.NewMCPStreamableHTTPTestServer(p.Config, p.MCPServer, p.MCPStreamableHTTPServerContextHandler) +} + // ProvideDefaultMCPSSEContextHandlerParam allows injection of the required dependencies in ProvideDefaultMCPSSEServerContextHandler. type ProvideDefaultMCPSSEContextHandlerParam struct { fx.In diff --git a/fxmcpserver/module_test.go b/fxmcpserver/module_test.go index 1223ceec..6e1b09ca 100644 --- a/fxmcpserver/module_test.go +++ b/fxmcpserver/module_test.go @@ -2,6 +2,7 @@ package fxmcpserver_test import ( "context" + "github.com/mark3labs/mcp-go/client" "strings" "testing" @@ -36,7 +37,9 @@ func TestMCPServerModule(t *testing.T) { t.Setenv("APP_ENV", "test") t.Setenv("APP_CONFIG_PATH", "testdata/config") - var testServer *fxmcpservertest.MCPSSETestServer + var testMCPStreamableHTTPServer *fxmcpservertest.MCPStreamableHTTPTestServer + var testMCPSSEServer *fxmcpservertest.MCPSSETestServer + var provider fs.MCPServerHooksProvider var checker *healthcheck.Checker var logBuffer logtest.TestLogBuffer var traceExporter tracetest.TestTraceExporter @@ -58,310 +61,353 @@ func TestMCPServerModule(t *testing.T) { fxmcpserver.AsMCPServerResources(resource.NewSimpleTestResource), fxmcpserver.AsMCPServerResourceTemplates(resourcetemplate.NewSimpleTestResourceTemplate), fxmcpserver.AsMCPSSEServerContextHooks(hook.NewSimpleMCPSSEServerContextHook), + fxmcpserver.AsMCPStreamableHTTPServerContextHooks(hook.NewSimpleMCPStreamableHTTPServerContextHook), fxhealthcheck.AsCheckerProbe(fs.NewMCPServerProbe), ), fx.Supply(fx.Annotate(context.Background(), fx.As(new(context.Context)))), - fx.Populate(&testServer, &checker, &logBuffer, &traceExporter, &metricsRegistry), + fx.Populate( + &testMCPStreamableHTTPServer, + &testMCPSSEServer, + &provider, + &checker, + &logBuffer, + &traceExporter, + &metricsRegistry, + ), ).RequireStart().RequireStop() - defer testServer.Close() + // ensure test servers closure + defer func() { + testMCPStreamableHTTPServer.Close() + testMCPSSEServer.Close() + }() ctx := context.Background() // health check checkResult := checker.Check(context.Background(), healthcheck.Readiness) assert.False(t, checkResult.Success) - assert.Equal(t, "MCP SSE server is not running", checkResult.ProbesResults["mcpserver"].Message) - - // start test client - testClient, err := testServer.StartClient(ctx) - assert.NoError(t, err) - - defer testClient.Close() - - // send success tools/call request - expectedRequest := `{"method":"tools/call","params":{"name":"advanced-test-tool","arguments":{"shouldFail":"false"}}}` - expectedResponse := `{"content":[{"type":"text","text":"test"}]}` - - callToolRequest := mcp.CallToolRequest{} - callToolRequest.Params.Name = "advanced-test-tool" - callToolRequest.Params.Arguments = map[string]interface{}{ - "shouldFail": "false", - } - - callToolResult, err := testClient.CallTool(ctx, callToolRequest) - assert.NoError(t, err) - assert.False(t, callToolResult.IsError) - - logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ - "level": "info", - "mcpMethod": "tools/call", - "mcpTool": "advanced-test-tool", - "mcpRequest": expectedRequest, - "mcpResponse": expectedResponse, - "mcpTransport": "sse", - "message": "MCP request success", - }) - - tracetest.AssertHasTraceSpan( - t, - traceExporter, - "MCP tools/call advanced-test-tool", - attribute.String("mcp.method", "tools/call"), - attribute.String("mcp.tool", "advanced-test-tool"), - attribute.String("mcp.request", expectedRequest), - attribute.String("mcp.response", expectedResponse), - attribute.String("mcp.transport", "sse"), - ) - - expectedMetric := ` - # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests - # TYPE foo_bar_mcp_server_requests_total counter - foo_bar_mcp_server_requests_total{method="initialize",status="success",target=""} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 - ` - err = testutil.GatherAndCompare( - metricsRegistry, - strings.NewReader(expectedMetric), - "foo_bar_mcp_server_requests_total", - ) - assert.NoError(t, err) - - // send error tools/call request - expectedRequest = `{"method":"tools/call","params":{"name":"advanced-test-tool","arguments":{"shouldFail":"true"}}}` - - callToolRequest = mcp.CallToolRequest{} - callToolRequest.Params.Name = "advanced-test-tool" - callToolRequest.Params.Arguments = map[string]interface{}{ - "shouldFail": "true", - } - - _, err = testClient.CallTool(ctx, callToolRequest) - assert.Error(t, err) - assert.Equal(t, "advanced tool test failure", err.Error()) - - logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ - "level": "error", - "mcpError": "request error: advanced tool test failure", - "mcpMethod": "tools/call", - "mcpTool": "advanced-test-tool", - "mcpRequest": expectedRequest, - "mcpTransport": "sse", - "message": "MCP request error", - }) - - tracetest.AssertHasTraceSpan( - t, - traceExporter, - "MCP tools/call advanced-test-tool", - attribute.String("mcp.method", "tools/call"), - attribute.String("mcp.tool", "advanced-test-tool"), - attribute.String("mcp.request", expectedRequest), - attribute.String("mcp.transport", "sse"), - ) - - expectedMetric = ` - # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests - # TYPE foo_bar_mcp_server_requests_total counter - foo_bar_mcp_server_requests_total{method="initialize",status="success",target=""} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="error",target="advanced-test-tool"} 1 - ` - err = testutil.GatherAndCompare( - metricsRegistry, - strings.NewReader(expectedMetric), - "foo_bar_mcp_server_requests_total", - ) - assert.NoError(t, err) - - // send success prompts/get request - expectedRequest = `{"method":"prompts/get","params":{"name":"simple-test-prompt"}}` - expectedResponse = `{"description":"ok","messages":[{"role":"assistant","content":{"type":"text","text":"context hook value: bar"}}]}` - - getPromptRequest := mcp.GetPromptRequest{} - getPromptRequest.Params.Name = "simple-test-prompt" - - getPromptResult, err := testClient.GetPrompt(ctx, getPromptRequest) - assert.NoError(t, err) - assert.Equal(t, mcp.RoleAssistant, getPromptResult.Messages[0].Role) - assert.Equal(t, "context hook value: bar", getPromptResult.Messages[0].Content.(mcp.TextContent).Text) - - logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ - "level": "info", - "mcpMethod": "prompts/get", - "mcpPrompt": "simple-test-prompt", - "mcpRequest": expectedRequest, - "mcpResponse": expectedResponse, - "mcpTransport": "sse", - "message": "MCP request success", - }) - - tracetest.AssertHasTraceSpan( - t, - traceExporter, - "MCP prompts/get simple-test-prompt", - attribute.String("mcp.method", "prompts/get"), - attribute.String("mcp.prompt", "simple-test-prompt"), - attribute.String("mcp.request", expectedRequest), - attribute.String("mcp.response", expectedResponse), - attribute.String("mcp.transport", "sse"), - ) - - expectedMetric = ` - # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests - # TYPE foo_bar_mcp_server_requests_total counter - foo_bar_mcp_server_requests_total{method="prompts/get",status="success",target="simple-test-prompt"} 1 - foo_bar_mcp_server_requests_total{method="initialize",status="success",target=""} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="error",target="advanced-test-tool"} 1 - ` - err = testutil.GatherAndCompare( - metricsRegistry, - strings.NewReader(expectedMetric), - "foo_bar_mcp_server_requests_total", - ) - assert.NoError(t, err) - - // send error prompts/get request - expectedRequest = `{"method":"prompts/get","params":{"name":"invalid-test-prompt"}}` - - getPromptRequest = mcp.GetPromptRequest{} - getPromptRequest.Params.Name = "invalid-test-prompt" - - _, err = testClient.GetPrompt(ctx, getPromptRequest) - assert.Error(t, err) - assert.Equal(t, "prompt 'invalid-test-prompt' not found: prompt not found", err.Error()) - - logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ - "level": "error", - "mcpError": "request error: prompt 'invalid-test-prompt' not found: prompt not found", - "mcpMethod": "prompts/get", - "mcpPrompt": "invalid-test-prompt", - "mcpRequest": expectedRequest, - "mcpTransport": "sse", - "message": "MCP request error", - }) - - tracetest.AssertHasTraceSpan( + assert.Equal( t, - traceExporter, - "MCP prompts/get invalid-test-prompt", - attribute.String("mcp.method", "prompts/get"), - attribute.String("mcp.prompt", "invalid-test-prompt"), - attribute.String("mcp.request", expectedRequest), - attribute.String("mcp.transport", "sse"), - ) - - expectedMetric = ` - # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests - # TYPE foo_bar_mcp_server_requests_total counter - foo_bar_mcp_server_requests_total{method="prompts/get",status="error",target="invalid-test-prompt"} 1 - foo_bar_mcp_server_requests_total{method="prompts/get",status="success",target="simple-test-prompt"} 1 - foo_bar_mcp_server_requests_total{method="initialize",status="success",target=""} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="error",target="advanced-test-tool"} 1 - ` - err = testutil.GatherAndCompare( - metricsRegistry, - strings.NewReader(expectedMetric), - "foo_bar_mcp_server_requests_total", + "MCP StreamableHTTP server is not running, MCP SSE server is not running", + checkResult.ProbesResults["mcpserver"].Message, ) - assert.NoError(t, err) - - // send success resources/get request - expectedRequest = `{"method":"resources/read","params":{"uri":"simple-test://resources"}}` - expectedResponse = `{"contents":[{"uri":"simple-test://resources","mimeType":"text/plain","text":"simple test resource"}]}` - - readResourceRequest := mcp.ReadResourceRequest{} - readResourceRequest.Params.URI = "simple-test://resources" - readResourceResult, err := testClient.ReadResource(ctx, readResourceRequest) + // start test clients + testMCPStreamableHTTPClient, err := testMCPStreamableHTTPServer.StartClient(ctx) assert.NoError(t, err) - assert.Equal(t, "simple test resource", readResourceResult.Contents[0].(mcp.TextResourceContents).Text) - - logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ - "level": "info", - "mcpMethod": "resources/read", - "mcpResourceURI": "simple-test://resources", - "mcpRequest": expectedRequest, - "mcpResponse": expectedResponse, - "mcpTransport": "sse", - "message": "MCP request success", - }) - - tracetest.AssertHasTraceSpan( - t, - traceExporter, - "MCP resources/read simple-test://resources", - attribute.String("mcp.method", "resources/read"), - attribute.String("mcp.resourceURI", "simple-test://resources"), - attribute.String("mcp.request", expectedRequest), - attribute.String("mcp.response", expectedResponse), - attribute.String("mcp.transport", "sse"), - ) - expectedMetric = ` - # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests - # TYPE foo_bar_mcp_server_requests_total counter - foo_bar_mcp_server_requests_total{method="prompts/get",status="error",target="invalid-test-prompt"} 1 - foo_bar_mcp_server_requests_total{method="prompts/get",status="success",target="simple-test-prompt"} 1 - foo_bar_mcp_server_requests_total{method="initialize",status="success",target=""} 1 - foo_bar_mcp_server_requests_total{method="resources/read",status="success",target="simple-test://resources"} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="error",target="advanced-test-tool"} 1 - ` - err = testutil.GatherAndCompare( - metricsRegistry, - strings.NewReader(expectedMetric), - "foo_bar_mcp_server_requests_total", - ) + testMCPSSEClient, err := testMCPSSEServer.StartClient(ctx) assert.NoError(t, err) + defer testMCPSSEClient.Close() + + // hooks provider + defaultProvider, ok := provider.(*fs.DefaultMCPServerHooksProvider) + assert.True(t, ok) + + tests := []struct { + name string + client *client.Client + transport string + }{ + { + "with StreamableHTTP transport", + testMCPStreamableHTTPClient, + "streamable-http", + }, + { + "with SSE transport", + testMCPSSEClient, + "sse", + }, + } - // send error resources/get request - expectedRequest = `{"method":"resources/read","params":{"uri":"simple-test://invalid"}}` - - readResourceRequest = mcp.ReadResourceRequest{} - readResourceRequest.Params.URI = "simple-test://invalid" - - _, err = testClient.ReadResource(ctx, readResourceRequest) - assert.Error(t, err) - assert.Equal(t, "handler not found for resource URI 'simple-test://invalid': resource not found", err.Error()) - - logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ - "level": "error", - "mcpError": "request error: handler not found for resource URI 'simple-test://invalid': resource not found", - "mcpMethod": "resources/read", - "mcpResourceURI": "simple-test://invalid", - "mcpRequest": expectedRequest, - "mcpTransport": "sse", - "message": "MCP request error", - }) - - tracetest.AssertHasTraceSpan( - t, - traceExporter, - "MCP resources/read simple-test://invalid", - attribute.String("mcp.method", "resources/read"), - attribute.String("mcp.resourceURI", "simple-test://invalid"), - attribute.String("mcp.request", expectedRequest), - attribute.String("mcp.transport", "sse"), - ) - - expectedMetric = ` - # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests - # TYPE foo_bar_mcp_server_requests_total counter - foo_bar_mcp_server_requests_total{method="prompts/get",status="error",target="invalid-test-prompt"} 1 - foo_bar_mcp_server_requests_total{method="prompts/get",status="success",target="simple-test-prompt"} 1 - foo_bar_mcp_server_requests_total{method="initialize",status="success",target=""} 1 - foo_bar_mcp_server_requests_total{method="resources/read",status="error",target="simple-test://invalid"} 1 - foo_bar_mcp_server_requests_total{method="resources/read",status="success",target="simple-test://resources"} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 - foo_bar_mcp_server_requests_total{method="tools/call",status="error",target="advanced-test-tool"} 1 - ` - err = testutil.GatherAndCompare( - metricsRegistry, - strings.NewReader(expectedMetric), - "foo_bar_mcp_server_requests_total", - ) - assert.NoError(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // reset o11y + logBuffer.Reset() + traceExporter.Reset() + defaultProvider.Reset() + + // send success tools/call request + expectedRequest := `{"method":"tools/call","params":{"name":"advanced-test-tool","arguments":{"shouldFail":"false"}}}` + expectedResponse := `{"content":[{"type":"text","text":"test"}]}` + + callToolRequest := mcp.CallToolRequest{} + callToolRequest.Params.Name = "advanced-test-tool" + callToolRequest.Params.Arguments = map[string]interface{}{ + "shouldFail": "false", + } + + callToolResult, err := tt.client.CallTool(ctx, callToolRequest) + assert.NoError(t, err) + assert.False(t, callToolResult.IsError) + + logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ + "level": "info", + "mcpMethod": "tools/call", + "mcpTool": "advanced-test-tool", + "mcpRequest": expectedRequest, + "mcpResponse": expectedResponse, + "mcpTransport": tt.transport, + "message": "MCP request success", + }) + + tracetest.AssertHasTraceSpan( + t, + traceExporter, + "MCP tools/call advanced-test-tool", + attribute.String("mcp.method", "tools/call"), + attribute.String("mcp.tool", "advanced-test-tool"), + attribute.String("mcp.request", expectedRequest), + attribute.String("mcp.response", expectedResponse), + attribute.String("mcp.transport", tt.transport), + ) + + expectedMetric := ` + # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests + # TYPE foo_bar_mcp_server_requests_total counter + foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 + ` + err = testutil.GatherAndCompare( + metricsRegistry, + strings.NewReader(expectedMetric), + "foo_bar_mcp_server_requests_total", + ) + assert.NoError(t, err) + + // send error tools/call request + expectedRequest = `{"method":"tools/call","params":{"name":"advanced-test-tool","arguments":{"shouldFail":"true"}}}` + + callToolRequest = mcp.CallToolRequest{} + callToolRequest.Params.Name = "advanced-test-tool" + callToolRequest.Params.Arguments = map[string]interface{}{ + "shouldFail": "true", + } + + _, err = tt.client.CallTool(ctx, callToolRequest) + assert.Error(t, err) + assert.Equal(t, "advanced tool test failure", err.Error()) + + logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ + "level": "error", + "mcpError": "request error: advanced tool test failure", + "mcpMethod": "tools/call", + "mcpTool": "advanced-test-tool", + "mcpRequest": expectedRequest, + "mcpTransport": tt.transport, + "message": "MCP request error", + }) + + tracetest.AssertHasTraceSpan( + t, + traceExporter, + "MCP tools/call advanced-test-tool", + attribute.String("mcp.method", "tools/call"), + attribute.String("mcp.tool", "advanced-test-tool"), + attribute.String("mcp.request", expectedRequest), + attribute.String("mcp.transport", tt.transport), + ) + + expectedMetric = ` + # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests + # TYPE foo_bar_mcp_server_requests_total counter + foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 + foo_bar_mcp_server_requests_total{method="tools/call",status="error",target="advanced-test-tool"} 1 + ` + err = testutil.GatherAndCompare( + metricsRegistry, + strings.NewReader(expectedMetric), + "foo_bar_mcp_server_requests_total", + ) + assert.NoError(t, err) + + // send success prompts/get request + expectedRequest = `{"method":"prompts/get","params":{"name":"simple-test-prompt"}}` + expectedResponse = `{"description":"ok","messages":[{"role":"assistant","content":{"type":"text","text":"context hook value: bar"}}]}` + + getPromptRequest := mcp.GetPromptRequest{} + getPromptRequest.Params.Name = "simple-test-prompt" + + getPromptResult, err := tt.client.GetPrompt(ctx, getPromptRequest) + assert.NoError(t, err) + assert.Equal(t, mcp.RoleAssistant, getPromptResult.Messages[0].Role) + assert.Equal(t, "context hook value: bar", getPromptResult.Messages[0].Content.(mcp.TextContent).Text) + + logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ + "level": "info", + "mcpMethod": "prompts/get", + "mcpPrompt": "simple-test-prompt", + "mcpRequest": expectedRequest, + "mcpResponse": expectedResponse, + "mcpTransport": tt.transport, + "message": "MCP request success", + }) + + tracetest.AssertHasTraceSpan( + t, + traceExporter, + "MCP prompts/get simple-test-prompt", + attribute.String("mcp.method", "prompts/get"), + attribute.String("mcp.prompt", "simple-test-prompt"), + attribute.String("mcp.request", expectedRequest), + attribute.String("mcp.response", expectedResponse), + attribute.String("mcp.transport", tt.transport), + ) + + expectedMetric = ` + # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests + # TYPE foo_bar_mcp_server_requests_total counter + foo_bar_mcp_server_requests_total{method="prompts/get",status="success",target="simple-test-prompt"} 1 + foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 + foo_bar_mcp_server_requests_total{method="tools/call",status="error",target="advanced-test-tool"} 1 + ` + err = testutil.GatherAndCompare( + metricsRegistry, + strings.NewReader(expectedMetric), + "foo_bar_mcp_server_requests_total", + ) + assert.NoError(t, err) + + // send error prompts/get request + expectedRequest = `{"method":"prompts/get","params":{"name":"invalid-test-prompt"}}` + + getPromptRequest = mcp.GetPromptRequest{} + getPromptRequest.Params.Name = "invalid-test-prompt" + + _, err = tt.client.GetPrompt(ctx, getPromptRequest) + assert.Error(t, err) + assert.Equal(t, "prompt 'invalid-test-prompt' not found: prompt not found", err.Error()) + + logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ + "level": "error", + "mcpError": "request error: prompt 'invalid-test-prompt' not found: prompt not found", + "mcpMethod": "prompts/get", + "mcpPrompt": "invalid-test-prompt", + "mcpRequest": expectedRequest, + "mcpTransport": tt.transport, + "message": "MCP request error", + }) + + tracetest.AssertHasTraceSpan( + t, + traceExporter, + "MCP prompts/get invalid-test-prompt", + attribute.String("mcp.method", "prompts/get"), + attribute.String("mcp.prompt", "invalid-test-prompt"), + attribute.String("mcp.request", expectedRequest), + attribute.String("mcp.transport", tt.transport), + ) + + expectedMetric = ` + # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests + # TYPE foo_bar_mcp_server_requests_total counter + foo_bar_mcp_server_requests_total{method="prompts/get",status="error",target="invalid-test-prompt"} 1 + foo_bar_mcp_server_requests_total{method="prompts/get",status="success",target="simple-test-prompt"} 1 + foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 + foo_bar_mcp_server_requests_total{method="tools/call",status="error",target="advanced-test-tool"} 1 + ` + err = testutil.GatherAndCompare( + metricsRegistry, + strings.NewReader(expectedMetric), + "foo_bar_mcp_server_requests_total", + ) + assert.NoError(t, err) + + // send success resources/get request + expectedRequest = `{"method":"resources/read","params":{"uri":"simple-test://resources"}}` + expectedResponse = `{"contents":[{"uri":"simple-test://resources","mimeType":"text/plain","text":"simple test resource"}]}` + + readResourceRequest := mcp.ReadResourceRequest{} + readResourceRequest.Params.URI = "simple-test://resources" + + readResourceResult, err := tt.client.ReadResource(ctx, readResourceRequest) + assert.NoError(t, err) + assert.Equal(t, "simple test resource", readResourceResult.Contents[0].(mcp.TextResourceContents).Text) + + logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ + "level": "info", + "mcpMethod": "resources/read", + "mcpResourceURI": "simple-test://resources", + "mcpRequest": expectedRequest, + "mcpResponse": expectedResponse, + "mcpTransport": tt.transport, + "message": "MCP request success", + }) + + tracetest.AssertHasTraceSpan( + t, + traceExporter, + "MCP resources/read simple-test://resources", + attribute.String("mcp.method", "resources/read"), + attribute.String("mcp.resourceURI", "simple-test://resources"), + attribute.String("mcp.request", expectedRequest), + attribute.String("mcp.response", expectedResponse), + attribute.String("mcp.transport", tt.transport), + ) + + expectedMetric = ` + # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests + # TYPE foo_bar_mcp_server_requests_total counter + foo_bar_mcp_server_requests_total{method="prompts/get",status="error",target="invalid-test-prompt"} 1 + foo_bar_mcp_server_requests_total{method="prompts/get",status="success",target="simple-test-prompt"} 1 + foo_bar_mcp_server_requests_total{method="resources/read",status="success",target="simple-test://resources"} 1 + foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 + foo_bar_mcp_server_requests_total{method="tools/call",status="error",target="advanced-test-tool"} 1 + ` + err = testutil.GatherAndCompare( + metricsRegistry, + strings.NewReader(expectedMetric), + "foo_bar_mcp_server_requests_total", + ) + assert.NoError(t, err) + + // send error resources/get request + expectedRequest = `{"method":"resources/read","params":{"uri":"simple-test://invalid"}}` + + readResourceRequest = mcp.ReadResourceRequest{} + readResourceRequest.Params.URI = "simple-test://invalid" + + _, err = tt.client.ReadResource(ctx, readResourceRequest) + assert.Error(t, err) + assert.Equal(t, "handler not found for resource URI 'simple-test://invalid': resource not found", err.Error()) + + logtest.AssertHasLogRecord(t, logBuffer, map[string]any{ + "level": "error", + "mcpError": "request error: handler not found for resource URI 'simple-test://invalid': resource not found", + "mcpMethod": "resources/read", + "mcpResourceURI": "simple-test://invalid", + "mcpRequest": expectedRequest, + "mcpTransport": tt.transport, + "message": "MCP request error", + }) + + tracetest.AssertHasTraceSpan( + t, + traceExporter, + "MCP resources/read simple-test://invalid", + attribute.String("mcp.method", "resources/read"), + attribute.String("mcp.resourceURI", "simple-test://invalid"), + attribute.String("mcp.request", expectedRequest), + attribute.String("mcp.transport", tt.transport), + ) + + expectedMetric = ` + # HELP foo_bar_mcp_server_requests_total Number of processed MCP requests + # TYPE foo_bar_mcp_server_requests_total counter + foo_bar_mcp_server_requests_total{method="prompts/get",status="error",target="invalid-test-prompt"} 1 + foo_bar_mcp_server_requests_total{method="prompts/get",status="success",target="simple-test-prompt"} 1 + foo_bar_mcp_server_requests_total{method="resources/read",status="error",target="simple-test://invalid"} 1 + foo_bar_mcp_server_requests_total{method="resources/read",status="success",target="simple-test://resources"} 1 + foo_bar_mcp_server_requests_total{method="tools/call",status="success",target="advanced-test-tool"} 1 + foo_bar_mcp_server_requests_total{method="tools/call",status="error",target="advanced-test-tool"} 1 + ` + err = testutil.GatherAndCompare( + metricsRegistry, + strings.NewReader(expectedMetric), + "foo_bar_mcp_server_requests_total", + ) + assert.NoError(t, err) + }) + } } diff --git a/fxmcpserver/register.go b/fxmcpserver/register.go index d7def683..c508b1bb 100644 --- a/fxmcpserver/register.go +++ b/fxmcpserver/register.go @@ -3,6 +3,7 @@ package fxmcpserver import ( "github.com/ankorstore/yokai/fxmcpserver/server" "github.com/ankorstore/yokai/fxmcpserver/server/sse" + "github.com/ankorstore/yokai/fxmcpserver/server/stream" "go.uber.org/fx" ) @@ -105,7 +106,7 @@ func AsMCPSSEServerContextHook(constructor any) fx.Option { ) } -// AsMCPSSEServerContextHooks registers several MCP SSE server context hook. +// AsMCPSSEServerContextHooks registers several MCP SSE server context hooks. func AsMCPSSEServerContextHooks(constructors ...any) fx.Option { options := []fx.Option{} @@ -115,3 +116,25 @@ func AsMCPSSEServerContextHooks(constructors ...any) fx.Option { return fx.Options(options...) } + +// AsMCPStreamableHTTPServerContextHook registers an MCP StreamableHTTP server context hook. +func AsMCPStreamableHTTPServerContextHook(constructor any) fx.Option { + return fx.Provide( + fx.Annotate( + constructor, + fx.As(new(stream.MCPStreamableHTTPServerContextHook)), + fx.ResultTags(`group:"mcp-streamable-http-server-context-hooks"`), + ), + ) +} + +// AsMCPStreamableHTTPServerContextHooks registers several MCP StreamableHTTP server context hooks. +func AsMCPStreamableHTTPServerContextHooks(constructors ...any) fx.Option { + options := []fx.Option{} + + for _, constructor := range constructors { + options = append(options, AsMCPStreamableHTTPServerContextHook(constructor)) + } + + return fx.Options(options...) +} diff --git a/fxmcpserver/register_test.go b/fxmcpserver/register_test.go index eae65450..1062d270 100644 --- a/fxmcpserver/register_test.go +++ b/fxmcpserver/register_test.go @@ -103,3 +103,21 @@ func TestAsMCPSSEServerContextHooks(t *testing.T) { assert.Equal(t, "fx.optionGroup", fmt.Sprintf("%T", reg)) assert.Implements(t, (*fx.Option)(nil), reg) } + +func TestAsMCPStreamableHTTPServerContextHook(t *testing.T) { + t.Parallel() + + reg := fxmcpserver.AsMCPStreamableHTTPServerContextHook(hook.NewSimpleMCPStreamableHTTPServerContextHook) + + assert.Equal(t, "fx.provideOption", fmt.Sprintf("%T", reg)) + assert.Implements(t, (*fx.Option)(nil), reg) +} + +func TestAsMCPStreamableHTTPServerContextHooks(t *testing.T) { + t.Parallel() + + reg := fxmcpserver.AsMCPStreamableHTTPServerContextHooks(hook.NewSimpleMCPStreamableHTTPServerContextHook) + + assert.Equal(t, "fx.optionGroup", fmt.Sprintf("%T", reg)) + assert.Implements(t, (*fx.Option)(nil), reg) +} diff --git a/fxmcpserver/server/healthcheck.go b/fxmcpserver/server/healthcheck.go index 8bb49f7b..1386ecd7 100644 --- a/fxmcpserver/server/healthcheck.go +++ b/fxmcpserver/server/healthcheck.go @@ -2,6 +2,7 @@ package server import ( "context" + "github.com/ankorstore/yokai/fxmcpserver/server/stream" "strings" "github.com/ankorstore/yokai/config" @@ -12,21 +13,24 @@ import ( // MCPServerProbe is a probe compatible with the healthcheck module. type MCPServerProbe struct { - config *config.Config - sseServer *sse.MCPSSEServer - stdioServer *stdio.MCPStdioServer + config *config.Config + steamableHTTPServer *stream.MCPStreamableHTTPServer + sseServer *sse.MCPSSEServer + stdioServer *stdio.MCPStdioServer } // NewMCPServerProbe returns a new MCPServerProbe. func NewMCPServerProbe( config *config.Config, + steamableHTTPServer *stream.MCPStreamableHTTPServer, sseServer *sse.MCPSSEServer, stdioServer *stdio.MCPStdioServer, ) *MCPServerProbe { return &MCPServerProbe{ - config: config, - sseServer: sseServer, - stdioServer: stdioServer, + config: config, + steamableHTTPServer: steamableHTTPServer, + sseServer: sseServer, + stdioServer: stdioServer, } } @@ -40,6 +44,15 @@ func (p *MCPServerProbe) Check(context.Context) *healthcheck.CheckerProbeResult success := true var messages []string + if p.config.GetBool("modules.mcp.server.transport.stream.expose") { + if p.steamableHTTPServer.Running() { + messages = append(messages, "MCP StreamableHTTP server is running") + } else { + success = false + messages = append(messages, "MCP StreamableHTTP server is not running") + } + } + if p.config.GetBool("modules.mcp.server.transport.sse.expose") { if p.sseServer.Running() { messages = append(messages, "MCP SSE server is running") diff --git a/fxmcpserver/server/healthcheck_test.go b/fxmcpserver/server/healthcheck_test.go index ed32c31f..e34965e3 100644 --- a/fxmcpserver/server/healthcheck_test.go +++ b/fxmcpserver/server/healthcheck_test.go @@ -2,6 +2,7 @@ package server_test import ( "context" + "github.com/ankorstore/yokai/fxmcpserver/server/stream" "testing" "github.com/ankorstore/yokai/config" @@ -22,13 +23,14 @@ func TestMCPServerProbe(t *testing.T) { mcpSrv := server.NewMCPServer("test-server", "1.0.0") + streamSrv := stream.NewDefaultMCPStreamableHTTPServerFactory(cfg).Create(mcpSrv) sseSrv := sse.NewDefaultMCPSSEServerFactory(cfg).Create(mcpSrv) stdioSrv := stdio.NewDefaultMCPStdioServerFactory().Create(mcpSrv) - probe := fs.NewMCPServerProbe(cfg, sseSrv, stdioSrv) + probe := fs.NewMCPServerProbe(cfg, streamSrv, sseSrv, stdioSrv) res := probe.Check(context.Background()) assert.False(t, res.Success) - assert.Equal(t, "MCP SSE server is not running", res.Message) + assert.Equal(t, "MCP StreamableHTTP server is not running, MCP SSE server is not running", res.Message) } diff --git a/fxmcpserver/server/provider.go b/fxmcpserver/server/provider.go index 289feb3b..e9f8e23e 100644 --- a/fxmcpserver/server/provider.go +++ b/fxmcpserver/server/provider.go @@ -262,3 +262,9 @@ func (p *DefaultMCPServerHooksProvider) Provide() *server.Hooks { return hooks } + +// Reset resets the MCP requests metrics. +func (p *DefaultMCPServerHooksProvider) Reset() { + p.requestsCounter.Reset() + p.requestsDuration.Reset() +} diff --git a/fxmcpserver/server/sse/context_test.go b/fxmcpserver/server/sse/context_test.go index 34bada89..6cbf7ed2 100644 --- a/fxmcpserver/server/sse/context_test.go +++ b/fxmcpserver/server/sse/context_test.go @@ -86,7 +86,7 @@ func TestDefaultMCPSSEServerContextHandler_Handle(t *testing.T) { gm.AssertExpectations(t) }) - t.Run("with provided session id and request id and middleware", func(t *testing.T) { + t.Run("with provided session id and request id and hook", func(t *testing.T) { t.Parallel() gm := new(generatorMock) diff --git a/fxmcpserver/server/sse/factory.go b/fxmcpserver/server/sse/factory.go index beccbac2..98fd015c 100644 --- a/fxmcpserver/server/sse/factory.go +++ b/fxmcpserver/server/sse/factory.go @@ -82,7 +82,7 @@ func (f *DefaultMCPSSEServerFactory) Create(mcpServer *server.MCPServer, options srvOptions := []server.SSEOption{ server.WithBaseURL(srvConfig.BaseURL), - server.WithBasePath(srvConfig.BasePath), + server.WithStaticBasePath(srvConfig.BasePath), server.WithSSEEndpoint(srvConfig.SSEEndpoint), server.WithMessageEndpoint(srvConfig.MessageEndpoint), } diff --git a/fxmcpserver/server/sse/server.go b/fxmcpserver/server/sse/server.go index c419d419..61e033ce 100644 --- a/fxmcpserver/server/sse/server.go +++ b/fxmcpserver/server/sse/server.go @@ -60,7 +60,9 @@ func (s *MCPSSEServer) Start(ctx context.Context) error { if err != nil { logger.Error().Err(err).Msgf("failed to start MCP SSE server") + s.mutex.Lock() s.running = false + s.mutex.Unlock() } return err @@ -105,7 +107,7 @@ func (s *MCPSSEServer) Info() map[string]any { "keep_alive_interval": s.config.KeepAliveInterval.Seconds(), }, "status": map[string]any{ - "running": s.running, + "running": s.Running(), }, } } diff --git a/fxmcpserver/server/sse/server_test.go b/fxmcpserver/server/sse/server_test.go index 66473afd..64ffcb34 100644 --- a/fxmcpserver/server/sse/server_test.go +++ b/fxmcpserver/server/sse/server_test.go @@ -52,10 +52,8 @@ func TestMCPSSEServer(t *testing.T) { ctx := lg.WithContext(context.Background()) - go func(fCtx context.Context) { - fErr := srv.Start(fCtx) - assert.NoError(t, fErr) - }(ctx) + //nolint:errcheck + go srv.Start(ctx) time.Sleep(1 * time.Millisecond) @@ -66,10 +64,8 @@ func TestMCPSSEServer(t *testing.T) { "message": "starting MCP SSE server on :0", }) - go func(fCtx context.Context) { - fErr := srv.Stop(fCtx) - assert.NoError(t, fErr) - }(ctx) + err = srv.Stop(ctx) + assert.NoError(t, err) time.Sleep(1 * time.Millisecond) diff --git a/fxmcpserver/server/stream/context.go b/fxmcpserver/server/stream/context.go new file mode 100644 index 00000000..9b0c0730 --- /dev/null +++ b/fxmcpserver/server/stream/context.go @@ -0,0 +1,110 @@ +package stream + +import ( + "context" + "net/http" + "time" + + fsc "github.com/ankorstore/yokai/fxmcpserver/server/context" + "github.com/ankorstore/yokai/generate/uuid" + "github.com/ankorstore/yokai/log" + "github.com/ankorstore/yokai/trace" + "github.com/mark3labs/mcp-go/server" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/propagation" + ot "go.opentelemetry.io/otel/trace" +) + +var _ MCPStreamableHTTPServerContextHandler = (*DefaultMCPStreamableHTTPServerContextHandler)(nil) + +// MCPStreamableHTTPServerContextHook is the interface for MCP StreamableHTTP server context hooks. +type MCPStreamableHTTPServerContextHook interface { + Handle() server.HTTPContextFunc +} + +// MCPStreamableHTTPServerContextHandler is the interface for MCP StreamableHTTP server context handlers. +type MCPStreamableHTTPServerContextHandler interface { + Handle() server.HTTPContextFunc +} + +// DefaultMCPStreamableHTTPServerContextHandler is the default MCPStreamableHTTPServerContextHandler implementation. +type DefaultMCPStreamableHTTPServerContextHandler struct { + generator uuid.UuidGenerator + tracerProvider ot.TracerProvider + textMapPropagator propagation.TextMapPropagator + logger *log.Logger + contextHooks []MCPStreamableHTTPServerContextHook +} + +// NewDefaultMCPStreamableHTTPServerContextHandler returns a new DefaultMCPStreamableHTTPServerContextHandler instance. +func NewDefaultMCPStreamableHTTPServerContextHandler( + generator uuid.UuidGenerator, + tracerProvider ot.TracerProvider, + textMapPropagator propagation.TextMapPropagator, + logger *log.Logger, + contextHooks ...MCPStreamableHTTPServerContextHook, +) *DefaultMCPStreamableHTTPServerContextHandler { + return &DefaultMCPStreamableHTTPServerContextHandler{ + generator: generator, + tracerProvider: tracerProvider, + textMapPropagator: textMapPropagator, + logger: logger, + contextHooks: contextHooks, + } +} + +// Handle returns the handler func. +func (h *DefaultMCPStreamableHTTPServerContextHandler) Handle() server.HTTPContextFunc { + return func(ctx context.Context, req *http.Request) context.Context { + // start time propagation + ctx = fsc.WithStartTime(ctx, time.Now()) + + // requestId propagation + rID := req.Header.Get("X-Request-Id") + + if rID == "" { + rID = h.generator.Generate() + req.Header.Set("X-Request-Id", rID) + } + + ctx = fsc.WithRequestID(ctx, rID) + + // tracer propagation + ctx = h.textMapPropagator.Extract(ctx, propagation.HeaderCarrier(req.Header)) + + ctx = trace.WithContext(ctx, h.tracerProvider) + + ctx, span := trace.CtxTracer(ctx).Start( + ctx, + "MCP", + ot.WithSpanKind(ot.SpanKindServer), + ot.WithAttributes( + attribute.String("system", "mcpserver"), + attribute.String("mcp.transport", "streamable-http"), + attribute.String("mcp.requestID", rID), + ), + ) + + ctx = fsc.WithRootSpan(ctx, span) + + // logger propagation + logger := h.logger. + With(). + Str("system", "mcpserver"). + Str("mcpTransport", "streamable-http"). + Str("mcpRequestID", rID). + Logger() + + ctx = logger.WithContext(ctx) + + // cancellation removal propagation + ctx = context.WithoutCancel(ctx) + + // hooks propagation + for _, hook := range h.contextHooks { + ctx = hook.Handle()(ctx, req) + } + + return ctx + } +} diff --git a/fxmcpserver/server/stream/context_test.go b/fxmcpserver/server/stream/context_test.go new file mode 100644 index 00000000..9f4dcee2 --- /dev/null +++ b/fxmcpserver/server/stream/context_test.go @@ -0,0 +1,141 @@ +package stream_test + +import ( + "context" + "github.com/ankorstore/yokai/fxmcpserver/server/stream" + "net/http" + "net/http/httptest" + "testing" + + servercontext "github.com/ankorstore/yokai/fxmcpserver/server/context" + "github.com/ankorstore/yokai/fxmcpserver/testdata/hook" + "github.com/ankorstore/yokai/log" + "github.com/ankorstore/yokai/log/logtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/sdk/trace" +) + +type generatorMock struct { + mock.Mock +} + +func (m *generatorMock) Generate() string { + return m.Called().String(0) +} + +func TestDefaultMCPStreamableHTTPServerContextHandler_Handle(t *testing.T) { + t.Parallel() + + t.Run("with defaults", func(t *testing.T) { + t.Parallel() + + gm := new(generatorMock) + gm.On("Generate").Return("test-request-id") + + tp := trace.NewTracerProvider() + + tmp := propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}) + + lb := logtest.NewDefaultTestLogBuffer() + lg, err := log.NewDefaultLoggerFactory().Create(log.WithOutputWriter(lb)) + assert.NoError(t, err) + + handler := stream.NewDefaultMCPStreamableHTTPServerContextHandler(gm, tp, tmp, lg) + + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + + ctx := handler.Handle()(context.Background(), req) + + assert.Equal(t, "", servercontext.CtxSessionID(ctx)) + assert.Equal(t, "test-request-id", servercontext.CtxRequestId(ctx)) + + span, ok := servercontext.CtxRootSpan(ctx).(trace.ReadWriteSpan) + assert.True(t, ok) + + assert.Equal(t, "MCP", span.Name()) + + for _, attr := range span.Attributes() { + if attr.Key == "system" { + assert.Equal(t, "mcpserver", attr.Value.AsString()) + } + if attr.Key == "mcp.transport" { + assert.Equal(t, "streamable-http", attr.Value.AsString()) + } + if attr.Key == "mcp.requestID" { + assert.Equal(t, "test-request-id", attr.Value.AsString()) + } + } + + log.CtxLogger(ctx).Info().Msg("test log") + + logtest.AssertHasLogRecord(t, lb, map[string]any{ + "level": "info", + "system": "mcpserver", + "mcpTransport": "streamable-http", + "mcpRequestID": "test-request-id", + "message": "test log", + }) + + gm.AssertExpectations(t) + }) + + t.Run("with provided request id and hook", func(t *testing.T) { + t.Parallel() + + gm := new(generatorMock) + gm.AssertNotCalled(t, "Generate") + + tp := trace.NewTracerProvider() + + tmp := propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}) + + lb := logtest.NewDefaultTestLogBuffer() + lg, err := log.NewDefaultLoggerFactory().Create(log.WithOutputWriter(lb)) + assert.NoError(t, err) + + hk := hook.NewSimpleMCPStreamableHTTPServerContextHook() + + handler := stream.NewDefaultMCPStreamableHTTPServerContextHandler(gm, tp, tmp, lg, hk) + + req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req.Header.Set("X-Request-Id", "test-request-id") + + ctx := handler.Handle()(context.Background(), req) + + assert.Equal(t, "test-request-id", servercontext.CtxRequestId(ctx)) + + span, ok := servercontext.CtxRootSpan(ctx).(trace.ReadWriteSpan) + assert.True(t, ok) + + assert.Equal(t, "MCP", span.Name()) + + for _, attr := range span.Attributes() { + if attr.Key == "system" { + assert.Equal(t, "mcpserver", attr.Value.AsString()) + } + if attr.Key == "mcp.transport" { + assert.Equal(t, "streamable-http", attr.Value.AsString()) + } + if attr.Key == "mcp.requestID" { + assert.Equal(t, "test-request-id", attr.Value.AsString()) + } + } + + log.CtxLogger(ctx).Info().Msg("test log") + + logtest.AssertHasLogRecord(t, lb, map[string]any{ + "level": "info", + "system": "mcpserver", + "mcpTransport": "streamable-http", + "mcpRequestID": "test-request-id", + "message": "test log", + }) + + //nolint:forcetypeassert + assert.Equal(t, "bar", ctx.Value("foo").(string)) + + gm.AssertExpectations(t) + }) +} diff --git a/fxmcpserver/server/stream/factory.go b/fxmcpserver/server/stream/factory.go new file mode 100644 index 00000000..72acd662 --- /dev/null +++ b/fxmcpserver/server/stream/factory.go @@ -0,0 +1,76 @@ +package stream + +import ( + "github.com/ankorstore/yokai/config" + "github.com/mark3labs/mcp-go/server" + "time" +) + +const ( + DefaultAddr = ":8083" + DefaultBasePath = "/mcp" + DefaultKeepAliveInterval = 10 * time.Second +) + +var _ MCPStreamableHTTPServerFactory = (*DefaultMCPStreamableHTTPServerFactory)(nil) + +// MCPStreamableHTTPServerFactory is the interface for MCP StreamableHTTP server factories. +type MCPStreamableHTTPServerFactory interface { + Create(mcpServer *server.MCPServer, options ...server.StreamableHTTPOption) *MCPStreamableHTTPServer +} + +// DefaultMCPStreamableHTTPServerFactory is the default MCPStreamableHTTPServerFactory implementation. +type DefaultMCPStreamableHTTPServerFactory struct { + config *config.Config +} + +// NewDefaultMCPStreamableHTTPServerFactory returns a new DefaultMCPStreamableHTTPServerFactory instance. +func NewDefaultMCPStreamableHTTPServerFactory(config *config.Config) *DefaultMCPStreamableHTTPServerFactory { + return &DefaultMCPStreamableHTTPServerFactory{ + config: config, + } +} + +// Create returns a new MCPStreamableHTTPServer instance. +func (f *DefaultMCPStreamableHTTPServerFactory) Create(mcpServer *server.MCPServer, options ...server.StreamableHTTPOption) *MCPStreamableHTTPServer { + addr := f.config.GetString("modules.mcp.server.transport.stream.address") + if addr == "" { + addr = DefaultAddr + } + + stateless := f.config.GetBool("modules.mcp.server.transport.stream.stateless") + + basePath := f.config.GetString("modules.mcp.server.transport.stream.base_path") + if basePath == "" { + basePath = DefaultBasePath + } + + keepAlive := f.config.GetBool("modules.mcp.server.transport.stream.keep_alive") + + keepAliveInterval := DefaultKeepAliveInterval + keepAliveIntervalConfig := f.config.GetInt("modules.mcp.server.transport.stream.keep_alive_interval") + if keepAliveIntervalConfig != 0 { + keepAliveInterval = time.Duration(keepAliveIntervalConfig) * time.Second + } + + srvConfig := MCPStreamableHTTPServerConfig{ + Address: addr, + Stateless: stateless, + BasePath: basePath, + KeepAlive: keepAlive, + KeepAliveInterval: keepAliveInterval, + } + + srvOptions := []server.StreamableHTTPOption{ + server.WithStateLess(srvConfig.Stateless), + server.WithEndpointPath(srvConfig.BasePath), + } + + if srvConfig.KeepAlive { + srvOptions = append(srvOptions, server.WithHeartbeatInterval(srvConfig.KeepAliveInterval)) + } + + srvOptions = append(srvOptions, options...) + + return NewMCPStreamableHTTPServer(mcpServer, srvConfig, srvOptions...) +} diff --git a/fxmcpserver/server/stream/factory_test.go b/fxmcpserver/server/stream/factory_test.go new file mode 100644 index 00000000..1fe6de07 --- /dev/null +++ b/fxmcpserver/server/stream/factory_test.go @@ -0,0 +1,35 @@ +package stream_test + +import ( + "github.com/ankorstore/yokai/fxmcpserver/server/stream" + "testing" + + "github.com/ankorstore/yokai/config" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" +) + +func TestDefaultMCPStreamableHTTPServerFactory_Create(t *testing.T) { + t.Parallel() + + cfg, err := config.NewDefaultConfigFactory().Create( + config.WithFilePaths("../../testdata/config"), + ) + assert.NoError(t, err) + + mcpSrv := &server.MCPServer{} + + fac := stream.NewDefaultMCPStreamableHTTPServerFactory(cfg) + + srv := fac.Create(mcpSrv) + + assert.IsType(t, (*server.StreamableHTTPServer)(nil), srv.Server()) + + assert.Equal(t, ":0", srv.Config().Address) + assert.True(t, srv.Config().Stateless) + assert.Equal(t, stream.DefaultBasePath, srv.Config().BasePath) + assert.True(t, srv.Config().KeepAlive) + assert.Equal(t, stream.DefaultKeepAliveInterval, srv.Config().KeepAliveInterval) + + assert.False(t, srv.Running()) +} diff --git a/fxmcpserver/server/stream/server.go b/fxmcpserver/server/stream/server.go new file mode 100644 index 00000000..c6d426f5 --- /dev/null +++ b/fxmcpserver/server/stream/server.go @@ -0,0 +1,110 @@ +package stream + +import ( + "context" + "github.com/ankorstore/yokai/log" + "github.com/mark3labs/mcp-go/server" + "sync" + "time" +) + +// MCPStreamableHTTPServerConfig is the MCP StreamableHTTP server configuration. +type MCPStreamableHTTPServerConfig struct { + Address string + Stateless bool + BasePath string + KeepAlive bool + KeepAliveInterval time.Duration +} + +// MCPStreamableHTTPServer is the MCP StreamableHTTP server. +type MCPStreamableHTTPServer struct { + server *server.StreamableHTTPServer + config MCPStreamableHTTPServerConfig + mutex sync.RWMutex + running bool +} + +// NewMCPStreamableHTTPServer returns a new MCPStreamableHTTPServer instance. +func NewMCPStreamableHTTPServer(mcpServer *server.MCPServer, config MCPStreamableHTTPServerConfig, opts ...server.StreamableHTTPOption) *MCPStreamableHTTPServer { + streamableHTTPServer := server.NewStreamableHTTPServer(mcpServer, opts...) + + return &MCPStreamableHTTPServer{ + server: streamableHTTPServer, + config: config, + } +} + +// Server returns the MCPStreamableHTTPServer underlying server. +func (s *MCPStreamableHTTPServer) Server() *server.StreamableHTTPServer { + return s.server +} + +// Config returns the MCPStreamableHTTPServer config. +func (s *MCPStreamableHTTPServer) Config() MCPStreamableHTTPServerConfig { + return s.config +} + +// Start starts the MCPStreamableHTTPServer. +func (s *MCPStreamableHTTPServer) Start(ctx context.Context) error { + logger := log.CtxLogger(ctx) + + logger.Info().Msgf("starting MCP StreamableHTTP server on %s", s.config.Address) + + s.mutex.Lock() + s.running = true + s.mutex.Unlock() + + err := s.server.Start(s.config.Address) + if err != nil { + logger.Error().Err(err).Msgf("failed to start MCP StreamableHTTP server") + + s.mutex.Lock() + s.running = false + s.mutex.Unlock() + } + + return err +} + +// Stop stops the MCPStreamableHTTPServer. +func (s *MCPStreamableHTTPServer) Stop(ctx context.Context) error { + logger := log.CtxLogger(ctx) + + logger.Info().Msg("stopping MCP StreamableHTTP server") + + s.mutex.Lock() + s.running = false + s.mutex.Unlock() + + err := s.server.Shutdown(ctx) + if err != nil { + logger.Error().Err(err).Msgf("failed to stop MCP StreamableHTTP server") + } + + return err +} + +// Running returns true if the MCPStreamableHTTPServer is running. +func (s *MCPStreamableHTTPServer) Running() bool { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.running +} + +// Info returns the MCPStreamableHTTPServer information. +func (s *MCPStreamableHTTPServer) Info() map[string]any { + return map[string]any{ + "config": map[string]any{ + "address": s.config.Address, + "stateless": s.config.Stateless, + "base_path": s.config.BasePath, + "keep_alive": s.config.KeepAlive, + "keep_alive_interval": s.config.KeepAliveInterval.Seconds(), + }, + "status": map[string]any{ + "running": s.Running(), + }, + } +} diff --git a/fxmcpserver/server/stream/server_test.go b/fxmcpserver/server/stream/server_test.go new file mode 100644 index 00000000..a9ed57d0 --- /dev/null +++ b/fxmcpserver/server/stream/server_test.go @@ -0,0 +1,76 @@ +package stream_test + +import ( + "context" + "github.com/ankorstore/yokai/fxmcpserver/server/stream" + "testing" + "time" + + "github.com/ankorstore/yokai/config" + "github.com/ankorstore/yokai/log" + "github.com/ankorstore/yokai/log/logtest" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" +) + +func TestMCPStreamableHTTPServer(t *testing.T) { + t.Parallel() + + cfg, err := config.NewDefaultConfigFactory().Create( + config.WithFilePaths("../../testdata/config"), + ) + assert.NoError(t, err) + + lb := logtest.NewDefaultTestLogBuffer() + lg, err := log.NewDefaultLoggerFactory().Create(log.WithOutputWriter(lb)) + assert.NoError(t, err) + + mcpSrv := server.NewMCPServer("test-server", "1.0.0") + + srv := stream.NewDefaultMCPStreamableHTTPServerFactory(cfg).Create(mcpSrv) + + assert.False(t, srv.Running()) + + assert.Equal( + t, + map[string]any{ + "config": map[string]any{ + "address": ":0", + "stateless": true, + "base_path": stream.DefaultBasePath, + "keep_alive": true, + "keep_alive_interval": stream.DefaultKeepAliveInterval.Seconds(), + }, + "status": map[string]any{ + "running": false, + }, + }, + srv.Info(), + ) + + ctx := lg.WithContext(context.Background()) + + //nolint:errcheck + go srv.Start(ctx) + + time.Sleep(1 * time.Millisecond) + + assert.True(t, srv.Running()) + + logtest.AssertHasLogRecord(t, lb, map[string]any{ + "level": "info", + "message": "starting MCP StreamableHTTP server on :0", + }) + + err = srv.Stop(ctx) + assert.NoError(t, err) + + time.Sleep(1 * time.Millisecond) + + assert.False(t, srv.Running()) + + logtest.AssertHasLogRecord(t, lb, map[string]any{ + "level": "info", + "message": "stopping MCP StreamableHTTP server", + }) +} diff --git a/fxmcpserver/testdata/config/config.yaml b/fxmcpserver/testdata/config/config.yaml index d820dd76..74585533 100644 --- a/fxmcpserver/testdata/config/config.yaml +++ b/fxmcpserver/testdata/config/config.yaml @@ -11,6 +11,13 @@ modules: prompts: true tools: true transport: + stream: + expose: true + address: ":0" + stateless: true + base_path: "/mcp" + keep_alive: true + keep_alive_interval: 10 sse: expose: true address: ":0" diff --git a/fxmcpserver/testdata/hook/simple.go b/fxmcpserver/testdata/hook/sse.go similarity index 100% rename from fxmcpserver/testdata/hook/simple.go rename to fxmcpserver/testdata/hook/sse.go diff --git a/fxmcpserver/testdata/hook/stream.go b/fxmcpserver/testdata/hook/stream.go new file mode 100644 index 00000000..9a9a0e69 --- /dev/null +++ b/fxmcpserver/testdata/hook/stream.go @@ -0,0 +1,20 @@ +package hook + +import ( + "context" + "net/http" + + "github.com/mark3labs/mcp-go/server" +) + +type SimpleMCPStreamableHTTPServerContextHook struct{} + +func NewSimpleMCPStreamableHTTPServerContextHook() *SimpleMCPStreamableHTTPServerContextHook { + return &SimpleMCPStreamableHTTPServerContextHook{} +} + +func (p *SimpleMCPStreamableHTTPServerContextHook) Handle() server.HTTPContextFunc { + return func(ctx context.Context, r *http.Request) context.Context { + return context.WithValue(ctx, "foo", "bar") + } +} diff --git a/fxmcpserver/testdata/tool/advanced.go b/fxmcpserver/testdata/tool/advanced.go index 467785b5..57020774 100644 --- a/fxmcpserver/testdata/tool/advanced.go +++ b/fxmcpserver/testdata/tool/advanced.go @@ -42,7 +42,7 @@ func (t *AdvancedTestTool) Handle() server.ToolHandlerFunc { log.CtxLogger(ctx).Info().Msg("AdvancedTestTool.Handle") - shouldFail := request.Params.Arguments["shouldFail"].(string) + shouldFail := request.GetArguments()["shouldFail"].(string) if shouldFail == "true" { return nil, fmt.Errorf("advanced tool test failure") }