Skip to content

Commit

Permalink
cleanup code, redirect to auth page when need
Browse files Browse the repository at this point in the history
  • Loading branch information
yusing committed Jan 12, 2025
1 parent ef277ef commit 76fe534
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 109 deletions.
14 changes: 7 additions & 7 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,16 @@ func main() {
return
}

if common.APIJWTSecret == nil {
logging.Warn().Msg("API JWT secret is empty, authentication is disabled")
}

cfg.Start()
config.WatchChanges()

// Initialize authentication providers
if err := auth.Initialize(); err != nil {
logging.Warn().Err(err).Msg("Failed to initialize authentication providers")
if !auth.IsEnabled() {
logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
} else {
// Initialize authentication providers
if err := auth.Initialize(); err != nil {
logging.Fatal().Err(err).Msg("Failed to initialize authentication providers")
}
}

sig := make(chan os.Signal, 1)
Expand Down
5 changes: 2 additions & 3 deletions internal/api/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
mux := ServeMux{http.NewServeMux()}
mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
mux.HandleFunc("POST", "/v1/login", auth.LoginHandler)
mux.HandleFunc("GET", "/v1/login/method", auth.AuthMethodHandler)
mux.HandleFunc("GET", "/v1/login/oidc", auth.OIDCLoginHandler)
mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler)
mux.HandleFunc("GET", "/v1/auth/redirect", auth.AuthRedirectHandler)
mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler)
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
Expand Down
95 changes: 33 additions & 62 deletions internal/api/v1/auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package auth

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"time"
Expand All @@ -25,51 +23,37 @@ type (
}
)

var (
ErrInvalidUsername = E.New("invalid username")
ErrInvalidPassword = E.New("invalid password")
)

func validatePassword(cred *Credentials) error {
if cred.Username != common.APIUser {
return ErrInvalidUsername.Subject(cred.Username)
}
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) {
return ErrInvalidPassword.Subject(cred.Password)
// Initialize sets up authentication providers.
func Initialize() error {
// Initialize OIDC if configured.
if common.OIDCIssuerURL != "" {
return InitOIDC(
common.OIDCIssuerURL,
common.OIDCClientID,
common.OIDCClientSecret,
common.OIDCRedirectURL,
)
}
return nil
}

func LoginHandler(w http.ResponseWriter, r *http.Request) {
var creds Credentials
err := json.NewDecoder(r.Body).Decode(&creds)
if err != nil {
U.HandleErr(w, r, err, http.StatusBadRequest)
return
}
if err := validatePassword(&creds); err != nil {
U.HandleErr(w, r, err, http.StatusUnauthorized)
return
}
if err := setAuthenticatedCookie(w, creds.Username); err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
func IsEnabled() bool {
return common.APIJWTSecret != nil || common.OIDCIssuerURL != ""
}

func AuthMethodHandler(w http.ResponseWriter, r *http.Request) {
// AuthRedirectHandler handles redirect to login page or OIDC login base on configuration.
func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
switch {
case oauthConfig != nil:
RedirectOIDC(w, r)
return
case common.APIJWTSecret == nil:
U.WriteBody(w, []byte("skip"))
case common.OIDCIssuerURL != "":
U.WriteBody(w, []byte("oidc"))
case common.APIPasswordHash != nil:
U.WriteBody(w, []byte("password"))
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
return
default:
U.WriteBody(w, []byte("skip"))
w.WriteHeader(http.StatusOK)
}
w.WriteHeader(http.StatusOK)
}

func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
Expand All @@ -86,57 +70,44 @@ func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
return err
}
http.SetCookie(w, &http.Cookie{
Name: "token",
Name: CookieToken,
Value: tokenStr,
Expires: expiresAt,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
Path: "/",
})
return nil
}

// LogoutHandler clear authentication cookie and redirect to login page.
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "token",
Name: CookieToken,
Value: "",
Expires: time.Unix(0, 0),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
Path: "/",
})
w.Header().Set("location", "/login")
w.WriteHeader(http.StatusTemporaryRedirect)
}

// Initialize sets up authentication providers.
func Initialize() error {
// Initialize OIDC if configured.
if common.OIDCIssuerURL != "" {
return InitOIDC(
common.OIDCIssuerURL,
common.OIDCClientID,
common.OIDCClientSecret,
common.OIDCRedirectURL,
)
}
return nil
AuthRedirectHandler(w, r)
}

func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
if common.IsDebugSkipAuth || common.APIJWTSecret == nil {
return next
}

return func(w http.ResponseWriter, r *http.Request) {
if checkToken(w, r) {
next(w, r)
if IsEnabled() {
return func(w http.ResponseWriter, r *http.Request) {
if checkToken(w, r) {
next(w, r)
}
}
}
return next
}

func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) {
tokenCookie, err := r.Cookie("token")
tokenCookie, err := r.Cookie(CookieToken)
if err != nil {
U.RespondError(w, E.New("missing token"), http.StatusUnauthorized)
return false
Expand Down
6 changes: 6 additions & 0 deletions internal/api/v1/auth/cookies.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package auth

const (
CookieToken = "token"
CookieOauthState = "oauth_state"
)
33 changes: 6 additions & 27 deletions internal/api/v1/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import (
"context"
"fmt"
"net/http"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
Expand Down Expand Up @@ -47,16 +45,16 @@ func InitOIDC(issuerURL, clientID, clientSecret, redirectURL string) error {
return nil
}

// OIDCLoginHandler initiates the OIDC login flow.
func OIDCLoginHandler(w http.ResponseWriter, r *http.Request) {
// RedirectOIDC initiates the OIDC login flow.
func RedirectOIDC(w http.ResponseWriter, r *http.Request) {
if oauthConfig == nil {
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
return
}

state := common.GenerateRandomString(32)
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Name: CookieOauthState,
Value: state,
MaxAge: 300,
HttpOnly: true,
Expand Down Expand Up @@ -87,7 +85,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
return
}

state, err := r.Cookie("oauth_state")
state, err := r.Cookie(CookieOauthState)
if err != nil {
U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest)
return
Expand Down Expand Up @@ -137,7 +135,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {

// handleTestCallback handles OIDC callback in test environment.
func handleTestCallback(w http.ResponseWriter, r *http.Request) {
state, err := r.Cookie("oauth_state")
state, err := r.Cookie(CookieOauthState)
if err != nil {
U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest)
return
Expand All @@ -149,29 +147,10 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) {
}

// Create test JWT token
expiresAt := time.Now().Add(common.APIJWTTokenTTL)
jwtClaims := &Claims{
Username: "test-user",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}

token := jwt.NewWithClaims(jwt.SigningMethodHS512, jwtClaims)
tokenStr, err := token.SignedString(common.APIJWTSecret)
if err != nil {
if err := setAuthenticatedCookie(w, "test-user"); err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError)
return
}

http.SetCookie(w, &http.Cookie{
Name: "token",
Value: tokenStr,
Expires: expiresAt,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Path: "/",
})

http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
}
4 changes: 2 additions & 2 deletions internal/api/v1/auth/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ func TestOIDCLoginHandler(t *testing.T) {
oauthConfig = nil
}

req := httptest.NewRequest(http.MethodGet, "/login/oidc", nil)
req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
w := httptest.NewRecorder()

OIDCLoginHandler(w, req)
RedirectOIDC(w, req)

if got := w.Code; got != tt.wantStatus {
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
Expand Down
45 changes: 45 additions & 0 deletions internal/api/v1/auth/userpass.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package auth

import (
"bytes"
"encoding/json"
"net/http"

U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)

var (
ErrInvalidUsername = E.New("invalid username")
ErrInvalidPassword = E.New("invalid password")
)

func validatePassword(cred *Credentials) error {
if cred.Username != common.APIUser {
return ErrInvalidUsername.Subject(cred.Username)
}
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) {
return ErrInvalidPassword.Subject(cred.Password)
}
return nil
}

// UserPassLoginHandler handles user login.
func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) {
var creds Credentials
err := json.NewDecoder(r.Body).Decode(&creds)
if err != nil {
U.HandleErr(w, r, err, http.StatusBadRequest)
return
}
if err := validatePassword(&creds); err != nil {
U.HandleErr(w, r, err, http.StatusUnauthorized)
return
}
if err := setAuthenticatedCookie(w, creds.Username); err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
6 changes: 5 additions & 1 deletion internal/api/v1/utils/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/yusing/go-proxy/internal/utils/strutils/ansi"
)

// HandleErr logs the error and returns an HTTP error response to the client.
// HandleErr logs the error and returns an error code to the client.
// If code is specified, it will be used as the HTTP status code; otherwise,
// http.StatusInternalServerError is used.
//
Expand All @@ -23,10 +23,14 @@ func HandleErr(w http.ResponseWriter, r *http.Request, err error, code ...int) {
http.Error(w, http.StatusText(code[0]), code[0])
}

// RespondError returns error details to the client.
// If code is specified, it will be used as the HTTP status code; otherwise,
// http.StatusBadRequest is used.
func RespondError(w http.ResponseWriter, err error, code ...int) {
if len(code) == 0 {
code = []int{http.StatusBadRequest}
}
// strip ANSI color codes added from Error.WithSubject
http.Error(w, ansi.StripANSI(err.Error()), code[0])
}

Expand Down
2 changes: 1 addition & 1 deletion internal/api/v1/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

func WriteBody(w http.ResponseWriter, body []byte) {
if _, err := w.Write(body); err != nil {
HandleErr(w, nil, err)
logging.Err(err).Msg("failed to write body")
}
}

Expand Down
11 changes: 5 additions & 6 deletions internal/common/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ import (
var (
prefixes = []string{"GODOXY_", "GOPROXY_", ""}

IsTest = GetEnvBool("TEST", false) || strings.HasSuffix(os.Args[0], ".test")
IsDebug = GetEnvBool("DEBUG", IsTest)
IsDebugSkipAuth = GetEnvBool("DEBUG_SKIP_AUTH", false)
IsTrace = GetEnvBool("TRACE", false) && IsDebug
IsProduction = !IsTest && !IsDebug
IsTest = GetEnvBool("TEST", false) || strings.HasSuffix(os.Args[0], ".test")
IsDebug = GetEnvBool("DEBUG", IsTest)
IsTrace = GetEnvBool("TRACE", false) && IsDebug
IsProduction = !IsTest && !IsDebug

ProxyHTTPAddr,
ProxyHTTPHost,
Expand Down Expand Up @@ -46,7 +45,7 @@ var (
APIUser = GetEnvString("API_USER", "admin")
APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password"))

// OIDC Configuration
// OIDC Configuration.
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
Expand Down
1 change: 1 addition & 0 deletions internal/route/provider/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func makeEntries(cont *types.Container, dockerHostIP ...string) route.RawEntries
} else {
host = client.DefaultDockerHost
}
p.name = "test"
entries := E.Must(p.entriesFromContainerLabels(D.FromDocker(cont, host)))
entries.RangeAll(func(k string, v *route.RawEntry) {
v.Finalize()
Expand Down

0 comments on commit 76fe534

Please sign in to comment.