diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md index e5aaa6525..f850b7286 100644 --- a/examples/servers/simple-streamablehttp/README.md +++ b/examples/servers/simple-streamablehttp/README.md @@ -9,6 +9,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en - Task management with anyio task groups - Ability to send multiple notifications over time to the client - Proper resource cleanup and lifespan management +- Resumability support via InMemoryEventStore ## Usage @@ -32,6 +33,23 @@ The server exposes a tool named "start-notification-stream" that accepts three a - `count`: Number of notifications to send (e.g., 5) - `caller`: Identifier string for the caller +## Resumability Support + +This server includes resumability support through the InMemoryEventStore. This enables clients to: + +- Reconnect to the server after a disconnection +- Resume event streaming from where they left off using the Last-Event-ID header + + +The server will: +- Generate unique event IDs for each SSE message +- Store events in memory for later replay +- Replay missed events when a client reconnects with a Last-Event-ID header + +Note: The InMemoryEventStore is designed for demonstration purposes only. For production use, consider implementing a persistent storage solution. + + + ## Client -You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector] \ No newline at end of file +You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use [Inspector](https://github.com/modelcontextprotocol/inspector) \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py new file mode 100644 index 000000000..28c58149f --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py @@ -0,0 +1,105 @@ +""" +In-memory event store for demonstrating resumability functionality. + +This is a simple implementation intended for examples and testing, +not for production use where a persistent storage solution would be more appropriate. +""" + +import logging +from collections import deque +from dataclasses import dataclass +from uuid import uuid4 + +from mcp.server.streamable_http import ( + EventCallback, + EventId, + EventMessage, + EventStore, + StreamId, +) +from mcp.types import JSONRPCMessage + +logger = logging.getLogger(__name__) + + +@dataclass +class EventEntry: + """ + Represents an event entry in the event store. + """ + + event_id: EventId + stream_id: StreamId + message: JSONRPCMessage + + +class InMemoryEventStore(EventStore): + """ + Simple in-memory implementation of the EventStore interface for resumability. + This is primarily intended for examples and testing, not for production use + where a persistent storage solution would be more appropriate. + + This implementation keeps only the last N events per stream for memory efficiency. + """ + + def __init__(self, max_events_per_stream: int = 100): + """Initialize the event store. + + Args: + max_events_per_stream: Maximum number of events to keep per stream + """ + self.max_events_per_stream = max_events_per_stream + # for maintaining last N events per stream + self.streams: dict[StreamId, deque[EventEntry]] = {} + # event_id -> EventEntry for quick lookup + self.event_index: dict[EventId, EventEntry] = {} + + async def store_event( + self, stream_id: StreamId, message: JSONRPCMessage + ) -> EventId: + """Stores an event with a generated event ID.""" + event_id = str(uuid4()) + event_entry = EventEntry( + event_id=event_id, stream_id=stream_id, message=message + ) + + # Get or create deque for this stream + if stream_id not in self.streams: + self.streams[stream_id] = deque(maxlen=self.max_events_per_stream) + + # If deque is full, the oldest event will be automatically removed + # We need to remove it from the event_index as well + if len(self.streams[stream_id]) == self.max_events_per_stream: + oldest_event = self.streams[stream_id][0] + self.event_index.pop(oldest_event.event_id, None) + + # Add new event + self.streams[stream_id].append(event_entry) + self.event_index[event_id] = event_entry + + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replays events that occurred after the specified event ID.""" + if last_event_id not in self.event_index: + logger.warning(f"Event ID {last_event_id} not found in store") + return None + + # Get the stream and find events after the last one + last_event = self.event_index[last_event_id] + stream_id = last_event.stream_id + stream_events = self.streams.get(last_event.stream_id, deque()) + + # Events in deque are already in chronological order + found_last = False + for event in stream_events: + if found_last: + await send_callback(EventMessage(event.message, event.event_id)) + elif event.event_id == last_event_id: + found_last = True + + return stream_id diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 71d4e5a37..b2079bb27 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -17,12 +17,24 @@ from starlette.responses import Response from starlette.routing import Mount +from .event_store import InMemoryEventStore + # Configure logging logger = logging.getLogger(__name__) # Global task group that will be initialized in the lifespan task_group = None +# Event store for resumability +# The InMemoryEventStore enables resumability support for StreamableHTTP transport. +# It stores SSE events with unique IDs, allowing clients to: +# 1. Receive event IDs for each SSE message +# 2. Resume streams by sending Last-Event-ID in GET requests +# 3. Replay missed events after reconnection +# Note: This in-memory implementation is for demonstration ONLY. +# For production, use a persistent storage solution. +event_store = InMemoryEventStore() + @contextlib.asynccontextmanager async def lifespan(app): @@ -79,9 +91,14 @@ async def call_tool( # Send the specified number of notifications with the given interval for i in range(count): + # Include more detailed message for resumability demonstration + notification_msg = ( + f"[{i+1}/{count}] Event from '{caller}' - " + f"Use Last-Event-ID to resume if disconnected" + ) await ctx.session.send_log_message( level="info", - data=f"Notification {i+1}/{count} from caller: {caller}", + data=notification_msg, logger="notification_stream", # Associates this notification with the original request # Ensures notifications are sent to the correct response stream @@ -90,6 +107,7 @@ async def call_tool( # - nowhere (if GET request isn't supported) related_request_id=ctx.request_id, ) + logger.debug(f"Sent notification {i+1}/{count} for caller: {caller}") if i < count - 1: # Don't wait after the last notification await anyio.sleep(interval) @@ -163,8 +181,10 @@ async def handle_streamable_http(scope, receive, send): http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, is_json_response_enabled=json_response, + event_store=event_store, # Enable resumability ) server_instances[http_transport.mcp_session_id] = http_transport + logger.info(f"Created new transport with session ID: {new_session_id}") async with http_transport.connect() as streams: read_stream, write_stream = streams diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 8faff0162..ca707e27b 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -10,10 +10,11 @@ import json import logging import re -from collections.abc import AsyncGenerator +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager +from dataclasses import dataclass from http import HTTPStatus -from typing import Any import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -57,6 +58,63 @@ # Pattern ensures entire string contains only valid characters by using ^ and $ anchors SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") +# Type aliases +StreamId = str +EventId = str + + +@dataclass +class EventMessage: + """ + A JSONRPCMessage with an optional event ID for stream resumability. + """ + + message: JSONRPCMessage + event_id: str | None = None + + +EventCallback = Callable[[EventMessage], Awaitable[None]] + + +class EventStore(ABC): + """ + Interface for resumability support via event storage. + """ + + @abstractmethod + async def store_event( + self, stream_id: StreamId, message: JSONRPCMessage + ) -> EventId: + """ + Stores an event for later retrieval. + + Args: + stream_id: ID of the stream the event belongs to + message: The JSON-RPC message to store + + Returns: + The generated event ID for the stored event + """ + pass + + @abstractmethod + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """ + Replays events that occurred after the specified event ID. + + Args: + last_event_id: The ID of the last event the client received + send_callback: A callback function to send events to the client + + Returns: + The stream ID of the replayed events + """ + pass + class StreamableHTTPServerTransport: """ @@ -76,6 +134,7 @@ def __init__( self, mcp_session_id: str | None, is_json_response_enabled: bool = False, + event_store: EventStore | None = None, ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -85,6 +144,9 @@ def __init__( Must contain only visible ASCII characters (0x21-0x7E). is_json_response_enabled: If True, return JSON responses for requests instead of SSE streams. Default is False. + event_store: Event store for resumability support. If provided, + resumability will be enabled, allowing clients to + reconnect and resume messages. Raises: ValueError: If the session ID contains invalid characters. @@ -98,8 +160,9 @@ def __init__( self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled + self._event_store = event_store self._request_streams: dict[ - RequestId, MemoryObjectSendStream[JSONRPCMessage] + RequestId, MemoryObjectSendStream[EventMessage] ] = {} self._terminated = False @@ -160,6 +223,21 @@ def _get_session_id(self, request: Request) -> str | None: """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) + def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: + """Create event data dictionary from an EventMessage.""" + event_data = { + "event": "message", + "data": event_message.message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + # If an event ID was provided, include it + if event_message.event_id: + event_data["id"] = event_message.event_id + + return event_data + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Application entry point that handles all HTTP requests""" request = Request(scope, receive) @@ -308,7 +386,7 @@ async def _handle_post_request( request_id = str(message.root.id) # Create promise stream for getting response request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) + anyio.create_memory_object_stream[EventMessage](0) ) # Register this stream for the request ID @@ -323,16 +401,18 @@ async def _handle_post_request( response_message = None # Use similar approach to SSE writer for consistency - async for received_message in request_stream_reader: + async for event_message in request_stream_reader: # If it's a response, this is what we're waiting for if isinstance( - received_message.root, JSONRPCResponse | JSONRPCError + event_message.message.root, JSONRPCResponse | JSONRPCError ): - response_message = received_message + response_message = event_message.message break # For notifications and request, keep waiting else: - logger.debug(f"received: {received_message.root.method}") + logger.debug( + f"received: {event_message.message.root.method}" + ) # At this point we should have a response if response_message: @@ -366,7 +446,7 @@ async def _handle_post_request( else: # Create SSE stream sse_stream_writer, sse_stream_reader = ( - anyio.create_memory_object_stream[dict[str, Any]](0) + anyio.create_memory_object_stream[dict[str, str]](0) ) async def sse_writer(): @@ -374,20 +454,14 @@ async def sse_writer(): try: async with sse_stream_writer, request_stream_reader: # Process messages from the request-specific stream - async for received_message in request_stream_reader: + async for event_message in request_stream_reader: # Build the event data - event_data = { - "event": "message", - "data": received_message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - + event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) # If response, remove from pending streams and close if isinstance( - received_message.root, + event_message.message.root, JSONRPCResponse | JSONRPCError, ): if request_id: @@ -472,6 +546,10 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: if not await self._validate_session(request, send): return + # Handle resumability: check for Last-Event-ID header + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): + await self._replay_events(last_event_id, request, send) + return headers = { "Cache-Control": "no-cache, no-transform", @@ -493,14 +571,14 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Create SSE stream sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ - dict[str, Any] + dict[str, str] ](0) async def standalone_sse_writer(): try: # Create a standalone message stream for server-initiated messages standalone_stream_writer, standalone_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) + anyio.create_memory_object_stream[EventMessage](0) ) # Register this stream using the special key @@ -508,20 +586,14 @@ async def standalone_sse_writer(): async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream - async for received_message in standalone_stream_reader: + async for event_message in standalone_stream_reader: # For the standalone stream, we handle: # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server sends requests to client) # We should NOT receive JSONRPCResponse # Send the message via SSE - event_data = { - "event": "message", - "data": received_message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - + event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) except Exception as e: logger.exception(f"Error in standalone SSE writer: {e}") @@ -639,6 +711,82 @@ async def _validate_session(self, request: Request, send: Send) -> bool: return True + async def _replay_events( + self, last_event_id: str, request: Request, send: Send + ) -> None: + """ + Replays events that would have been sent after the specified event ID. + Only used when resumability is enabled. + """ + event_store = self._event_store + if not event_store: + return + + try: + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Create SSE stream for replay + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ + dict[str, str] + ](0) + + async def replay_sender(): + try: + async with sse_stream_writer: + # Define an async callback for sending events + async def send_event(event_message: EventMessage) -> None: + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + + # Replay past events and get the stream ID + stream_id = await event_store.replay_events_after( + last_event_id, send_event + ) + + # If stream ID not in mapping, create it + if stream_id and stream_id not in self._request_streams: + msg_writer, msg_reader = anyio.create_memory_object_stream[ + EventMessage + ](0) + self._request_streams[stream_id] = msg_writer + + # Forward messages to SSE + async with msg_reader: + async for event_message in msg_reader: + event_data = self._create_event_data(event_message) + + await sse_stream_writer.send(event_data) + except Exception as e: + logger.exception(f"Error in replay sender: {e}") + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=replay_sender, + headers=headers, + ) + + try: + await response(request.scope, request.receive, send) + except Exception as e: + logger.exception(f"Error in replay response: {e}") + + except Exception as e: + logger.exception(f"Error replaying events: {e}") + response = self._create_error_response( + f"Error replaying events: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(request.scope, request.receive, send) + @asynccontextmanager async def connect( self, @@ -691,10 +839,22 @@ async def message_router(): target_request_id = str(message.root.id) request_stream_id = target_request_id or GET_STREAM_KEY + + # Store the event if we have an event store, + # regardless of whether a client is connected + # messages will be replayed on the re-connect + event_id = None + if self._event_store: + event_id = await self._event_store.store_event( + request_stream_id, message + ) + logger.debug(f"Stored {event_id} from {request_stream_id}") + if request_stream_id in self._request_streams: try: + # Send both the message and the event ID await self._request_streams[request_stream_id].send( - message + EventMessage(message, event_id) ) except ( anyio.BrokenResourceError, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 48af09536..7331b392b 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -113,15 +113,12 @@ async def lifespan(app): async with anyio.create_task_group() as tg: task_group = tg - print("Application started, task group initialized!") try: yield finally: - print("Application shutting down, cleaning up resources...") if task_group: tg.cancel_scope.cancel() task_group = None - print("Resources cleaned up successfully.") async def handle_streamable_http(scope, receive, send): request = Request(scope, receive) @@ -148,14 +145,11 @@ async def handle_streamable_http(scope, receive, send): read_stream, write_stream = streams async def run_server(): - try: - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - ) - except Exception as e: - print(f"Server exception: {e}") + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) if task_group is None: response = Response( @@ -196,10 +190,6 @@ def run_server(port: int, is_json_response_enabled=False) -> None: port: Port to listen on. is_json_response_enabled: If True, use JSON responses instead of SSE streams. """ - print( - f"Starting test server on port {port} with " - f"json_enabled={is_json_response_enabled}" - ) app = create_app(is_json_response_enabled) # Configure server @@ -218,16 +208,12 @@ def run_server(port: int, is_json_response_enabled=False) -> None: # This is important to catch exceptions and prevent test hangs try: - print("Server starting...") server.run() - except Exception as e: - print(f"ERROR: Server failed to run: {e}") + except Exception: import traceback traceback.print_exc() - print("Server shutdown") - # Test fixtures - using same approach as SSE tests @pytest.fixture @@ -273,8 +259,6 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]: # Clean up proc.kill() proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") @pytest.fixture @@ -306,8 +290,6 @@ def json_response_server(json_server_port: int) -> Generator[None, None, None]: # Clean up proc.kill() proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") @pytest.fixture