diff --git a/internal/api/v1/auth/block_page.go b/internal/api/v1/auth/block_page.go new file mode 100644 index 0000000..9d5c0b5 --- /dev/null +++ b/internal/api/v1/auth/block_page.go @@ -0,0 +1,22 @@ +package auth + +import ( + "html/template" + "net/http" + + _ "embed" +) + +//go:embed block_page.html +var blockPageHTML string + +var blockPageTemplate = template.Must(template.New("block_page").Parse(blockPageHTML)) + +func WriteBlockPage(w http.ResponseWriter, status int, error string, logoutURL string) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + blockPageTemplate.Execute(w, map[string]string{ + "StatusText": http.StatusText(status), + "Error": error, + "LogoutURL": logoutURL, + }) +} diff --git a/internal/api/v1/auth/block_page.html b/internal/api/v1/auth/block_page.html new file mode 100644 index 0000000..195cc13 --- /dev/null +++ b/internal/api/v1/auth/block_page.html @@ -0,0 +1,14 @@ + + + + + + + Access Denied + + +

{{.StatusText}}

+

{{.Error}}

+ Logout + + diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go index d56cb50..a5aaa0a 100644 --- a/internal/api/v1/auth/oidc.go +++ b/internal/api/v1/auth/oidc.go @@ -132,7 +132,7 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error { allowedUser := slices.Contains(auth.allowedUsers, claims.Username) allowedGroup := len(CE.Intersect(claims.Groups, auth.allowedGroups)) > 0 if !allowedUser && !allowedGroup { - return ErrUserNotAllowed.Subject(claims.Username) + return ErrUserNotAllowed } return nil } diff --git a/internal/net/gphttp/error.go b/internal/net/gphttp/error.go index 8d5c488..f269e3f 100644 --- a/internal/net/gphttp/error.go +++ b/internal/net/gphttp/error.go @@ -48,7 +48,6 @@ func ClientError(w http.ResponseWriter, err error, code ...int) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(err) } else { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") http.Error(w, err.Error(), code[0]) } } @@ -65,7 +64,8 @@ func BadRequest(w http.ResponseWriter, err string, code ...int) { if len(code) == 0 { code = []int{http.StatusBadRequest} } - http.Error(w, err, code[0]) + w.WriteHeader(code[0]) + w.Write([]byte(err)) } // Unauthorized returns an Unauthorized response with the given error message. @@ -73,6 +73,11 @@ func Unauthorized(w http.ResponseWriter, err string) { BadRequest(w, err, http.StatusUnauthorized) } +// Forbidden returns a Forbidden response with the given error message. +func Forbidden(w http.ResponseWriter, err string) { + BadRequest(w, err, http.StatusForbidden) +} + // NotFound returns a Not Found response with the given error message. func NotFound(w http.ResponseWriter, err string) { BadRequest(w, err, http.StatusNotFound) diff --git a/internal/net/gphttp/middleware/oidc.go b/internal/net/gphttp/middleware/oidc.go index 128f4a0..417469f 100644 --- a/internal/net/gphttp/middleware/oidc.go +++ b/internal/net/gphttp/middleware/oidc.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "sync" "sync/atomic" @@ -80,11 +81,13 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce } if err := amw.auth.CheckToken(r); err != nil { - amw.authMux.ServeHTTP(w, r) - return false - } - if r.URL.Path == auth.OIDCLogoutPath { - amw.auth.LogoutCallbackHandler(w, r) + if errors.Is(err, auth.ErrMissingToken) { + amw.authMux.ServeHTTP(w, r) + } else if r.URL.Path == auth.OIDCLogoutPath { + amw.auth.LogoutCallbackHandler(w, r) + } else { + auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath) + } return false } return true