Skip to content

Commit a79dab0

Browse files
J0hf
authored andcommitted
feat: add request timeouts
1 parent bd8b5c4 commit a79dab0

File tree

5 files changed

+102
-13
lines changed

5 files changed

+102
-13
lines changed

internal/api/api.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
9898
logger := observability.NewStructuredLogger(logrus.StandardLogger(), globalConfig)
9999

100100
r := newRouter()
101+
102+
r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration))
101103
r.Use(addRequestID(globalConfig))
102104

103105
// request tracing should be added only when tracing or metrics is enabled

internal/api/errorcodes.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,5 @@ const (
7777
ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry"
7878
ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit"
7979
ErrorCodeHookPayloadUnknownSize ErrorCode = "hook_payload_unknown_size"
80+
ErrorCodeRequestTimeout ErrorCode = "request_timeout"
8081
)

internal/api/errors.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
207207
output.Message = e.Message
208208
output.Payload.Reasons = e.Reasons
209209

210-
if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil {
210+
if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil && jsonErr != context.DeadlineExceeded {
211211
HandleResponseError(jsonErr, w, r)
212212
}
213213

@@ -224,7 +224,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
224224
output.Message = e.Message
225225
output.Payload.Reasons = e.Reasons
226226

227-
if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil {
227+
if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil && jsonErr != context.DeadlineExceeded {
228228
HandleResponseError(jsonErr, w, r)
229229
}
230230
}
@@ -252,7 +252,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
252252
}
253253
}
254254

255-
if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil {
255+
if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded {
256256
HandleResponseError(jsonErr, w, r)
257257
}
258258
} else {
@@ -266,20 +266,20 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
266266

267267
// Provide better error messages for certain user-triggered Postgres errors.
268268
if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil {
269-
if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil {
269+
if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil && jsonErr != context.DeadlineExceeded {
270270
HandleResponseError(jsonErr, w, r)
271271
}
272272
return
273273
}
274274

275-
if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil {
275+
if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil && jsonErr != context.DeadlineExceeded {
276276
HandleResponseError(jsonErr, w, r)
277277
}
278278
}
279279

280280
case *OAuthError:
281281
log.WithError(e.Cause()).Info(e.Error())
282-
if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil {
282+
if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil && jsonErr != context.DeadlineExceeded {
283283
HandleResponseError(jsonErr, w, r)
284284
}
285285

@@ -295,7 +295,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
295295
Message: "Unexpected failure, please check server logs for more information",
296296
}
297297

298-
if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil {
298+
if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded {
299299
HandleResponseError(jsonErr, w, r)
300300
}
301301
} else {
@@ -305,7 +305,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
305305
Message: "Unexpected failure, please check server logs for more information",
306306
}
307307

308-
if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil {
308+
if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil && jsonErr != context.DeadlineExceeded {
309309
HandleResponseError(jsonErr, w, r)
310310
}
311311
}

internal/api/middleware.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/http"
88
"net/url"
99
"strings"
10+
"sync/atomic"
1011
"time"
1112

1213
"github.com/supabase/auth/internal/models"
@@ -260,3 +261,87 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H
260261
})
261262
}
262263
}
264+
265+
// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent
266+
// writes after the context contained in it has exceeded the deadline. If a
267+
// partial write occurs before the deadline is exceeded, but the writing is not
268+
// complete it will allow further writes.
269+
type timeoutResponseWriter struct {
270+
ctx context.Context
271+
w http.ResponseWriter
272+
wrote int32
273+
}
274+
275+
func (t *timeoutResponseWriter) Header() http.Header {
276+
return t.w.Header()
277+
}
278+
279+
func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) {
280+
if t.ctx.Err() == context.DeadlineExceeded {
281+
if atomic.LoadInt32(&t.wrote) == 0 {
282+
return 0, context.DeadlineExceeded
283+
}
284+
285+
// writing started before the deadline exceeded, but the
286+
// deadline came in the middle, so letting the writes go
287+
// through
288+
}
289+
290+
atomic.AddInt32(&t.wrote, 1)
291+
292+
return t.w.Write(bytes)
293+
}
294+
295+
func (t *timeoutResponseWriter) WriteHeader(statusCode int) {
296+
if t.ctx.Err() == context.DeadlineExceeded {
297+
if atomic.LoadInt32(&t.wrote) == 0 {
298+
return
299+
}
300+
301+
// writing started before the deadline exceeded, but the
302+
// deadline came in the middle, so letting the writes go
303+
// through
304+
}
305+
306+
atomic.AddInt32(&t.wrote, 1)
307+
308+
t.w.WriteHeader(statusCode)
309+
}
310+
311+
func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
312+
return func(next http.Handler) http.Handler {
313+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
314+
ctx, cancel := context.WithTimeout(r.Context(), timeout)
315+
defer cancel()
316+
317+
timeoutWriter := &timeoutResponseWriter{
318+
w: w,
319+
ctx: ctx,
320+
}
321+
322+
go func() {
323+
<-ctx.Done()
324+
325+
err := ctx.Err()
326+
327+
if err == context.DeadlineExceeded {
328+
if atomic.LoadInt32(&timeoutWriter.wrote) == 0 {
329+
// writer wasn't written to, so we're sending the error payload
330+
331+
httpError := &HTTPError{
332+
HTTPStatus: http.StatusGatewayTimeout,
333+
ErrorCode: ErrorCodeRequestTimeout,
334+
Message: "Processing this request timed out, please retry after a moment.",
335+
}
336+
337+
httpError = httpError.WithInternalError(err)
338+
339+
HandleResponseError(httpError, w, r)
340+
}
341+
}
342+
}()
343+
344+
next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
345+
})
346+
}
347+
}

internal/conf/configuration.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,12 @@ type MFAConfiguration struct {
112112
}
113113

114114
type APIConfiguration struct {
115-
Host string
116-
Port string `envconfig:"PORT" default:"8081"`
117-
Endpoint string
118-
RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"`
119-
ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"`
115+
Host string
116+
Port string `envconfig:"PORT" default:"8081"`
117+
Endpoint string
118+
RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"`
119+
ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"`
120+
MaxRequestDuration time.Duration `json:"max_request_duration" split_words:"true" default:"10s"`
120121
}
121122

122123
func (a *APIConfiguration) Validate() error {

0 commit comments

Comments
 (0)