Skip to content

Commit 7955cc7

Browse files
J0kangmingtay
andauthored
feat: add timeout middleware (supabase#1529)
A new middleware is introduced that enforces a strict timeout by using `context.WithTimeout()`. When the timeout is reached, a 504 JSON error with the `request_timeout` error code is sent. Anything that depends on the context is cancelled. --------- Co-authored-by: Kang Ming <[email protected]>
1 parent 43b52ff commit 7955cc7

File tree

6 files changed

+138
-13
lines changed

6 files changed

+138
-13
lines changed

internal/api/api.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
9898
logger := observability.NewStructuredLogger(logrus.StandardLogger(), globalConfig)
9999

100100
r := newRouter()
101+
102+
if globalConfig.API.MaxRequestDuration > 0 {
103+
r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration))
104+
}
105+
101106
r.Use(addRequestID(globalConfig))
102107

103108
// request tracing should be added only when tracing or metrics is enabled

internal/api/errorcodes.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,5 @@ const (
7777
ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry"
7878
ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit"
7979
ErrorCodeHookPayloadUnknownSize ErrorCode = "hook_payload_unknown_size"
80+
ErrorCodeRequestTimeout ErrorCode = "request_timeout"
8081
)

internal/api/errors.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
207207
output.Message = e.Message
208208
output.Payload.Reasons = e.Reasons
209209

210-
if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil {
210+
if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil && jsonErr != context.DeadlineExceeded {
211211
HandleResponseError(jsonErr, w, r)
212212
}
213213

@@ -224,7 +224,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
224224
output.Message = e.Message
225225
output.Payload.Reasons = e.Reasons
226226

227-
if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil {
227+
if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil && jsonErr != context.DeadlineExceeded {
228228
HandleResponseError(jsonErr, w, r)
229229
}
230230
}
@@ -252,7 +252,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
252252
}
253253
}
254254

255-
if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil {
255+
if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded {
256256
HandleResponseError(jsonErr, w, r)
257257
}
258258
} else {
@@ -266,20 +266,20 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
266266

267267
// Provide better error messages for certain user-triggered Postgres errors.
268268
if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil {
269-
if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil {
269+
if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil && jsonErr != context.DeadlineExceeded {
270270
HandleResponseError(jsonErr, w, r)
271271
}
272272
return
273273
}
274274

275-
if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil {
275+
if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil && jsonErr != context.DeadlineExceeded {
276276
HandleResponseError(jsonErr, w, r)
277277
}
278278
}
279279

280280
case *OAuthError:
281281
log.WithError(e.Cause()).Info(e.Error())
282-
if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil {
282+
if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil && jsonErr != context.DeadlineExceeded {
283283
HandleResponseError(jsonErr, w, r)
284284
}
285285

@@ -295,7 +295,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
295295
Message: "Unexpected failure, please check server logs for more information",
296296
}
297297

298-
if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil {
298+
if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded {
299299
HandleResponseError(jsonErr, w, r)
300300
}
301301
} else {
@@ -305,7 +305,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
305305
Message: "Unexpected failure, please check server logs for more information",
306306
}
307307

308-
if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil {
308+
if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil && jsonErr != context.DeadlineExceeded {
309309
HandleResponseError(jsonErr, w, r)
310310
}
311311
}

internal/api/middleware.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"net/http"
88
"net/url"
99
"strings"
10+
"sync"
11+
"sync/atomic"
1012
"time"
1113

1214
"github.com/supabase/auth/internal/models"
@@ -260,3 +262,96 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H
260262
})
261263
}
262264
}
265+
266+
// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent
267+
// writes after the context contained in it has exceeded the deadline. If a
268+
// partial write occurs before the deadline is exceeded, but the writing is not
269+
// complete it will allow further writes.
270+
type timeoutResponseWriter struct {
271+
ctx context.Context
272+
w http.ResponseWriter
273+
wrote int32
274+
mu sync.Mutex
275+
}
276+
277+
func (t *timeoutResponseWriter) Header() http.Header {
278+
t.mu.Lock()
279+
defer t.mu.Unlock()
280+
return t.w.Header()
281+
}
282+
283+
func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) {
284+
t.mu.Lock()
285+
defer t.mu.Unlock()
286+
if t.ctx.Err() == context.DeadlineExceeded {
287+
if atomic.LoadInt32(&t.wrote) == 0 {
288+
return 0, context.DeadlineExceeded
289+
}
290+
291+
// writing started before the deadline exceeded, but the
292+
// deadline came in the middle, so letting the writes go
293+
// through
294+
}
295+
296+
t.wrote = 1
297+
298+
return t.w.Write(bytes)
299+
}
300+
301+
func (t *timeoutResponseWriter) WriteHeader(statusCode int) {
302+
t.mu.Lock()
303+
defer t.mu.Unlock()
304+
if t.ctx.Err() == context.DeadlineExceeded {
305+
if atomic.LoadInt32(&t.wrote) == 0 {
306+
return
307+
}
308+
309+
// writing started before the deadline exceeded, but the
310+
// deadline came in the middle, so letting the writes go
311+
// through
312+
}
313+
314+
t.wrote = 1
315+
316+
t.w.WriteHeader(statusCode)
317+
}
318+
319+
func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
320+
return func(next http.Handler) http.Handler {
321+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
322+
ctx, cancel := context.WithTimeout(r.Context(), timeout)
323+
defer cancel()
324+
325+
timeoutWriter := &timeoutResponseWriter{
326+
w: w,
327+
ctx: ctx,
328+
}
329+
330+
go func() {
331+
<-ctx.Done()
332+
333+
err := ctx.Err()
334+
335+
if err == context.DeadlineExceeded {
336+
timeoutWriter.mu.Lock()
337+
defer timeoutWriter.mu.Unlock()
338+
if timeoutWriter.wrote == 0 {
339+
// writer wasn't written to, so we're sending the error payload
340+
341+
httpError := &HTTPError{
342+
HTTPStatus: http.StatusGatewayTimeout,
343+
ErrorCode: ErrorCodeRequestTimeout,
344+
Message: "Processing this request timed out, please retry after a moment.",
345+
}
346+
347+
httpError = httpError.WithInternalError(err)
348+
349+
HandleResponseError(httpError, w, r)
350+
}
351+
}
352+
}()
353+
354+
next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
355+
})
356+
}
357+
}

internal/api/middleware_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http/httptest"
1010
"net/url"
1111
"testing"
12+
"time"
1213

1314
jwt "github.com/golang-jwt/jwt"
1415
"github.com/stretchr/testify/assert"
@@ -312,3 +313,25 @@ func TestFunctionHooksUnmarshalJSON(t *testing.T) {
312313
})
313314
}
314315
}
316+
317+
func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() {
318+
ts.Config.API.MaxRequestDuration = 5 * time.Microsecond
319+
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
320+
w := httptest.NewRecorder()
321+
322+
timeoutHandler := ts.API.timeoutMiddleware(ts.Config.API.MaxRequestDuration)
323+
324+
slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
325+
// Sleep for 1 second to simulate a slow handler which should trigger the timeout
326+
time.Sleep(1 * time.Second)
327+
ts.API.handler.ServeHTTP(w, r)
328+
})
329+
timeoutHandler(slowHandler).ServeHTTP(w, req)
330+
assert.Equal(ts.T(), http.StatusGatewayTimeout, w.Code)
331+
332+
var data map[string]interface{}
333+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
334+
require.Equal(ts.T(), ErrorCodeRequestTimeout, data["error_code"])
335+
require.Equal(ts.T(), float64(504), data["code"])
336+
require.NotNil(ts.T(), data["msg"])
337+
}

internal/conf/configuration.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,12 @@ type MFAConfiguration struct {
112112
}
113113

114114
type APIConfiguration struct {
115-
Host string
116-
Port string `envconfig:"PORT" default:"8081"`
117-
Endpoint string
118-
RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"`
119-
ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"`
115+
Host string
116+
Port string `envconfig:"PORT" default:"8081"`
117+
Endpoint string
118+
RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"`
119+
ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"`
120+
MaxRequestDuration time.Duration `json:"max_request_duration" split_words:"true" default:"10s"`
120121
}
121122

122123
func (a *APIConfiguration) Validate() error {

0 commit comments

Comments
 (0)