Skip to content

Commit 72614a1

Browse files
hfkangmingtay
andauthored
feat: new timeout writer implementation (supabase#1584)
supabase#1529 introduced timeout middleware, but it appears from working in the wild it has some race conditions that are not particularly helpful. This PR rewrites the implementation to get rid of race conditions, at the expense of slightly higher RAM usage. It follows the implementation of `http.TimeoutHandler` closely. --------- Co-authored-by: Kang Ming <[email protected]>
1 parent 39ca026 commit 72614a1

File tree

5 files changed

+224
-170
lines changed

5 files changed

+224
-170
lines changed

internal/api/api.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
102102
r.UseBypass(recoverer)
103103

104104
if globalConfig.API.MaxRequestDuration > 0 {
105-
r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration))
105+
r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration))
106106
}
107107

108108
// request tracing should be added only when tracing or metrics is enabled

internal/api/middleware.go

Lines changed: 83 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
package api
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"fmt"
78
"net/http"
89
"net/url"
910
"strings"
1011
"sync"
11-
"sync/atomic"
1212
"time"
1313

14+
"github.com/sirupsen/logrus"
1415
"github.com/supabase/auth/internal/models"
1516
"github.com/supabase/auth/internal/observability"
1617
"github.com/supabase/auth/internal/security"
@@ -263,95 +264,122 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H
263264
}
264265
}
265266

266-
// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent
267-
// writes after the context contained in it has exceeded the deadline. If a
268-
// partial write occurs before the deadline is exceeded, but the writing is not
269-
// complete it will allow further writes.
267+
// timeoutResponseWriter is a http.ResponseWriter that queues up a response
268+
// body to be sent if the serving completes before the context has exceeded its
269+
// deadline.
270270
type timeoutResponseWriter struct {
271-
ctx context.Context
272-
w http.ResponseWriter
273-
wrote int32
274-
mu sync.Mutex
271+
sync.Mutex
272+
273+
header http.Header
274+
wroteHeader bool
275+
snapHeader http.Header // snapshot of the header at the time WriteHeader was called
276+
statusCode int
277+
buf bytes.Buffer
275278
}
276279

277280
func (t *timeoutResponseWriter) Header() http.Header {
278-
t.mu.Lock()
279-
defer t.mu.Unlock()
280-
return t.w.Header()
281+
t.Lock()
282+
defer t.Unlock()
283+
284+
return t.header
281285
}
282286

283287
func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) {
284-
t.mu.Lock()
285-
defer t.mu.Unlock()
286-
if t.ctx.Err() == context.DeadlineExceeded {
287-
if atomic.LoadInt32(&t.wrote) == 0 {
288-
return 0, context.DeadlineExceeded
289-
}
288+
t.Lock()
289+
defer t.Unlock()
290290

291-
// writing started before the deadline exceeded, but the
292-
// deadline came in the middle, so letting the writes go
293-
// through
291+
if !t.wroteHeader {
292+
t.WriteHeader(http.StatusOK)
294293
}
295294

296-
t.wrote = 1
297-
298-
return t.w.Write(bytes)
295+
return t.buf.Write(bytes)
299296
}
300297

301298
func (t *timeoutResponseWriter) WriteHeader(statusCode int) {
302-
t.mu.Lock()
303-
defer t.mu.Unlock()
304-
if t.ctx.Err() == context.DeadlineExceeded {
305-
if atomic.LoadInt32(&t.wrote) == 0 {
306-
return
307-
}
299+
t.Lock()
300+
defer t.Unlock()
301+
302+
if t.wroteHeader {
303+
// ignore multiple calls to WriteHeader
304+
// once WriteHeader has been called once, a snapshot of the header map is taken
305+
// and saved in snapHeader to be used in finallyWrite
306+
return
307+
}
308+
t.statusCode = statusCode
309+
t.wroteHeader = true
310+
t.snapHeader = t.header.Clone()
311+
}
312+
313+
func (t *timeoutResponseWriter) finallyWrite(w http.ResponseWriter) {
314+
t.Lock()
315+
defer t.Unlock()
308316

309-
// writing started before the deadline exceeded, but the
310-
// deadline came in the middle, so letting the writes go
311-
// through
317+
dst := w.Header()
318+
for k, vv := range t.snapHeader {
319+
dst[k] = vv
312320
}
313321

314-
t.wrote = 1
322+
if !t.wroteHeader {
323+
t.statusCode = http.StatusOK
324+
}
315325

316-
t.w.WriteHeader(statusCode)
326+
w.WriteHeader(t.statusCode)
327+
if _, err := w.Write(t.buf.Bytes()); err != nil {
328+
logrus.WithError(err).Warn("Write failed")
329+
}
317330
}
318331

319-
func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
332+
func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
320333
return func(next http.Handler) http.Handler {
321334
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
322335
ctx, cancel := context.WithTimeout(r.Context(), timeout)
323336
defer cancel()
324337

325338
timeoutWriter := &timeoutResponseWriter{
326-
w: w,
327-
ctx: ctx,
339+
header: make(http.Header),
328340
}
329341

342+
panicChan := make(chan any, 1)
343+
serverDone := make(chan struct{})
330344
go func() {
331-
<-ctx.Done()
345+
defer func() {
346+
if p := recover(); p != nil {
347+
panicChan <- p
348+
}
349+
}()
332350

351+
next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
352+
close(serverDone)
353+
}()
354+
355+
select {
356+
case p := <-panicChan:
357+
panic(p)
358+
359+
case <-serverDone:
360+
timeoutWriter.finallyWrite(w)
361+
362+
case <-ctx.Done():
333363
err := ctx.Err()
334364

335365
if err == context.DeadlineExceeded {
336-
timeoutWriter.mu.Lock()
337-
defer timeoutWriter.mu.Unlock()
338-
if timeoutWriter.wrote == 0 {
339-
// writer wasn't written to, so we're sending the error payload
340-
341-
httpError := &HTTPError{
342-
HTTPStatus: http.StatusGatewayTimeout,
343-
ErrorCode: ErrorCodeRequestTimeout,
344-
Message: "Processing this request timed out, please retry after a moment.",
345-
}
366+
httpError := &HTTPError{
367+
HTTPStatus: http.StatusGatewayTimeout,
368+
ErrorCode: ErrorCodeRequestTimeout,
369+
Message: "Processing this request timed out, please retry after a moment.",
370+
}
346371

347-
httpError = httpError.WithInternalError(err)
372+
httpError = httpError.WithInternalError(err)
348373

349-
HandleResponseError(httpError, w, r)
350-
}
351-
}
352-
}()
374+
HandleResponseError(httpError, w, r)
375+
} else {
376+
// unrecognized context error, so we should wait for the server to finish
377+
// and write out the response
378+
<-serverDone
353379

354-
next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
380+
timeoutWriter.finallyWrite(w)
381+
}
382+
}
355383
})
356384
}
357385
}

internal/api/middleware_test.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() {
319319
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
320320
w := httptest.NewRecorder()
321321

322-
timeoutHandler := ts.API.timeoutMiddleware(ts.Config.API.MaxRequestDuration)
322+
timeoutHandler := timeoutMiddleware(ts.Config.API.MaxRequestDuration)
323323

324324
slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
325325
// Sleep for 1 second to simulate a slow handler which should trigger the timeout
@@ -335,3 +335,24 @@ func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() {
335335
require.Equal(ts.T(), float64(504), data["code"])
336336
require.NotNil(ts.T(), data["msg"])
337337
}
338+
339+
func TestTimeoutResponseWriter(t *testing.T) {
340+
// timeoutResponseWriter should exhitbit a similar behavior as http.ResponseWriter
341+
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
342+
w1 := httptest.NewRecorder()
343+
w2 := httptest.NewRecorder()
344+
345+
timeoutHandler := timeoutMiddleware(time.Second * 10)
346+
347+
redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
348+
// tries to redirect twice
349+
http.Redirect(w, r, "http://localhost:3001/#message=first_message", http.StatusSeeOther)
350+
351+
// overwrites the first
352+
http.Redirect(w, r, "http://localhost:3001/second", http.StatusSeeOther)
353+
})
354+
timeoutHandler(redirectHandler).ServeHTTP(w1, req)
355+
redirectHandler.ServeHTTP(w2, req)
356+
357+
require.Equal(t, w1.Result(), w2.Result())
358+
}

internal/api/verify.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
125125
err error
126126
token *AccessTokenResponse
127127
authCode string
128+
rurl string
128129
)
129130

130131
grantParams.FillGrantParams(r)
@@ -138,6 +139,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
138139
return err
139140
}
140141
}
142+
141143
err = db.Transaction(func(tx *storage.Connection) error {
142144
var terr error
143145
user, terr = a.verifyTokenHash(tx, params)
@@ -152,12 +154,11 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
152154
case mail.EmailChangeVerification:
153155
user, terr = a.emailChangeVerify(r, tx, params, user)
154156
if user == nil && terr == nil {
155-
// when double confirmation is required
156-
rurl, err := a.prepRedirectURL(singleConfirmationAccepted, params.RedirectTo, flowType)
157-
if err != nil {
158-
return err
157+
// only one OTP is confirmed at this point, so we return early and ask the user to confirm the second OTP
158+
rurl, terr = a.prepRedirectURL(singleConfirmationAccepted, params.RedirectTo, flowType)
159+
if terr != nil {
160+
return terr
159161
}
160-
http.Redirect(w, r, rurl, http.StatusSeeOther)
161162
return nil
162163
}
163164
default:
@@ -198,15 +199,17 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
198199
if err != nil {
199200
var herr *HTTPError
200201
if errors.As(err, &herr) {
201-
rurl, err := a.prepErrorRedirectURL(herr, r, params.RedirectTo, flowType)
202+
rurl, err = a.prepErrorRedirectURL(herr, r, params.RedirectTo, flowType)
202203
if err != nil {
203204
return err
204205
}
205-
http.Redirect(w, r, rurl, http.StatusSeeOther)
206-
return nil
207206
}
208207
}
209-
rurl := params.RedirectTo
208+
if rurl != "" {
209+
http.Redirect(w, r, rurl, http.StatusSeeOther)
210+
return nil
211+
}
212+
rurl = params.RedirectTo
210213
if isImplicitFlow(flowType) && token != nil {
211214
q := url.Values{}
212215
q.Set("type", params.Type)

0 commit comments

Comments
 (0)