Skip to content

Add client-side Streamable HTTP transport support #356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ internal sealed class StreamableHttpHandler(
ILoggerFactory loggerFactory,
IServiceProvider applicationServices)
{
private static JsonTypeInfo<JsonRpcError> s_errorTypeInfo = GetRequiredJsonTypeInfo<JsonRpcError>();
private static MediaTypeHeaderValue ApplicationJsonMediaType = new("application/json");
private static MediaTypeHeaderValue TextEventStreamMediaType = new("text/event-stream");
private static readonly JsonTypeInfo<JsonRpcError> s_errorTypeInfo = GetRequiredJsonTypeInfo<JsonRpcError>();
private static readonly MediaTypeHeaderValue s_applicationJsonMediaType = new("application/json");
private static readonly MediaTypeHeaderValue s_textEventStreamMediaType = new("text/event-stream");

public ConcurrentDictionary<string, HttpMcpSession<StreamableHttpServerTransport>> Sessions { get; } = new(StringComparer.Ordinal);

Expand All @@ -36,7 +36,7 @@ public async Task HandlePostRequestAsync(HttpContext context)
// so we have to do this manually. The spec doesn't mandate that servers MUST reject these requests,
// but it's probably good to at least start out trying to be strict.
var acceptHeaders = context.Request.GetTypedHeaders().Accept;
if (!acceptHeaders.Contains(ApplicationJsonMediaType) || !acceptHeaders.Contains(TextEventStreamMediaType))
if (!acceptHeaders.Contains(s_applicationJsonMediaType) || !acceptHeaders.Contains(s_textEventStreamMediaType))
{
await WriteJsonRpcErrorAsync(context,
"Not Acceptable: Client must accept both application/json and text/event-stream",
Expand Down Expand Up @@ -64,7 +64,7 @@ await WriteJsonRpcErrorAsync(context,
public async Task HandleGetRequestAsync(HttpContext context)
{
var acceptHeaders = context.Request.GetTypedHeaders().Accept;
if (!acceptHeaders.Contains(TextEventStreamMediaType))
if (!acceptHeaders.Contains(s_textEventStreamMediaType))
{
await WriteJsonRpcErrorAsync(context,
"Not Acceptable: Client must accept text/event-stream",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,40 +87,17 @@ public override async Task SendMessageAsync(
messageId = messageWithId.Id.ToString();
}

var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint)
using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint)
{
Content = content,
};
CopyAdditionalHeaders(httpRequestMessage.Headers);
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders);
var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);

response.EnsureSuccessStatusCode();

var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);

// Check if the message was an initialize request
if (message is JsonRpcRequest request && request.Method == RequestMethods.Initialize)
{
// If the response is not a JSON-RPC response, it is an SSE message
if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase))
{
LogAcceptedPost(Name, messageId);
// The response will arrive as an SSE message
}
else
{
JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ??
throw new InvalidOperationException("Failed to initialize client");

LogTransportReceivedMessage(Name, messageId);
await WriteMessageAsync(initializeResponse, cancellationToken).ConfigureAwait(false);
LogTransportMessageWritten(Name, messageId);
}

return;
}

// Otherwise, check if the response was accepted (the response will come as an SSE message)
if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase))
{
LogAcceptedPost(Name, messageId);
Expand Down Expand Up @@ -177,17 +154,13 @@ public override async ValueTask DisposeAsync()
}
}

internal Uri? MessageEndpoint => _messageEndpoint;

internal SseClientTransportOptions Options => _options;

private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
{
try
{
using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint);
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
CopyAdditionalHeaders(request.Headers);
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders);

using var response = await _httpClient.SendAsync(
request,
Expand Down Expand Up @@ -251,15 +224,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation
return;
}

string messageId = "(no id)";
if (message is JsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}

LogTransportReceivedMessage(Name, messageId);
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
LogTransportMessageWritten(Name, messageId);
}
catch (JsonException ex)
{
Expand Down Expand Up @@ -290,20 +255,6 @@ private void HandleEndpointEvent(string data)
_connectionEstablished.TrySetResult(true);
}

private void CopyAdditionalHeaders(HttpRequestHeaders headers)
{
if (_options.AdditionalHeaders is not null)
{
foreach (var header in _options.AdditionalHeaders)
{
if (!headers.TryAddWithoutValidation(header.Key, header.Value))
{
throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}.");
}
}
}
}

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} accepted SSE transport POST for message ID '{MessageId}'.")]
private partial void LogAcceptedPost(string endpointName, string messageId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient
/// <inheritdoc />
public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken = default)
{
if (_options.UseStreamableHttp)
{
return new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, Name);
}

var sessionTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, Name);

try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,20 @@ public required Uri Endpoint
}
}

/// <summary>
/// Gets or sets a value indicating whether to use "Streamable HTTP" for the transport rather than "HTTP with SSE". Defaults to false.
/// <see href="https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http">Streamable HTTP transport specification</see>.
/// <see href="https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#http-with-sse">HTTP with SSE transport specification</see>.
/// </summary>
public bool UseStreamableHttp { get; init; }

/// <summary>
/// Gets a transport identifier used for logging purposes.
/// </summary>
public string? Name { get; init; }

/// <summary>
/// Gets or sets a timeout used to establish the initial connection to the SSE server.
/// Gets or sets a timeout used to establish the initial connection to the SSE server. Defaults to 30 seconds.
/// </summary>
/// <remarks>
/// This timeout controls how long the client waits for:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,7 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati
var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)));
if (message != null)
{
string messageId = "(no id)";
if (message is JsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}

LogTransportReceivedMessage(Name, messageId);
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
LogTransportMessageWritten(Name, messageId);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,7 @@ private async Task ReadMessagesAsync()
{
if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) is JsonRpcMessage message)
{
string messageId = "(no id)";
if (message is JsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}

LogTransportReceivedMessage(Name, messageId);
await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false);
LogTransportMessageWritten(Name, messageId);
}
else
{
Expand Down
Loading