From bafa65c5da35cc75793edf29602e2e9817ec6200 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 26 Oct 2023 15:43:54 +0800 Subject: [PATCH 1/4] Refactor Find Sources and fix bug when view a user who belongs to an unactive auth source --- cmd/admin_auth.go | 2 +- models/activities/statistic.go | 2 +- models/auth/oauth2.go | 9 ---- models/auth/source.go | 49 ++++++++----------- routers/web/admin/auths.go | 6 +-- routers/web/admin/users.go | 10 ++-- routers/web/user/setting/security/security.go | 6 ++- services/auth/signin.go | 5 +- services/auth/source/oauth2/init.go | 9 +++- services/auth/source/oauth2/providers.go | 22 ++++++--- services/auth/sspi.go | 5 +- services/auth/sync.go | 2 +- 12 files changed, 69 insertions(+), 58 deletions(-) diff --git a/cmd/admin_auth.go b/cmd/admin_auth.go index 3b308d77f7987..014ddf329f94d 100644 --- a/cmd/admin_auth.go +++ b/cmd/admin_auth.go @@ -62,7 +62,7 @@ func runListAuth(c *cli.Context) error { return err } - authSources, err := auth_model.Sources(ctx) + authSources, err := auth_model.FindSources(ctx, auth_model.FindSourcesOptions{}) if err != nil { return err } diff --git a/models/activities/statistic.go b/models/activities/statistic.go index 009c8c5ab474e..e9dab6fc10b6b 100644 --- a/models/activities/statistic.go +++ b/models/activities/statistic.go @@ -102,7 +102,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) { stats.Counter.Follow, _ = e.Count(new(user_model.Follow)) stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror)) stats.Counter.Release, _ = e.Count(new(repo_model.Release)) - stats.Counter.AuthSource = auth.CountSources(ctx) + stats.Counter.AuthSource = auth.CountSources(ctx, auth.FindSourcesOptions{}) stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook)) stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone)) stats.Counter.Label, _ = e.Count(new(issues_model.Label)) diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go index d73ad6965d2f0..76a4e9d835bcd 100644 --- a/models/auth/oauth2.go +++ b/models/auth/oauth2.go @@ -631,15 +631,6 @@ func (err ErrOAuthApplicationNotFound) Unwrap() error { return util.ErrNotExist } -// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources -func GetActiveOAuth2ProviderSources(ctx context.Context) ([]*Source, error) { - sources := make([]*Source, 0, 1) - if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil { - return nil, err - } - return sources, nil -} - // GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name func GetActiveOAuth2SourceByName(ctx context.Context, name string) (*Source, error) { authSource := new(Source) diff --git a/models/auth/source.go b/models/auth/source.go index 0f57d1702a774..b3f3262cc206c 100644 --- a/models/auth/source.go +++ b/models/auth/source.go @@ -14,6 +14,7 @@ import ( "code.gitea.io/gitea/modules/timeutil" "code.gitea.io/gitea/modules/util" + "xorm.io/builder" "xorm.io/xorm" "xorm.io/xorm/convert" ) @@ -240,37 +241,26 @@ func CreateSource(ctx context.Context, source *Source) error { return err } -// Sources returns a slice of all login sources found in DB. -func Sources(ctx context.Context) ([]*Source, error) { - auths := make([]*Source, 0, 6) - return auths, db.GetEngine(ctx).Find(&auths) +type FindSourcesOptions struct { + IsActive util.OptionalBool + LoginType Type } -// SourcesByType returns all sources of the specified type -func SourcesByType(ctx context.Context, loginType Type) ([]*Source, error) { - sources := make([]*Source, 0, 1) - if err := db.GetEngine(ctx).Where("type = ?", loginType).Find(&sources); err != nil { - return nil, err +func (opts FindSourcesOptions) ToConds() builder.Cond { + conds := builder.NewCond() + if !opts.IsActive.IsNone() { + conds = conds.And(builder.Eq{"is_active": opts.IsActive.IsTrue()}) } - return sources, nil -} - -// AllActiveSources returns all active sources -func AllActiveSources(ctx context.Context) ([]*Source, error) { - sources := make([]*Source, 0, 5) - if err := db.GetEngine(ctx).Where("is_active = ?", true).Find(&sources); err != nil { - return nil, err + if opts.LoginType != NoType { + conds = conds.And(builder.Eq{"`type`": opts.LoginType}) } - return sources, nil + return conds } -// ActiveSources returns all active sources of the specified type -func ActiveSources(ctx context.Context, tp Type) ([]*Source, error) { - sources := make([]*Source, 0, 1) - if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil { - return nil, err - } - return sources, nil +// FindSources returns a slice of login sources found in DB according to given conditions. +func FindSources(ctx context.Context, opts FindSourcesOptions) ([]*Source, error) { + auths := make([]*Source, 0, 6) + return auths, db.GetEngine(ctx).Where(opts.ToConds()).Find(&auths) } // IsSSPIEnabled returns true if there is at least one activated login @@ -279,7 +269,10 @@ func IsSSPIEnabled(ctx context.Context) bool { if !db.HasEngine { return false } - sources, err := ActiveSources(ctx, SSPI) + sources, err := FindSources(ctx, FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + LoginType: SSPI, + }) if err != nil { log.Error("ActiveSources: %v", err) return false @@ -354,8 +347,8 @@ func UpdateSource(ctx context.Context, source *Source) error { } // CountSources returns number of login sources. -func CountSources(ctx context.Context) int64 { - count, _ := db.GetEngine(ctx).Count(new(Source)) +func CountSources(ctx context.Context, opts FindSourcesOptions) int64 { + count, _ := db.GetEngine(ctx).Where(opts.ToConds()).Count(new(Source)) return count } diff --git a/routers/web/admin/auths.go b/routers/web/admin/auths.go index da91e31efe11c..550def1d64d6f 100644 --- a/routers/web/admin/auths.go +++ b/routers/web/admin/auths.go @@ -48,13 +48,13 @@ func Authentications(ctx *context.Context) { ctx.Data["PageIsAdminAuthentications"] = true var err error - ctx.Data["Sources"], err = auth.Sources(ctx) + ctx.Data["Sources"], err = auth.FindSources(ctx, auth.FindSourcesOptions{}) if err != nil { ctx.ServerError("auth.Sources", err) return } - ctx.Data["Total"] = auth.CountSources(ctx) + ctx.Data["Total"] = auth.CountSources(ctx, auth.FindSourcesOptions{}) ctx.HTML(http.StatusOK, tplAuths) } @@ -284,7 +284,7 @@ func NewAuthSourcePost(ctx *context.Context) { ctx.RenderWithErr(err.Error(), tplAuthNew, form) return } - existing, err := auth.SourcesByType(ctx, auth.SSPI) + existing, err := auth.FindSources(ctx, auth.FindSourcesOptions{LoginType: auth.SSPI}) if err != nil || len(existing) > 0 { ctx.Data["Err_Type"] = true ctx.RenderWithErr(ctx.Tr("admin.auths.login_source_of_type_exist"), tplAuthNew, form) diff --git a/routers/web/admin/users.go b/routers/web/admin/users.go index 91a578fb55482..630d739836beb 100644 --- a/routers/web/admin/users.go +++ b/routers/web/admin/users.go @@ -90,7 +90,9 @@ func NewUser(ctx *context.Context) { ctx.Data["login_type"] = "0-0" - sources, err := auth.Sources(ctx) + sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + }) if err != nil { ctx.ServerError("auth.Sources", err) return @@ -109,7 +111,9 @@ func NewUserPost(ctx *context.Context) { ctx.Data["DefaultUserVisibilityMode"] = setting.Service.DefaultUserVisibilityMode ctx.Data["AllowedUserVisibilityModes"] = setting.Service.AllowedUserVisibilityModesSlice.ToVisibleTypeSlice() - sources, err := auth.Sources(ctx) + sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + }) if err != nil { ctx.ServerError("auth.Sources", err) return @@ -230,7 +234,7 @@ func prepareUserInfo(ctx *context.Context) *user_model.User { ctx.Data["LoginSource"] = &auth.Source{} } - sources, err := auth.Sources(ctx) + sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{}) if err != nil { ctx.ServerError("auth.Sources", err) return nil diff --git a/routers/web/user/setting/security/security.go b/routers/web/user/setting/security/security.go index c687f7314d401..ae0b4bc3a974c 100644 --- a/routers/web/user/setting/security/security.go +++ b/routers/web/user/setting/security/security.go @@ -12,6 +12,7 @@ import ( "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/services/auth/source/oauth2" ) @@ -105,9 +106,10 @@ func loadSecurityData(ctx *context.Context) { } ctx.Data["AccountLinks"] = sources - orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx) + // here we need to load all possible auth sources because a linked account maybe is using an unactive auth source + orderedOAuth2Names, oauth2Providers, err := oauth2.GetOAuth2ProvidersMap(ctx, util.OptionalBoolNone) if err != nil { - ctx.ServerError("GetActiveOAuth2Providers", err) + ctx.ServerError("GetOAuth2ProvidersMap", err) return } ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names diff --git a/services/auth/signin.go b/services/auth/signin.go index 5fdf6d2bd7ab2..2e534536817e4 100644 --- a/services/auth/signin.go +++ b/services/auth/signin.go @@ -11,6 +11,7 @@ import ( "code.gitea.io/gitea/models/db" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/services/auth/source/oauth2" "code.gitea.io/gitea/services/auth/source/smtp" @@ -85,7 +86,9 @@ func UserSignIn(ctx context.Context, username, password string) (*user_model.Use } } - sources, err := auth.AllActiveSources(ctx) + sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + }) if err != nil { return nil, nil, err } diff --git a/services/auth/source/oauth2/init.go b/services/auth/source/oauth2/init.go index cfaddaa35d55c..0ebbdaebd411f 100644 --- a/services/auth/source/oauth2/init.go +++ b/services/auth/source/oauth2/init.go @@ -12,6 +12,7 @@ import ( "code.gitea.io/gitea/models/auth" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" "github.com/google/uuid" "github.com/gorilla/sessions" @@ -63,7 +64,13 @@ func ResetOAuth2(ctx context.Context) error { // initOAuth2Sources is used to load and register all active OAuth2 providers func initOAuth2Sources(ctx context.Context) error { - authSources, _ := auth.GetActiveOAuth2ProviderSources(ctx) + authSources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + LoginType: auth.OAuth2, + }) + if err != nil { + return err + } for _, source := range authSources { oauth2Source, ok := source.Cfg.(*Source) if !ok { diff --git a/services/auth/source/oauth2/providers.go b/services/auth/source/oauth2/providers.go index cd158614a2e4e..56fabb264d21a 100644 --- a/services/auth/source/oauth2/providers.go +++ b/services/auth/source/oauth2/providers.go @@ -15,6 +15,7 @@ import ( "code.gitea.io/gitea/models/auth" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" "github.com/markbates/goth" ) @@ -95,13 +96,12 @@ func GetOAuth2Providers() []Provider { return providers } -// GetActiveOAuth2Providers returns the map of configured active OAuth2 providers -// key is used as technical name (like in the callbackURL) -// values to display -func GetActiveOAuth2Providers(ctx context.Context) ([]string, map[string]Provider, error) { - // Maybe also separate used and unused providers so we can force the registration of only 1 active provider for each type - - authSources, err := auth.GetActiveOAuth2ProviderSources(ctx) +// GetOAuth2ProvidersMap returns the map of configured OAuth2 providers +func GetOAuth2ProvidersMap(ctx context.Context, isActive util.OptionalBool) ([]string, map[string]Provider, error) { + authSources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: isActive, + LoginType: auth.OAuth2, + }) if err != nil { return nil, nil, err } @@ -124,6 +124,14 @@ func GetActiveOAuth2Providers(ctx context.Context) ([]string, map[string]Provide return orderedKeys, providers, nil } +// GetActiveOAuth2Providers returns the map of configured active OAuth2 providers +// key is used as technical name (like in the callbackURL) +// values to display +func GetActiveOAuth2Providers(ctx context.Context) ([]string, map[string]Provider, error) { + // Maybe also separate used and unused providers so we can force the registration of only 1 active provider for each type + return GetOAuth2ProvidersMap(ctx, util.OptionalBoolTrue) +} + // RegisterProviderWithGothic register a OAuth2 provider in goth lib func RegisterProviderWithGothic(providerName string, source *Source) error { provider, err := createProvider(providerName, source) diff --git a/services/auth/sspi.go b/services/auth/sspi.go index 573d94b42c2c0..bc8ec948f29cd 100644 --- a/services/auth/sspi.go +++ b/services/auth/sspi.go @@ -130,7 +130,10 @@ func (s *SSPI) Verify(req *http.Request, w http.ResponseWriter, store DataStore, // getConfig retrieves the SSPI configuration from login sources func (s *SSPI) getConfig(ctx context.Context) (*sspi.Source, error) { - sources, err := auth.ActiveSources(ctx, auth.SSPI) + sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + LoginType: auth.SSPI, + }) if err != nil { return nil, err } diff --git a/services/auth/sync.go b/services/auth/sync.go index 25b9460b9921f..11a59d41ae1b4 100644 --- a/services/auth/sync.go +++ b/services/auth/sync.go @@ -15,7 +15,7 @@ import ( func SyncExternalUsers(ctx context.Context, updateExisting bool) error { log.Trace("Doing: SyncExternalUsers") - ls, err := auth.Sources(ctx) + ls, err := auth.FindSources(ctx, auth.FindSourcesOptions{}) if err != nil { log.Error("SyncExternalUsers: %v", err) return err From dae54476e97a6ad1ca6dccbbd79dd4b0a114c882 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 26 Oct 2023 18:59:05 +0800 Subject: [PATCH 2/4] Only list active auth source names --- routers/web/user/setting/security/security.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/routers/web/user/setting/security/security.go b/routers/web/user/setting/security/security.go index ae0b4bc3a974c..1abc90f60c040 100644 --- a/routers/web/user/setting/security/security.go +++ b/routers/web/user/setting/security/security.go @@ -107,7 +107,7 @@ func loadSecurityData(ctx *context.Context) { ctx.Data["AccountLinks"] = sources // here we need to load all possible auth sources because a linked account maybe is using an unactive auth source - orderedOAuth2Names, oauth2Providers, err := oauth2.GetOAuth2ProvidersMap(ctx, util.OptionalBoolNone) + orderedOAuth2Names, oauth2Providers, err := oauth2.GetOAuth2ProvidersMap(ctx, util.OptionalBoolTrue) if err != nil { ctx.ServerError("GetOAuth2ProvidersMap", err) return From f3570c41d4f93bbd8ab84840c9fefbb142d60831 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 26 Oct 2023 19:22:52 +0800 Subject: [PATCH 3/4] Fix bug --- routers/web/user/setting/security/security.go | 24 +++++++++++++++++-- services/auth/source/oauth2/providers.go | 19 ++++++++++----- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/routers/web/user/setting/security/security.go b/routers/web/user/setting/security/security.go index 1abc90f60c040..5dcf0c9b4363e 100644 --- a/routers/web/user/setting/security/security.go +++ b/routers/web/user/setting/security/security.go @@ -6,6 +6,7 @@ package security import ( "net/http" + "sort" auth_model "code.gitea.io/gitea/models/auth" user_model "code.gitea.io/gitea/models/user" @@ -106,12 +107,31 @@ func loadSecurityData(ctx *context.Context) { } ctx.Data["AccountLinks"] = sources - // here we need to load all possible auth sources because a linked account maybe is using an unactive auth source - orderedOAuth2Names, oauth2Providers, err := oauth2.GetOAuth2ProvidersMap(ctx, util.OptionalBoolTrue) + authSources, err := auth_model.FindSources(ctx, auth_model.FindSourcesOptions{ + IsActive: util.OptionalBoolNone, + LoginType: auth_model.OAuth2, + }) if err != nil { ctx.ServerError("GetOAuth2ProvidersMap", err) return } + + var orderedOAuth2Names []string + oauth2Providers := make(map[string]oauth2.Provider) + for _, source := range authSources { + provider, err := oauth2.CreateProviderFromSource(source) + if err != nil { + ctx.ServerError("CreateProviderFromSource", err) + return + } + oauth2Providers[source.Name] = provider + if source.IsActive { + orderedOAuth2Names = append(orderedOAuth2Names, source.Name) + } + } + + sort.Strings(orderedOAuth2Names) + ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers diff --git a/services/auth/source/oauth2/providers.go b/services/auth/source/oauth2/providers.go index 56fabb264d21a..cedc723f3ae72 100644 --- a/services/auth/source/oauth2/providers.go +++ b/services/auth/source/oauth2/providers.go @@ -96,6 +96,15 @@ func GetOAuth2Providers() []Provider { return providers } +func CreateProviderFromSource(source *auth.Source) (Provider, error) { + oauth2Cfg, ok := source.Cfg.(*Source) + if !ok { + return nil, fmt.Errorf("invalid OAuth2 source config: %v", oauth2Cfg) + } + gothProv := gothProviders[oauth2Cfg.Provider] + return &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL}, nil +} + // GetOAuth2ProvidersMap returns the map of configured OAuth2 providers func GetOAuth2ProvidersMap(ctx context.Context, isActive util.OptionalBool) ([]string, map[string]Provider, error) { authSources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ @@ -109,13 +118,11 @@ func GetOAuth2ProvidersMap(ctx context.Context, isActive util.OptionalBool) ([]s var orderedKeys []string providers := make(map[string]Provider) for _, source := range authSources { - oauth2Cfg, ok := source.Cfg.(*Source) - if !ok { - log.Error("Invalid OAuth2 source config: %v", oauth2Cfg) - continue + provider, err := CreateProviderFromSource(source) + if err != nil { + return nil, nil, err } - gothProv := gothProviders[oauth2Cfg.Provider] - providers[source.Name] = &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL} + providers[source.Name] = provider orderedKeys = append(orderedKeys, source.Name) } From cc527e43302ac9444428907e303c838e3b1b3d38 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 3 Nov 2023 00:11:27 +0800 Subject: [PATCH 4/4] remove copmlicated code --- routers/web/admin/auths.go | 8 ++--- routers/web/auth/auth.go | 12 +++---- routers/web/user/setting/security/security.go | 2 +- services/auth/source/oauth2/providers.go | 32 +++++++------------ templates/user/auth/signin_inner.tmpl | 7 ++-- templates/user/auth/signup_inner.tmpl | 7 ++-- 6 files changed, 27 insertions(+), 41 deletions(-) diff --git a/routers/web/admin/auths.go b/routers/web/admin/auths.go index 550def1d64d6f..23946d64afa33 100644 --- a/routers/web/admin/auths.go +++ b/routers/web/admin/auths.go @@ -99,7 +99,7 @@ func NewAuthSource(ctx *context.Context) { ctx.Data["AuthSources"] = authSources ctx.Data["SecurityProtocols"] = securityProtocols ctx.Data["SMTPAuths"] = smtp.Authenticators - oauth2providers := oauth2.GetOAuth2Providers() + oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers ctx.Data["SSPIAutoCreateUsers"] = true @@ -242,7 +242,7 @@ func NewAuthSourcePost(ctx *context.Context) { ctx.Data["AuthSources"] = authSources ctx.Data["SecurityProtocols"] = securityProtocols ctx.Data["SMTPAuths"] = smtp.Authenticators - oauth2providers := oauth2.GetOAuth2Providers() + oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers ctx.Data["SSPIAutoCreateUsers"] = true @@ -334,7 +334,7 @@ func EditAuthSource(ctx *context.Context) { ctx.Data["SecurityProtocols"] = securityProtocols ctx.Data["SMTPAuths"] = smtp.Authenticators - oauth2providers := oauth2.GetOAuth2Providers() + oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers source, err := auth.GetSourceByID(ctx, ctx.ParamsInt64(":authid")) @@ -368,7 +368,7 @@ func EditAuthSourcePost(ctx *context.Context) { ctx.Data["PageIsAdminAuthentications"] = true ctx.Data["SMTPAuths"] = smtp.Authenticators - oauth2providers := oauth2.GetOAuth2Providers() + oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers source, err := auth.GetSourceByID(ctx, ctx.ParamsInt64(":authid")) diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index e27307ef1afc6..0ea91fc759a9a 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -160,12 +160,11 @@ func SignIn(ctx *context.Context) { return } - orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx) + oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue) if err != nil { ctx.ServerError("UserSignIn", err) return } - ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers ctx.Data["Title"] = ctx.Tr("sign_in") ctx.Data["SignInLink"] = setting.AppSubURL + "/user/login" @@ -184,12 +183,11 @@ func SignIn(ctx *context.Context) { func SignInPost(ctx *context.Context) { ctx.Data["Title"] = ctx.Tr("sign_in") - orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx) + oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue) if err != nil { ctx.ServerError("UserSignIn", err) return } - ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers ctx.Data["Title"] = ctx.Tr("sign_in") ctx.Data["SignInLink"] = setting.AppSubURL + "/user/login" @@ -408,13 +406,12 @@ func SignUp(ctx *context.Context) { ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/sign_up" - orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx) + oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue) if err != nil { ctx.ServerError("UserSignUp", err) return } - ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers context.SetCaptchaData(ctx) @@ -438,13 +435,12 @@ func SignUpPost(ctx *context.Context) { ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/sign_up" - orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx) + oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue) if err != nil { ctx.ServerError("UserSignUp", err) return } - ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers context.SetCaptchaData(ctx) diff --git a/routers/web/user/setting/security/security.go b/routers/web/user/setting/security/security.go index 5dcf0c9b4363e..ec269776e2b6b 100644 --- a/routers/web/user/setting/security/security.go +++ b/routers/web/user/setting/security/security.go @@ -112,7 +112,7 @@ func loadSecurityData(ctx *context.Context) { LoginType: auth_model.OAuth2, }) if err != nil { - ctx.ServerError("GetOAuth2ProvidersMap", err) + ctx.ServerError("FindSources", err) return } diff --git a/services/auth/source/oauth2/providers.go b/services/auth/source/oauth2/providers.go index cedc723f3ae72..3b45b252f7099 100644 --- a/services/auth/source/oauth2/providers.go +++ b/services/auth/source/oauth2/providers.go @@ -81,10 +81,10 @@ func RegisterGothProvider(provider GothProvider) { gothProviders[provider.Name()] = provider } -// GetOAuth2Providers returns the map of unconfigured OAuth2 providers +// GetSupportedOAuth2Providers returns the map of unconfigured OAuth2 providers // key is used as technical name (like in the callbackURL) // values to display -func GetOAuth2Providers() []Provider { +func GetSupportedOAuth2Providers() []Provider { providers := make([]Provider, 0, len(gothProviders)) for _, provider := range gothProviders { @@ -105,38 +105,30 @@ func CreateProviderFromSource(source *auth.Source) (Provider, error) { return &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL}, nil } -// GetOAuth2ProvidersMap returns the map of configured OAuth2 providers -func GetOAuth2ProvidersMap(ctx context.Context, isActive util.OptionalBool) ([]string, map[string]Provider, error) { +// GetOAuth2Providers returns the list of configured OAuth2 providers +func GetOAuth2Providers(ctx context.Context, isActive util.OptionalBool) ([]Provider, error) { authSources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ IsActive: isActive, LoginType: auth.OAuth2, }) if err != nil { - return nil, nil, err + return nil, err } - var orderedKeys []string - providers := make(map[string]Provider) + providers := make([]Provider, 0, len(authSources)) for _, source := range authSources { provider, err := CreateProviderFromSource(source) if err != nil { - return nil, nil, err + return nil, err } - providers[source.Name] = provider - orderedKeys = append(orderedKeys, source.Name) + providers = append(providers, provider) } - sort.Strings(orderedKeys) - - return orderedKeys, providers, nil -} + sort.Slice(providers, func(i, j int) bool { + return providers[i].Name() < providers[j].Name() + }) -// GetActiveOAuth2Providers returns the map of configured active OAuth2 providers -// key is used as technical name (like in the callbackURL) -// values to display -func GetActiveOAuth2Providers(ctx context.Context) ([]string, map[string]Provider, error) { - // Maybe also separate used and unused providers so we can force the registration of only 1 active provider for each type - return GetOAuth2ProvidersMap(ctx, util.OptionalBoolTrue) + return providers, nil } // RegisterProviderWithGothic register a OAuth2 provider in goth lib diff --git a/templates/user/auth/signin_inner.tmpl b/templates/user/auth/signin_inner.tmpl index f38b0a26087f5..7f744b24d87bf 100644 --- a/templates/user/auth/signin_inner.tmpl +++ b/templates/user/auth/signin_inner.tmpl @@ -52,16 +52,15 @@ {{end}} - {{if and .OrderedOAuth2Names .OAuth2Providers}} + {{if .OAuth2Providers}}
{{ctx.Locale.Tr "sign_in_or"}}
- {{range $key := .OrderedOAuth2Names}} - {{$provider := index $.OAuth2Providers $key}} - diff --git a/templates/user/auth/signup_inner.tmpl b/templates/user/auth/signup_inner.tmpl index 068ccbc6182c1..c75e33a18a0e9 100644 --- a/templates/user/auth/signup_inner.tmpl +++ b/templates/user/auth/signup_inner.tmpl @@ -56,16 +56,15 @@ {{end}} {{end}} - {{if and .OrderedOAuth2Names .OAuth2Providers}} + {{if .OAuth2Providers}}
{{ctx.Locale.Tr "sign_in_or"}}