|
1 | 1 | package api
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "bytes" |
4 | 5 | "context"
|
5 | 6 | "encoding/json"
|
6 | 7 | "fmt"
|
7 | 8 | "net/http"
|
8 | 9 | "net/url"
|
9 | 10 | "strings"
|
10 | 11 | "sync"
|
11 |
| - "sync/atomic" |
12 | 12 | "time"
|
13 | 13 |
|
| 14 | + "github.com/sirupsen/logrus" |
14 | 15 | "github.com/supabase/auth/internal/models"
|
15 | 16 | "github.com/supabase/auth/internal/observability"
|
16 | 17 | "github.com/supabase/auth/internal/security"
|
@@ -263,95 +264,122 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H
|
263 | 264 | }
|
264 | 265 | }
|
265 | 266 |
|
266 |
| -// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent |
267 |
| -// writes after the context contained in it has exceeded the deadline. If a |
268 |
| -// partial write occurs before the deadline is exceeded, but the writing is not |
269 |
| -// complete it will allow further writes. |
| 267 | +// timeoutResponseWriter is a http.ResponseWriter that queues up a response |
| 268 | +// body to be sent if the serving completes before the context has exceeded its |
| 269 | +// deadline. |
270 | 270 | type timeoutResponseWriter struct {
|
271 |
| - ctx context.Context |
272 |
| - w http.ResponseWriter |
273 |
| - wrote int32 |
274 |
| - mu sync.Mutex |
| 271 | + sync.Mutex |
| 272 | + |
| 273 | + header http.Header |
| 274 | + wroteHeader bool |
| 275 | + snapHeader http.Header // snapshot of the header at the time WriteHeader was called |
| 276 | + statusCode int |
| 277 | + buf bytes.Buffer |
275 | 278 | }
|
276 | 279 |
|
277 | 280 | func (t *timeoutResponseWriter) Header() http.Header {
|
278 |
| - t.mu.Lock() |
279 |
| - defer t.mu.Unlock() |
280 |
| - return t.w.Header() |
| 281 | + t.Lock() |
| 282 | + defer t.Unlock() |
| 283 | + |
| 284 | + return t.header |
281 | 285 | }
|
282 | 286 |
|
283 | 287 | func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) {
|
284 |
| - t.mu.Lock() |
285 |
| - defer t.mu.Unlock() |
286 |
| - if t.ctx.Err() == context.DeadlineExceeded { |
287 |
| - if atomic.LoadInt32(&t.wrote) == 0 { |
288 |
| - return 0, context.DeadlineExceeded |
289 |
| - } |
| 288 | + t.Lock() |
| 289 | + defer t.Unlock() |
290 | 290 |
|
291 |
| - // writing started before the deadline exceeded, but the |
292 |
| - // deadline came in the middle, so letting the writes go |
293 |
| - // through |
| 291 | + if !t.wroteHeader { |
| 292 | + t.WriteHeader(http.StatusOK) |
294 | 293 | }
|
295 | 294 |
|
296 |
| - t.wrote = 1 |
297 |
| - |
298 |
| - return t.w.Write(bytes) |
| 295 | + return t.buf.Write(bytes) |
299 | 296 | }
|
300 | 297 |
|
301 | 298 | func (t *timeoutResponseWriter) WriteHeader(statusCode int) {
|
302 |
| - t.mu.Lock() |
303 |
| - defer t.mu.Unlock() |
304 |
| - if t.ctx.Err() == context.DeadlineExceeded { |
305 |
| - if atomic.LoadInt32(&t.wrote) == 0 { |
306 |
| - return |
307 |
| - } |
| 299 | + t.Lock() |
| 300 | + defer t.Unlock() |
| 301 | + |
| 302 | + if t.wroteHeader { |
| 303 | + // ignore multiple calls to WriteHeader |
| 304 | + // once WriteHeader has been called once, a snapshot of the header map is taken |
| 305 | + // and saved in snapHeader to be used in finallyWrite |
| 306 | + return |
| 307 | + } |
| 308 | + t.statusCode = statusCode |
| 309 | + t.wroteHeader = true |
| 310 | + t.snapHeader = t.header.Clone() |
| 311 | +} |
| 312 | + |
| 313 | +func (t *timeoutResponseWriter) finallyWrite(w http.ResponseWriter) { |
| 314 | + t.Lock() |
| 315 | + defer t.Unlock() |
308 | 316 |
|
309 |
| - // writing started before the deadline exceeded, but the |
310 |
| - // deadline came in the middle, so letting the writes go |
311 |
| - // through |
| 317 | + dst := w.Header() |
| 318 | + for k, vv := range t.snapHeader { |
| 319 | + dst[k] = vv |
312 | 320 | }
|
313 | 321 |
|
314 |
| - t.wrote = 1 |
| 322 | + if !t.wroteHeader { |
| 323 | + t.statusCode = http.StatusOK |
| 324 | + } |
315 | 325 |
|
316 |
| - t.w.WriteHeader(statusCode) |
| 326 | + w.WriteHeader(t.statusCode) |
| 327 | + if _, err := w.Write(t.buf.Bytes()); err != nil { |
| 328 | + logrus.WithError(err).Warn("Write failed") |
| 329 | + } |
317 | 330 | }
|
318 | 331 |
|
319 |
| -func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { |
| 332 | +func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { |
320 | 333 | return func(next http.Handler) http.Handler {
|
321 | 334 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
322 | 335 | ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
323 | 336 | defer cancel()
|
324 | 337 |
|
325 | 338 | timeoutWriter := &timeoutResponseWriter{
|
326 |
| - w: w, |
327 |
| - ctx: ctx, |
| 339 | + header: make(http.Header), |
328 | 340 | }
|
329 | 341 |
|
| 342 | + panicChan := make(chan any, 1) |
| 343 | + serverDone := make(chan struct{}) |
330 | 344 | go func() {
|
331 |
| - <-ctx.Done() |
| 345 | + defer func() { |
| 346 | + if p := recover(); p != nil { |
| 347 | + panicChan <- p |
| 348 | + } |
| 349 | + }() |
332 | 350 |
|
| 351 | + next.ServeHTTP(timeoutWriter, r.WithContext(ctx)) |
| 352 | + close(serverDone) |
| 353 | + }() |
| 354 | + |
| 355 | + select { |
| 356 | + case p := <-panicChan: |
| 357 | + panic(p) |
| 358 | + |
| 359 | + case <-serverDone: |
| 360 | + timeoutWriter.finallyWrite(w) |
| 361 | + |
| 362 | + case <-ctx.Done(): |
333 | 363 | err := ctx.Err()
|
334 | 364 |
|
335 | 365 | if err == context.DeadlineExceeded {
|
336 |
| - timeoutWriter.mu.Lock() |
337 |
| - defer timeoutWriter.mu.Unlock() |
338 |
| - if timeoutWriter.wrote == 0 { |
339 |
| - // writer wasn't written to, so we're sending the error payload |
340 |
| - |
341 |
| - httpError := &HTTPError{ |
342 |
| - HTTPStatus: http.StatusGatewayTimeout, |
343 |
| - ErrorCode: ErrorCodeRequestTimeout, |
344 |
| - Message: "Processing this request timed out, please retry after a moment.", |
345 |
| - } |
| 366 | + httpError := &HTTPError{ |
| 367 | + HTTPStatus: http.StatusGatewayTimeout, |
| 368 | + ErrorCode: ErrorCodeRequestTimeout, |
| 369 | + Message: "Processing this request timed out, please retry after a moment.", |
| 370 | + } |
346 | 371 |
|
347 |
| - httpError = httpError.WithInternalError(err) |
| 372 | + httpError = httpError.WithInternalError(err) |
348 | 373 |
|
349 |
| - HandleResponseError(httpError, w, r) |
350 |
| - } |
351 |
| - } |
352 |
| - }() |
| 374 | + HandleResponseError(httpError, w, r) |
| 375 | + } else { |
| 376 | + // unrecognized context error, so we should wait for the server to finish |
| 377 | + // and write out the response |
| 378 | + <-serverDone |
353 | 379 |
|
354 |
| - next.ServeHTTP(timeoutWriter, r.WithContext(ctx)) |
| 380 | + timeoutWriter.finallyWrite(w) |
| 381 | + } |
| 382 | + } |
355 | 383 | })
|
356 | 384 | }
|
357 | 385 | }
|
0 commit comments