diff --git a/server/auth.go b/server/auth.go index cb42ca17..822f3dd3 100644 --- a/server/auth.go +++ b/server/auth.go @@ -109,34 +109,33 @@ func createJwt(claims KaraberusClaims) (*jwt.Token, string, error) { } func authMiddleware(ctx huma.Context, next func(huma.Context)) { - var ( - token string - user *User - scopes *Scopes - err error - ) + var user *User = nil + var scopes *Scopes = nil + isOIDC := false // Check for a token in the request. - token, err = getRequestToken(ctx) + token, err := getRequestToken(ctx) // If we have a token, try to get the user. if err == nil { user, scopes, err = getUserScopesFromApiToken(ctx.Context(), token) if err != nil { user, scopes, err = getUserScopesFromJwt(ctx.Context(), token) + isOIDC = true } } + // If we have a user, add it to the context. if err == nil { ctx = huma.WithValue(ctx, currentUserCtxKey, user) - } - if ok := checkOperationSecurity(ctx, user, scopes); ok { - next(ctx) - return + if checkOperationSecurity(ctx, user, scopes, isOIDC) { + next(ctx) + return + } } if err != nil { - getLogger().Print(err) + getLogger().Println(err) ctx.SetStatus(http.StatusUnauthorized) } else { ctx.SetStatus(http.StatusForbidden) @@ -161,8 +160,8 @@ func getRequestToken(ctx huma.Context) (string, error) { func getUserScopesFromApiToken(ctx context.Context, token string) (*User, *Scopes, error) { db := GetDB(ctx) - apiToken := TokenV2{Token: token} - if err := db.Preload(clause.Associations).First(&apiToken).Error; err != nil { + apiToken := TokenV2{} + if err := db.Preload(clause.Associations).Where(&TokenV2{Token: token}).First(&apiToken).Error; err != nil { return nil, nil, err } return &apiToken.User, &apiToken.Scopes, nil @@ -209,28 +208,32 @@ func getUserScopesFromJwt(ctx context.Context, token string) (*User, *Scopes, er return &user, &scopes, nil } -func checkOperationSecurity(ctx huma.Context, user *User, scopes *Scopes) bool { - var authRequired bool - var opScopes []string = []string{} +func checkOperationSecurity(ctx huma.Context, user *User, scopes *Scopes, isOIDC bool) bool { + oidcSecurity := false + opScopes := []string{} for _, opScheme := range ctx.Operation().Security { - var ok bool - if _, ok = opScheme["oidc"]; ok { - authRequired = true - } - if opScopes, ok = opScheme["scopes"]; ok { - break - } + _, oidcSecurityFound := opScheme["oidc"] + oidcSecurity = oidcSecurity || oidcSecurityFound + opScopes = append(opScopes, opScheme["scopes"]...) + } + + if !oidcSecurity && opScopes == nil { + return true } - if authRequired && user == nil { + if user == nil { return false } + if oidcSecurity && isOIDC { + return true + } + for _, v := range opScopes { if !scopes.HasScope(v) { return false } } - return true + return false }