1
1
use bytes:: { Bytes , BytesMut } ;
2
+ use reqwest:: header:: { HeaderMap , HeaderValue , ACCEPT } ;
2
3
use reqwest:: Client ;
3
4
use std:: time:: Duration ;
4
5
use tokio:: sync:: { mpsc, oneshot} ;
@@ -39,11 +40,15 @@ impl SseStream {
39
40
& self ,
40
41
mut endpoint_event_tx : Option < oneshot:: Sender < Option < String > > > ,
41
42
cancellation_token : CancellationToken ,
43
+ custom_headers : & Option < HeaderMap > ,
42
44
) {
43
45
let mut retry_count = 0 ;
44
46
let mut buffer = BytesMut :: with_capacity ( BUFFER_CAPACITY ) ;
45
47
let mut endpoint_event_received = false ;
46
48
49
+ let mut request_headers: HeaderMap = custom_headers. to_owned ( ) . unwrap_or_default ( ) ;
50
+ request_headers. insert ( ACCEPT , HeaderValue :: from_static ( "text/event-stream" ) ) ;
51
+
47
52
// Main loop for reconnection attempts
48
53
loop {
49
54
// Check for cancellation before attempting connection
@@ -56,7 +61,7 @@ impl SseStream {
56
61
let response = match self
57
62
. sse_client
58
63
. get ( & self . sse_url )
59
- . header ( "Accept" , "text/event-stream" )
64
+ . headers ( request_headers . clone ( ) )
60
65
. send ( )
61
66
. await
62
67
{
@@ -86,7 +91,18 @@ impl SseStream {
86
91
chunk = stream. next( ) => {
87
92
match chunk {
88
93
Some ( chunk) => chunk,
89
- None => break , // Stream ended, break from inner loop to reconnect
94
+ None => {
95
+ if retry_count >= self . max_retries {
96
+ tracing:: error!( "Max retries ({}) reached, giving up" , self . max_retries) ;
97
+ if let Some ( tx) = endpoint_event_tx. take( ) {
98
+ let _ = tx. send( None ) ;
99
+ }
100
+ return ;
101
+ }
102
+ retry_count += 1 ;
103
+ time:: sleep( self . retry_delay) . await ;
104
+ break ; // Stream ended, break from inner loop to reconnect
105
+ }
90
106
}
91
107
}
92
108
// Wait for cancellation
@@ -177,4 +193,81 @@ impl SseStream {
177
193
}
178
194
179
195
#[ cfg( test) ]
180
- mod tests { }
196
+ mod tests {
197
+ use super :: * ;
198
+ use crate :: utils:: CancellationTokenSource ;
199
+ use reqwest:: header:: { HeaderMap , HeaderValue } ;
200
+ use tokio:: time:: Duration ;
201
+ use wiremock:: matchers:: { header, method, path} ;
202
+ use wiremock:: { Mock , MockServer , ResponseTemplate } ;
203
+
204
+ #[ tokio:: test]
205
+ async fn test_sse_client_sends_custom_headers_on_connection ( ) {
206
+ // Start WireMock server
207
+ let mock_server = MockServer :: builder ( ) . start ( ) . await ;
208
+
209
+ // Create WireMock stub with connection close
210
+ Mock :: given ( method ( "GET" ) )
211
+ . and ( path ( "/sse" ) )
212
+ . and ( header ( "Accept" , "text/event-stream" ) )
213
+ . and ( header ( "X-Custom-Header" , "CustomValue" ) )
214
+ . respond_with (
215
+ ResponseTemplate :: new ( 200 )
216
+ . set_body_string ( "event: endpoint\n data: mock-endpoint\n \n " )
217
+ . append_header ( "Content-Type" , "text/event-stream" )
218
+ . append_header ( "Connection" , "close" ) , // Ensure connection closes
219
+ )
220
+ . expect ( 1 ) // Expect exactly one request
221
+ . mount ( & mock_server)
222
+ . await ;
223
+
224
+ // Create custom headers
225
+ let mut custom_headers = HeaderMap :: new ( ) ;
226
+ custom_headers. insert ( "X-Custom-Header" , HeaderValue :: from_static ( "CustomValue" ) ) ;
227
+
228
+ // Create channel and SseStream
229
+ let ( read_tx, _read_rx) = mpsc:: channel :: < Bytes > ( 64 ) ;
230
+ let sse = SseStream {
231
+ sse_client : reqwest:: Client :: new ( ) ,
232
+ sse_url : format ! ( "{}/sse" , mock_server. uri( ) ) ,
233
+ max_retries : 0 , // to receive one request only
234
+ retry_delay : Duration :: from_millis ( 100 ) ,
235
+ read_tx,
236
+ } ;
237
+
238
+ // Create cancellation token and endpoint channel
239
+ let ( cancellation_source, cancellation_token) = CancellationTokenSource :: new ( ) ;
240
+ let ( endpoint_event_tx, endpoint_event_rx) = oneshot:: channel :: < Option < String > > ( ) ;
241
+
242
+ // Spawn the run method
243
+ let sse_task = tokio:: spawn ( {
244
+ async move {
245
+ sse. run (
246
+ Some ( endpoint_event_tx) ,
247
+ cancellation_token,
248
+ & Some ( custom_headers) ,
249
+ )
250
+ . await ;
251
+ }
252
+ } ) ;
253
+
254
+ // Wait for the endpoint event or timeout
255
+ let event_result =
256
+ tokio:: time:: timeout ( Duration :: from_millis ( 500 ) , endpoint_event_rx) . await ;
257
+
258
+ // Cancel the task to ensure loop exits
259
+ let _ = cancellation_source. cancel ( ) ;
260
+
261
+ // Wait for the task to complete with a timeout
262
+ match tokio:: time:: timeout ( Duration :: from_secs ( 1 ) , sse_task) . await {
263
+ Ok ( result) => result. unwrap ( ) ,
264
+ Err ( _) => panic ! ( "Test timed out after 1 second" ) ,
265
+ }
266
+
267
+ // Verify the endpoint event was received
268
+ match event_result {
269
+ Ok ( Ok ( Some ( event) ) ) => assert_eq ! ( event, "mock-endpoint" , "Expected endpoint event" ) ,
270
+ _ => panic ! ( "Did not receive expected endpoint event" ) ,
271
+ }
272
+ }
273
+ }
0 commit comments