Skip to content

Commit 166939e

Browse files
authored
fix: ensure custom headers are included in initial SSE connection to remote MCP Server (#46)
* fix: sse client custom headers * chore: cleanup * chore: improve retry logic
1 parent 5711808 commit 166939e

File tree

2 files changed

+99
-4
lines changed

2 files changed

+99
-4
lines changed

crates/rust-mcp-transport/src/client_sse.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ where
206206
let max_retries = self.max_retries;
207207
let retry_delay = self.retry_delay;
208208

209+
let custom_headers = self.custom_headers.clone();
210+
209211
let read_stream = SseStream {
210212
sse_client,
211213
sse_url,
@@ -218,7 +220,7 @@ where
218220
let cancellation_token_sse = cancellation_token.clone();
219221
let sse_task_handle = tokio::spawn(async move {
220222
read_stream
221-
.run(endpoint_event_tx, cancellation_token_sse)
223+
.run(endpoint_event_tx, cancellation_token_sse, &custom_headers)
222224
.await;
223225
});
224226
let mut sse_task_lock = self.sse_task.write().await;

crates/rust-mcp-transport/src/utils/sse_stream.rs

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use bytes::{Bytes, BytesMut};
2+
use reqwest::header::{HeaderMap, HeaderValue, ACCEPT};
23
use reqwest::Client;
34
use std::time::Duration;
45
use tokio::sync::{mpsc, oneshot};
@@ -39,11 +40,15 @@ impl SseStream {
3940
&self,
4041
mut endpoint_event_tx: Option<oneshot::Sender<Option<String>>>,
4142
cancellation_token: CancellationToken,
43+
custom_headers: &Option<HeaderMap>,
4244
) {
4345
let mut retry_count = 0;
4446
let mut buffer = BytesMut::with_capacity(BUFFER_CAPACITY);
4547
let mut endpoint_event_received = false;
4648

49+
let mut request_headers: HeaderMap = custom_headers.to_owned().unwrap_or_default();
50+
request_headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
51+
4752
// Main loop for reconnection attempts
4853
loop {
4954
// Check for cancellation before attempting connection
@@ -56,7 +61,7 @@ impl SseStream {
5661
let response = match self
5762
.sse_client
5863
.get(&self.sse_url)
59-
.header("Accept", "text/event-stream")
64+
.headers(request_headers.clone())
6065
.send()
6166
.await
6267
{
@@ -86,7 +91,18 @@ impl SseStream {
8691
chunk = stream.next() => {
8792
match chunk {
8893
Some(chunk) => chunk,
89-
None => break, // Stream ended, break from inner loop to reconnect
94+
None => {
95+
if retry_count >= self.max_retries {
96+
tracing::error!("Max retries ({}) reached, giving up",self.max_retries);
97+
if let Some(tx) = endpoint_event_tx.take() {
98+
let _ = tx.send(None);
99+
}
100+
return;
101+
}
102+
retry_count += 1;
103+
time::sleep(self.retry_delay).await;
104+
break; // Stream ended, break from inner loop to reconnect
105+
}
90106
}
91107
}
92108
// Wait for cancellation
@@ -177,4 +193,81 @@ impl SseStream {
177193
}
178194

179195
#[cfg(test)]
180-
mod tests {}
196+
mod tests {
197+
use super::*;
198+
use crate::utils::CancellationTokenSource;
199+
use reqwest::header::{HeaderMap, HeaderValue};
200+
use tokio::time::Duration;
201+
use wiremock::matchers::{header, method, path};
202+
use wiremock::{Mock, MockServer, ResponseTemplate};
203+
204+
#[tokio::test]
205+
async fn test_sse_client_sends_custom_headers_on_connection() {
206+
// Start WireMock server
207+
let mock_server = MockServer::builder().start().await;
208+
209+
// Create WireMock stub with connection close
210+
Mock::given(method("GET"))
211+
.and(path("/sse"))
212+
.and(header("Accept", "text/event-stream"))
213+
.and(header("X-Custom-Header", "CustomValue"))
214+
.respond_with(
215+
ResponseTemplate::new(200)
216+
.set_body_string("event: endpoint\ndata: mock-endpoint\n\n")
217+
.append_header("Content-Type", "text/event-stream")
218+
.append_header("Connection", "close"), // Ensure connection closes
219+
)
220+
.expect(1) // Expect exactly one request
221+
.mount(&mock_server)
222+
.await;
223+
224+
// Create custom headers
225+
let mut custom_headers = HeaderMap::new();
226+
custom_headers.insert("X-Custom-Header", HeaderValue::from_static("CustomValue"));
227+
228+
// Create channel and SseStream
229+
let (read_tx, _read_rx) = mpsc::channel::<Bytes>(64);
230+
let sse = SseStream {
231+
sse_client: reqwest::Client::new(),
232+
sse_url: format!("{}/sse", mock_server.uri()),
233+
max_retries: 0, // to receive one request only
234+
retry_delay: Duration::from_millis(100),
235+
read_tx,
236+
};
237+
238+
// Create cancellation token and endpoint channel
239+
let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
240+
let (endpoint_event_tx, endpoint_event_rx) = oneshot::channel::<Option<String>>();
241+
242+
// Spawn the run method
243+
let sse_task = tokio::spawn({
244+
async move {
245+
sse.run(
246+
Some(endpoint_event_tx),
247+
cancellation_token,
248+
&Some(custom_headers),
249+
)
250+
.await;
251+
}
252+
});
253+
254+
// Wait for the endpoint event or timeout
255+
let event_result =
256+
tokio::time::timeout(Duration::from_millis(500), endpoint_event_rx).await;
257+
258+
// Cancel the task to ensure loop exits
259+
let _ = cancellation_source.cancel();
260+
261+
// Wait for the task to complete with a timeout
262+
match tokio::time::timeout(Duration::from_secs(1), sse_task).await {
263+
Ok(result) => result.unwrap(),
264+
Err(_) => panic!("Test timed out after 1 second"),
265+
}
266+
267+
// Verify the endpoint event was received
268+
match event_result {
269+
Ok(Ok(Some(event))) => assert_eq!(event, "mock-endpoint", "Expected endpoint event"),
270+
_ => panic!("Did not receive expected endpoint event"),
271+
}
272+
}
273+
}

0 commit comments

Comments
 (0)