Skip to content

Commit c5904c0

Browse files
cstocktonChris Stockton
andauthored
feat: refactor hooks out of api package (supabase#1976)
These changes refactor hooks out of the api package into `internal/hooks/v0hooks`. To do this we needed to move errors from `internal/api` to `internal/api/apierrors`. This required updating the `New*Error` functions and updating all existing references to errors in the `internal/api` package. --------- Co-authored-by: Chris Stockton <[email protected]>
1 parent 38eb4bb commit c5904c0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1179
-774
lines changed

internal/api/admin.go

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,
5454

5555
userID, err := uuid.FromString(chi.URLParam(r, "user_id"))
5656
if err != nil {
57-
return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "user_id must be an UUID")
57+
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeValidationFailed, "user_id must be an UUID")
5858
}
5959

6060
observability.LogEntrySetField(r, "user_id", userID)
6161

6262
u, err := models.FindUserByID(db, userID)
6363
if err != nil {
6464
if models.IsNotFoundError(err) {
65-
return nil, notFoundError(apierrors.ErrorCodeUserNotFound, "User not found")
65+
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeUserNotFound, "User not found")
6666
}
67-
return nil, internalServerError("Database error loading user").WithInternalError(err)
67+
return nil, apierrors.NewInternalServerError("Database error loading user").WithInternalError(err)
6868
}
6969

7070
return withUser(ctx, u), nil
@@ -77,17 +77,17 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex
7777
user := getUser(ctx)
7878
factorID, err := uuid.FromString(chi.URLParam(r, "factor_id"))
7979
if err != nil {
80-
return nil, notFoundError(apierrors.ErrorCodeValidationFailed, "factor_id must be an UUID")
80+
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeValidationFailed, "factor_id must be an UUID")
8181
}
8282

8383
observability.LogEntrySetField(r, "factor_id", factorID)
8484

8585
factor, err := user.FindOwnedFactorByID(db, factorID)
8686
if err != nil {
8787
if models.IsNotFoundError(err) {
88-
return nil, notFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found")
88+
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeMFAFactorNotFound, "Factor not found")
8989
}
90-
return nil, internalServerError("Database error loading factor").WithInternalError(err)
90+
return nil, apierrors.NewInternalServerError("Database error loading factor").WithInternalError(err)
9191
}
9292
return withFactor(ctx, factor), nil
9393
}
@@ -109,19 +109,19 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {
109109

110110
pageParams, err := paginate(r)
111111
if err != nil {
112-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
112+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
113113
}
114114

115115
sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}})
116116
if err != nil {
117-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
117+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
118118
}
119119

120120
filter := r.URL.Query().Get("filter")
121121

122122
users, err := models.FindUsersInAudience(db, aud, pageParams, sortParams, filter)
123123
if err != nil {
124-
return internalServerError("Database error finding users").WithInternalError(err)
124+
return apierrors.NewInternalServerError("Database error finding users").WithInternalError(err)
125125
}
126126
addPaginationHeaders(w, r, pageParams)
127127

@@ -170,7 +170,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
170170
if params.BanDuration != "none" {
171171
duration, err = time.ParseDuration(params.BanDuration)
172172
if err != nil {
173-
return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
173+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
174174
}
175175
}
176176
banDuration = &duration
@@ -315,7 +315,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
315315
})
316316

317317
if err != nil {
318-
return internalServerError("Error updating user").WithInternalError(err)
318+
return apierrors.NewInternalServerError("Error updating user").WithInternalError(err)
319319
}
320320

321321
return sendJSON(w, http.StatusOK, user)
@@ -339,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
339339
}
340340

341341
if params.Email == "" && params.Phone == "" {
342-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
342+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
343343
}
344344

345345
var providers []string
@@ -349,9 +349,9 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
349349
return err
350350
}
351351
if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil {
352-
return internalServerError("Database error checking email").WithInternalError(err)
352+
return apierrors.NewInternalServerError("Database error checking email").WithInternalError(err)
353353
} else if user != nil {
354-
return unprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg)
354+
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg)
355355
}
356356
providers = append(providers, "email")
357357
}
@@ -362,21 +362,21 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
362362
return err
363363
}
364364
if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil {
365-
return internalServerError("Database error checking phone").WithInternalError(err)
365+
return apierrors.NewInternalServerError("Database error checking phone").WithInternalError(err)
366366
} else if exists {
367-
return unprocessableEntityError(apierrors.ErrorCodePhoneExists, "Phone number already registered by another user")
367+
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodePhoneExists, "Phone number already registered by another user")
368368
}
369369
providers = append(providers, "phone")
370370
}
371371

372372
if params.Password != nil && params.PasswordHash != "" {
373-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Only a password or a password hash should be provided")
373+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only a password or a password hash should be provided")
374374
}
375375

376376
if (params.Password == nil || *params.Password == "") && params.PasswordHash == "" {
377377
password, err := password.Generate(64, 10, 0, false, true)
378378
if err != nil {
379-
return internalServerError("Error generating password").WithInternalError(err)
379+
return apierrors.NewInternalServerError("Error generating password").WithInternalError(err)
380380
}
381381
params.Password = &password
382382
}
@@ -390,18 +390,18 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
390390

391391
if err != nil {
392392
if errors.Is(err, bcrypt.ErrPasswordTooLong) {
393-
return badRequestError(apierrors.ErrorCodeValidationFailed, err.Error())
393+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error())
394394
}
395-
return internalServerError("Error creating user").WithInternalError(err)
395+
return apierrors.NewInternalServerError("Error creating user").WithInternalError(err)
396396
}
397397

398398
if params.Id != "" {
399399
customId, err := uuid.FromString(params.Id)
400400
if err != nil {
401-
return badRequestError(apierrors.ErrorCodeValidationFailed, "ID must conform to the uuid v4 format")
401+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "ID must conform to the uuid v4 format")
402402
}
403403
if customId == uuid.Nil {
404-
return badRequestError(apierrors.ErrorCodeValidationFailed, "ID cannot be a nil uuid")
404+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "ID cannot be a nil uuid")
405405
}
406406
user.ID = customId
407407
}
@@ -419,7 +419,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
419419
if params.BanDuration != "none" {
420420
duration, err = time.ParseDuration(params.BanDuration)
421421
if err != nil {
422-
return badRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
422+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
423423
}
424424
}
425425
banDuration = &duration
@@ -501,7 +501,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
501501
})
502502

503503
if err != nil {
504-
return internalServerError("Database error creating new user").WithInternalError(err)
504+
return apierrors.NewInternalServerError("Database error creating new user").WithInternalError(err)
505505
}
506506

507507
return sendJSON(w, http.StatusOK, user)
@@ -529,7 +529,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
529529
"user_email": user.Email,
530530
"user_phone": user.Phone,
531531
}); terr != nil {
532-
return internalServerError("Error recording audit log entry").WithInternalError(terr)
532+
return apierrors.NewInternalServerError("Error recording audit log entry").WithInternalError(terr)
533533
}
534534

535535
if params.ShouldSoftDelete {
@@ -538,24 +538,24 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
538538
return nil
539539
}
540540
if terr := user.SoftDeleteUser(tx); terr != nil {
541-
return internalServerError("Error soft deleting user").WithInternalError(terr)
541+
return apierrors.NewInternalServerError("Error soft deleting user").WithInternalError(terr)
542542
}
543543

544544
if terr := user.SoftDeleteUserIdentities(tx); terr != nil {
545-
return internalServerError("Error soft deleting user identities").WithInternalError(terr)
545+
return apierrors.NewInternalServerError("Error soft deleting user identities").WithInternalError(terr)
546546
}
547547

548548
// hard delete all associated factors
549549
if terr := models.DeleteFactorsByUserId(tx, user.ID); terr != nil {
550-
return internalServerError("Error deleting user's factors").WithInternalError(terr)
550+
return apierrors.NewInternalServerError("Error deleting user's factors").WithInternalError(terr)
551551
}
552552
// hard delete all associated sessions
553553
if terr := models.Logout(tx, user.ID); terr != nil {
554-
return internalServerError("Error deleting user's sessions").WithInternalError(terr)
554+
return apierrors.NewInternalServerError("Error deleting user's sessions").WithInternalError(terr)
555555
}
556556
} else {
557557
if terr := tx.Destroy(user); terr != nil {
558-
return internalServerError("Database error deleting user").WithInternalError(terr)
558+
return apierrors.NewInternalServerError("Database error deleting user").WithInternalError(terr)
559559
}
560560
}
561561

@@ -581,7 +581,7 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro
581581
return terr
582582
}
583583
if terr := tx.Destroy(factor); terr != nil {
584-
return internalServerError("Database error deleting factor").WithInternalError(terr)
584+
return apierrors.NewInternalServerError("Database error deleting factor").WithInternalError(terr)
585585
}
586586
return nil
587587
})
@@ -619,7 +619,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
619619
if params.Phone != "" && factor.IsPhoneFactor() {
620620
phone, err := validatePhone(params.Phone)
621621
if err != nil {
622-
return badRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)")
622+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)")
623623
}
624624
if terr := factor.UpdatePhone(tx, phone); terr != nil {
625625
return terr

internal/api/anonymous.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
1616
aud := a.requestAud(ctx, r)
1717

1818
if config.DisableSignup {
19-
return unprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
19+
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeSignupDisabled, "Signups not allowed for this instance")
2020
}
2121

2222
params := &SignupParams{}
@@ -48,7 +48,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
4848
return nil
4949
})
5050
if err != nil {
51-
return internalServerError("Database error creating anonymous user").WithInternalError(err)
51+
return apierrors.NewInternalServerError("Database error creating anonymous user").WithInternalError(err)
5252
}
5353

5454
metering.RecordLogin("anonymous", newUser.ID)

internal/api/api.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/sirupsen/logrus"
1111
"github.com/supabase/auth/internal/api/apierrors"
1212
"github.com/supabase/auth/internal/conf"
13+
"github.com/supabase/auth/internal/hooks"
1314
"github.com/supabase/auth/internal/mailer"
1415
"github.com/supabase/auth/internal/models"
1516
"github.com/supabase/auth/internal/observability"
@@ -32,6 +33,7 @@ type API struct {
3233
config *conf.GlobalConfiguration
3334
version string
3435

36+
hooksMgr *hooks.Manager
3537
hibpClient *hibp.PwnedClient
3638

3739
// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
@@ -40,6 +42,9 @@ type API struct {
4042
limiterOpts *LimiterOptions
4143
}
4244

45+
func (a *API) GetConfig() *conf.GlobalConfiguration { return a.config }
46+
func (a *API) GetDB() *storage.Connection { return a.db }
47+
4348
func (a *API) Version() string {
4449
return a.version
4550
}
@@ -81,6 +86,9 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
8186
if api.limiterOpts == nil {
8287
api.limiterOpts = NewLimiterOptions(globalConfig)
8388
}
89+
if api.hooksMgr == nil {
90+
api.hooksMgr = hooks.NewManager(db, globalConfig)
91+
}
8492
if api.config.Password.HIBP.Enabled {
8593
httpClient := &http.Client{
8694
// all HIBP API requests should finish quickly to avoid
@@ -157,7 +165,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
157165
}
158166
if params.Email == "" && params.Phone == "" {
159167
if !api.config.External.AnonymousUsers.Enabled {
160-
return unprocessableEntityError(apierrors.ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled")
168+
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled")
161169
}
162170
if _, err := api.limitHandler(limitAnonymousSignIns)(w, r); err != nil {
163171
return err

internal/api/apierrors/apierrors.go

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package apierrors
22

33
import (
44
"fmt"
5+
"net/http"
56
)
67

78
// OAuthError is the JSON handler for OAuth2 error responses
@@ -30,7 +31,7 @@ func (e *OAuthError) WithInternalError(err error) *OAuthError {
3031
}
3132

3233
// WithInternalMessage adds internal message information to the error
33-
func (e *OAuthError) WithInternalMessage(fmtString string, args ...interface{}) *OAuthError {
34+
func (e *OAuthError) WithInternalMessage(fmtString string, args ...any) *OAuthError {
3435
e.InternalMessage = fmt.Sprintf(fmtString, args...)
3536
return e
3637
}
@@ -53,14 +54,42 @@ type HTTPError struct {
5354
ErrorID string `json:"error_id,omitempty"`
5455
}
5556

56-
func NewHTTPError(httpStatus int, errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError {
57+
func NewHTTPError(httpStatus int, errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
5758
return &HTTPError{
5859
HTTPStatus: httpStatus,
5960
ErrorCode: errorCode,
6061
Message: fmt.Sprintf(fmtString, args...),
6162
}
6263
}
6364

65+
func NewBadRequestError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
66+
return NewHTTPError(http.StatusBadRequest, errorCode, fmtString, args...)
67+
}
68+
69+
func NewNotFoundError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
70+
return NewHTTPError(http.StatusNotFound, errorCode, fmtString, args...)
71+
}
72+
73+
func NewForbiddenError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
74+
return NewHTTPError(http.StatusForbidden, errorCode, fmtString, args...)
75+
}
76+
77+
func NewUnprocessableEntityError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
78+
return NewHTTPError(http.StatusUnprocessableEntity, errorCode, fmtString, args...)
79+
}
80+
81+
func NewTooManyRequestsError(errorCode ErrorCode, fmtString string, args ...any) *HTTPError {
82+
return NewHTTPError(http.StatusTooManyRequests, errorCode, fmtString, args...)
83+
}
84+
85+
func NewInternalServerError(fmtString string, args ...any) *HTTPError {
86+
return NewHTTPError(http.StatusInternalServerError, ErrorCodeUnexpectedFailure, fmtString, args...)
87+
}
88+
89+
func NewConflictError(fmtString string, args ...any) *HTTPError {
90+
return NewHTTPError(http.StatusConflict, ErrorCodeConflict, fmtString, args...)
91+
}
92+
6493
func (e *HTTPError) Error() string {
6594
if e.InternalMessage != "" {
6695
return e.InternalMessage
@@ -87,7 +116,7 @@ func (e *HTTPError) WithInternalError(err error) *HTTPError {
87116
}
88117

89118
// WithInternalMessage adds internal message information to the error
90-
func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) *HTTPError {
119+
func (e *HTTPError) WithInternalMessage(fmtString string, args ...any) *HTTPError {
91120
e.InternalMessage = fmt.Sprintf(fmtString, args...)
92121
return e
93122
}

0 commit comments

Comments
 (0)