diff --git a/handlers/adfs/adfs.go b/handlers/adfs/adfs.go index 4280e839..f738b339 100644 --- a/handlers/adfs/adfs.go +++ b/handlers/adfs/adfs.go @@ -14,6 +14,8 @@ import ( "strings" ) +type Handler struct{} + type adfsTokenRes struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` @@ -26,7 +28,7 @@ var ( ) // More info: https://docs.microsoft.com/en-us/windows-server/identity/ad-fs/overview/ad-fs-scenarios-for-developers#supported-scenarios -func GetUserInfoFromADFS(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { +func (Handler) GetUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { code := r.URL.Query().Get("code") log.Debugf("code: %s", code) diff --git a/handlers/github/github.go b/handlers/github/github.go index db8a7367..501996ed 100644 --- a/handlers/github/github.go +++ b/handlers/github/github.go @@ -11,14 +11,18 @@ import ( "strings" ) +type Handler struct { + PrepareTokensAndClient func(*http.Request, *structs.PTokens, bool) (error, *http.Client, *oauth2.Token) +} + var ( log = cfg.Cfg.Logger ) // github // https://developer.github.com/apps/building-integrations/setting-up-and-registering-oauth-apps/about-authorization-options-for-oauth-apps/ -func GetUserInfoFromGitHub(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { - err, client, ptoken := common.PrepareTokensAndClient(r, ptokens, true) +func (me Handler) GetUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { + err, client, ptoken := me.PrepareTokensAndClient(r, ptokens, true) if err != nil { // http.Error(w, err.Error(), http.StatusBadRequest) return err diff --git a/handlers/github/github_test.go b/handlers/github/github_test.go index 8f7c2c2d..f4ec0f10 100644 --- a/handlers/github/github_test.go +++ b/handlers/github/github_test.go @@ -2,15 +2,14 @@ package github import ( "encoding/json" + mockhttp "github.com/karupanerura/go-mock-http-response" + "github.com/stretchr/testify/assert" "github.com/vouch/vouch-proxy/pkg/cfg" "github.com/vouch/vouch-proxy/pkg/domains" "github.com/vouch/vouch-proxy/pkg/structs" "golang.org/x/oauth2" "net/http" "regexp" - - mockhttp "github.com/karupanerura/go-mock-http-response" - "github.com/stretchr/testify/assert" "testing" ) @@ -156,7 +155,7 @@ func TestGetOrgMembershipStateFromGitHubNoOrgAccess(t *testing.T) { assertUrlCalled(t, expectedOrgPublicMembershipUrl) } -func TestGetUserInfoFromGitHub(t *testing.T) { +func TestGetUserInfo(t *testing.T) { setUp() userInfoContent, _ := json.Marshal(structs.GitHubUser{ @@ -178,7 +177,10 @@ func TestGetUserInfoFromGitHub(t *testing.T) { mockResponse(regexMatcher(".*teams.*"), http.StatusOK, map[string]string{}, []byte("{\"state\": \"active\"}")) mockResponse(regexMatcher(".*members.*"), http.StatusNoContent, map[string]string{}, []byte("")) - err := GetUserInfoFromGitHub(client, user, &structs.CustomClaims{}, token) + handler := Handler{PrepareTokensAndClient: func(_ *http.Request, _ *structs.PTokens, _ bool) (error, *http.Client, *oauth2.Token) { + return nil, client, token + }} + err := handler.GetUserInfo(nil, user, &structs.CustomClaims{}, &structs.PTokens{}) assert.Nil(t, err) assert.Equal(t, "myusername", user.Username) diff --git a/handlers/google/google.go b/handlers/google/google.go index aecd5485..cfcd0d18 100644 --- a/handlers/google/google.go +++ b/handlers/google/google.go @@ -9,11 +9,13 @@ import ( "net/http" ) +type Handler struct{} + var ( log = cfg.Cfg.Logger ) -func GetUserInfoFromGoogle(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { +func (Handler) GetUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { err, client, _ := common.PrepareTokensAndClient(r, ptokens, true) if err != nil { return err diff --git a/handlers/handlers.go b/handlers/handlers.go index 302ca65b..45f94222 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -3,6 +3,7 @@ package handlers import ( "fmt" "github.com/vouch/vouch-proxy/handlers/adfs" + "github.com/vouch/vouch-proxy/handlers/common" "github.com/vouch/vouch-proxy/handlers/github" "github.com/vouch/vouch-proxy/handlers/google" "github.com/vouch/vouch-proxy/handlers/homeassistant" @@ -43,6 +44,10 @@ type AuthError struct { JWT string } +type Handler interface { + GetUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) error +} + const ( base64Bytes = 32 ) @@ -527,32 +532,30 @@ func CallbackHandler(w http.ResponseWriter, r *http.Request) { renderIndex(w, "/auth "+tokenstring) } -// TODO: put all getUserInfo logic into its own pkg - func getUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) error { + return getHandler().GetUserInfo(r, user, customClaims, ptokens) +} - // indieauth sends the "me" setting in json back to the callback, so just pluck it from the callback - if cfg.GenOAuth.Provider == cfg.Providers.IndieAuth { - return indieauth.GetUserInfoFromIndieAuth(r, user, customClaims, ptokens) - } else if cfg.GenOAuth.Provider == cfg.Providers.ADFS { - return adfs.GetUserInfoFromADFS(r, user, customClaims, ptokens) - } - if cfg.GenOAuth.Provider == cfg.Providers.HomeAssistant { - return homeassistant.GetUserInfoFromHomeAssistant(r, user, customClaims, ptokens) - } - if cfg.GenOAuth.Provider == cfg.Providers.OpenStax { - return openstax.GetUserInfoFromOpenStax(r, user, customClaims, ptokens) - } - - if cfg.GenOAuth.Provider == cfg.Providers.Google { - return google.GetUserInfoFromGoogle(r, user, customClaims, ptokens) - } else if cfg.GenOAuth.Provider == cfg.Providers.GitHub { - return github.GetUserInfoFromGitHub(r, user, customClaims, ptokens) - } else if cfg.GenOAuth.Provider == cfg.Providers.OIDC { - return openid.GetUserInfoFromOpenID(r, user, customClaims, ptokens) +func getHandler() Handler { + switch cfg.GenOAuth.Provider { + case cfg.Providers.IndieAuth: + return indieauth.Handler{} + case cfg.Providers.ADFS: + return adfs.Handler{} + case cfg.Providers.HomeAssistant: + return homeassistant.Handler{} + case cfg.Providers.OpenStax: + return openstax.Handler{} + case cfg.Providers.Google: + return google.Handler{} + case cfg.Providers.GitHub: + return github.Handler{common.PrepareTokensAndClient} + case cfg.Providers.OIDC: + return openid.Handler{} + default: + log.Error("we don't know how to look up the user info") + return nil } - log.Error("we don't know how to look up the user info") - return nil } // the standard error diff --git a/handlers/homeassistant/homeassistant.go b/handlers/homeassistant/homeassistant.go index 1695fae7..597e4582 100644 --- a/handlers/homeassistant/homeassistant.go +++ b/handlers/homeassistant/homeassistant.go @@ -6,8 +6,10 @@ import ( "net/http" ) +type Handler struct{} + // More info: https://developers.home-assistant.io/docs/en/auth_api.html -func GetUserInfoFromHomeAssistant(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { +func (Handler) GetUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { err, _, providerToken := common.PrepareTokensAndClient(r, ptokens, false) if err != nil { return err diff --git a/handlers/indieauth/indieauth.go b/handlers/indieauth/indieauth.go index f0b97135..53897d58 100644 --- a/handlers/indieauth/indieauth.go +++ b/handlers/indieauth/indieauth.go @@ -11,12 +11,14 @@ import ( "net/http" ) +type Handler struct{} + var ( log = cfg.Cfg.Logger ) -func GetUserInfoFromIndieAuth(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { - +func (Handler) GetUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { + // indieauth sends the "me" setting in json back to the callback, so just pluck it from the callback code := r.URL.Query().Get("code") log.Errorf("ptoken.AccessToken: %s", code) var b bytes.Buffer diff --git a/handlers/openid/openid.go b/handlers/openid/openid.go index 61a9a0b4..3b6025a3 100644 --- a/handlers/openid/openid.go +++ b/handlers/openid/openid.go @@ -9,11 +9,13 @@ import ( "net/http" ) +type Handler struct{} + var ( log = cfg.Cfg.Logger ) -func GetUserInfoFromOpenID(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { +func (Handler) GetUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { err, client, _ := common.PrepareTokensAndClient(r, ptokens, true) if err != nil { return err diff --git a/handlers/openstax/openstax.go b/handlers/openstax/openstax.go index 583cc47a..60212668 100644 --- a/handlers/openstax/openstax.go +++ b/handlers/openstax/openstax.go @@ -9,11 +9,13 @@ import ( "net/http" ) +type Handler struct{} + var ( log = cfg.Cfg.Logger ) -func GetUserInfoFromOpenStax(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { +func (Handler) GetUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens) (rerr error) { err, client, _ := common.PrepareTokensAndClient(r, ptokens, false) if err != nil { return err