diff --git a/CHANGELOG.md b/CHANGELOG.md index 5dd2dbc3fa..e3703bcb0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## HEAD ### Added +* Ability to configure signing key per oauth provider. ## 1.18.0 diff --git a/app/app.go b/app/app.go index 73927f848f..f1c7d37dce 100644 --- a/app/app.go +++ b/app/app.go @@ -6,12 +6,11 @@ import ( "github.com/go-redis/redis/v8" "github.com/jmoiron/sqlx" "github.com/keratin/authn-server/app/data" + dataRedis "github.com/keratin/authn-server/app/data/redis" "github.com/keratin/authn-server/lib/oauth" "github.com/keratin/authn-server/ops" "github.com/pkg/errors" "github.com/sirupsen/logrus" - - dataRedis "github.com/keratin/authn-server/app/data/redis" ) type pinger func() bool @@ -100,22 +99,7 @@ func NewApp(cfg *Config, logger logrus.FieldLogger) (*App, error) { ) } - oauthProviders := map[string]oauth.Provider{} - if cfg.GoogleOauthCredentials != nil { - oauthProviders["google"] = *oauth.NewGoogleProvider(cfg.GoogleOauthCredentials) - } - if cfg.GitHubOauthCredentials != nil { - oauthProviders["github"] = *oauth.NewGitHubProvider(cfg.GitHubOauthCredentials) - } - if cfg.FacebookOauthCredentials != nil { - oauthProviders["facebook"] = *oauth.NewFacebookProvider(cfg.FacebookOauthCredentials) - } - if cfg.DiscordOauthCredentials != nil { - oauthProviders["discord"] = *oauth.NewDiscordProvider(cfg.DiscordOauthCredentials) - } - if cfg.MicrosoftOauthCredientials != nil { - oauthProviders["microsoft"] = *oauth.NewMicrosoftProvider(cfg.MicrosoftOauthCredientials) - } + oauthProviders := initializeOAuthProviders(cfg) return &App{ // Provide access to root DB - useful when extending AccountStore functionality @@ -133,3 +117,23 @@ func NewApp(cfg *Config, logger logrus.FieldLogger) (*App, error) { Logger: logger, }, nil } + +func initializeOAuthProviders(cfg *Config) map[string]oauth.Provider { + oauthProviders := make(map[string]oauth.Provider) + if cfg.GoogleOauthCredentials != nil { + oauthProviders["google"] = *oauth.NewGoogleProvider(cfg.GoogleOauthCredentials) + } + if cfg.GitHubOauthCredentials != nil { + oauthProviders["github"] = *oauth.NewGitHubProvider(cfg.GitHubOauthCredentials) + } + if cfg.FacebookOauthCredentials != nil { + oauthProviders["facebook"] = *oauth.NewFacebookProvider(cfg.FacebookOauthCredentials) + } + if cfg.DiscordOauthCredentials != nil { + oauthProviders["discord"] = *oauth.NewDiscordProvider(cfg.DiscordOauthCredentials) + } + if cfg.MicrosoftOauthCredentials != nil { + oauthProviders["microsoft"] = *oauth.NewMicrosoftProvider(cfg.MicrosoftOauthCredentials) + } + return oauthProviders +} diff --git a/app/config.go b/app/config.go index 838709697e..5f51a903a1 100644 --- a/app/config.go +++ b/app/config.go @@ -76,7 +76,7 @@ type Config struct { GitHubOauthCredentials *oauth.Credentials FacebookOauthCredentials *oauth.Credentials DiscordOauthCredentials *oauth.Credentials - MicrosoftOauthCredientials *oauth.Credentials + MicrosoftOauthCredentials *oauth.Credentials RefreshTokenExplicitExpiry bool } @@ -86,7 +86,7 @@ func (c *Config) OAuthEnabled() bool { c.GitHubOauthCredentials != nil || c.FacebookOauthCredentials != nil || c.DiscordOauthCredentials != nil || - c.MicrosoftOauthCredientials != nil + c.MicrosoftOauthCredentials != nil } // SameSiteComputed returns either the specified http.SameSite, or a computed one from OAuth config @@ -586,11 +586,11 @@ var configurers = []configurer{ return nil }, - // GOOGLE_OAUTH_CREDENTIALS is a credential pair in the format `id:secret`. When specified, - // AuthN will enable routes for Google OAuth signin. + // GOOGLE_OAUTH_CREDENTIALS is a credential pair in the format `id:secret:signing_key(optional)`. + // When specified, AuthN will enable routes for Google OAuth signin. func(c *Config) error { if val, ok := os.LookupEnv("GOOGLE_OAUTH_CREDENTIALS"); ok { - credentials, err := oauth.NewCredentials(val) + credentials, err := oauth.NewCredentials(val, c.OAuthSigningKey) if err == nil { c.GoogleOauthCredentials = credentials } @@ -599,11 +599,11 @@ var configurers = []configurer{ return nil }, - // GITHUB_OAUTH_CREDENTIALS is a credential pair in the format `id:secret`. When specified, - // AuthN will enable routes for GitHub OAuth signin. + // GITHUB_OAUTH_CREDENTIALS is a credential pair in the format `id:secret:signing_key(optional)`. + // When specified, AuthN will enable routes for GitHub OAuth signin. func(c *Config) error { if val, ok := os.LookupEnv("GITHUB_OAUTH_CREDENTIALS"); ok { - credentials, err := oauth.NewCredentials(val) + credentials, err := oauth.NewCredentials(val, c.OAuthSigningKey) if err == nil { c.GitHubOauthCredentials = credentials } @@ -612,11 +612,11 @@ var configurers = []configurer{ return nil }, - // FACEBOOK_OAUTH_CREDENTIALS is a credential pair in the format `id:secret`. When specified, - // AuthN will enable routes for Facebook OAuth signin. + // FACEBOOK_OAUTH_CREDENTIALS is a credential pair in the format `id:secret:signing_key(optional)`. + // When specified, AuthN will enable routes for Facebook OAuth signin. func(c *Config) error { if val, ok := os.LookupEnv("FACEBOOK_OAUTH_CREDENTIALS"); ok { - credentials, err := oauth.NewCredentials(val) + credentials, err := oauth.NewCredentials(val, c.OAuthSigningKey) if err == nil { c.FacebookOauthCredentials = credentials } @@ -625,11 +625,11 @@ var configurers = []configurer{ return nil }, - // DISCORD_OAUTH_CREDENTIALS is a credential pair in the format `id:secret`. When specified, - // AuthN will enable routes for Discord OAuth signin. + // DISCORD_OAUTH_CREDENTIALS is a credential pair in the format `id:secret:signing_key(optional)`. + // When specified, AuthN will enable routes for Discord OAuth signin. func(c *Config) error { if val, ok := os.LookupEnv("DISCORD_OAUTH_CREDENTIALS"); ok { - credentials, err := oauth.NewCredentials(val) + credentials, err := oauth.NewCredentials(val, c.OAuthSigningKey) if err == nil { c.DiscordOauthCredentials = credentials } @@ -638,13 +638,13 @@ var configurers = []configurer{ return nil }, - // Microsoft_OAUTH_CREDENTIALS is a credential pair in the format `id:secret`. When specified, - // AuthN will enable routes for Discord OAuth signin. + // Microsoft_OAUTH_CREDENTIALS is a credential pair in the format `id:secret:signing_key(optional)`. + // When specified, AuthN will enable routes for Microsoft OAuth signin. func(c *Config) error { if val, ok := os.LookupEnv("MICROSOFT_OAUTH_CREDENTIALS"); ok { - credentials, err := oauth.NewCredentials(val) + credentials, err := oauth.NewCredentials(val, c.OAuthSigningKey) if err == nil { - c.MicrosoftOauthCredientials = credentials + c.MicrosoftOauthCredentials = credentials } return err } diff --git a/app/tokens/oauth/oauth.go b/app/tokens/oauth/oauth.go index 836f4edac6..b8d9f2af23 100644 --- a/app/tokens/oauth/oauth.go +++ b/app/tokens/oauth/oauth.go @@ -23,9 +23,9 @@ type Claims struct { } // Sign converts the claims into a serialized string, signed with HMAC. -func (c *Claims) Sign(hmacKey []byte) (string, error) { +func (c *Claims) Sign(signingKey jose.SigningKey) (string, error) { signer, err := jose.NewSigner( - jose.SigningKey{Algorithm: jose.HS256, Key: hmacKey}, + signingKey, (&jose.SignerOptions{}).WithType("JWT"), ) if err != nil { @@ -68,7 +68,7 @@ func Parse(tokenStr string, cfg *app.Config, nonce string) (*Claims, error) { // New creates Claims for a JWT suitable as a state parameter during an OAuth flow. func New(cfg *app.Config, nonce string, destination string) (*Claims, error) { return &Claims{ - Scope: scope, + Scope: scope, RequestForgeryProtection: nonce, Destination: destination, Claims: jwt.Claims{ diff --git a/app/tokens/oauth/oauth_test.go b/app/tokens/oauth/oauth_test.go index e3fb255e58..9c73dd3f40 100644 --- a/app/tokens/oauth/oauth_test.go +++ b/app/tokens/oauth/oauth_test.go @@ -8,6 +8,7 @@ import ( "github.com/keratin/authn-server/app/tokens/oauth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/square/go-jose.v2" ) func TestOAuthToken(t *testing.T) { @@ -28,7 +29,7 @@ func TestOAuthToken(t *testing.T) { assert.True(t, token.Audience.Contains("https://authn.example.com")) assert.NotEmpty(t, token.IssuedAt) - tokenStr, err := token.Sign(cfg.OAuthSigningKey) + tokenStr, err := token.Sign(jose.SigningKey{Algorithm: jose.HS256, Key: cfg.OAuthSigningKey}) require.NoError(t, err) _, err = oauth.Parse(tokenStr, cfg, nonce) @@ -39,7 +40,7 @@ func TestOAuthToken(t *testing.T) { token, err := oauth.New(cfg, nonce, "https://app.example.com/return") require.NoError(t, err) - tokenStr, err := token.Sign(cfg.OAuthSigningKey) + tokenStr, err := token.Sign(jose.SigningKey{Algorithm: jose.HS256, Key: cfg.OAuthSigningKey}) require.NoError(t, err) _, err = oauth.Parse(tokenStr, cfg, "wrong") @@ -53,7 +54,7 @@ func TestOAuthToken(t *testing.T) { } token, err := oauth.New(cfg, nonce, "https://app.example.com/return") require.NoError(t, err) - tokenStr, err := token.Sign(oldCfg.OAuthSigningKey) + tokenStr, err := token.Sign(jose.SigningKey{Algorithm: jose.HS256, Key: oldCfg.OAuthSigningKey}) require.NoError(t, err) _, err = oauth.Parse(tokenStr, cfg, nonce) assert.Error(t, err) @@ -66,7 +67,7 @@ func TestOAuthToken(t *testing.T) { } token, err := oauth.New(&oldCfg, nonce, "https://app.example.com/return") require.NoError(t, err) - tokenStr, err := token.Sign(cfg.OAuthSigningKey) + tokenStr, err := token.Sign(jose.SigningKey{Algorithm: jose.HS256, Key: cfg.OAuthSigningKey}) require.NoError(t, err) _, err = oauth.Parse(tokenStr, cfg, nonce) assert.Error(t, err) diff --git a/docs/config.md b/docs/config.md index 38f37dc188..20bf146c9f 100644 --- a/docs/config.md +++ b/docs/config.md @@ -261,53 +261,55 @@ or * `https://www.example.com/authn/oauth/google/return` +By default OAuth providers will use a key derived from `SECRET_KEY_BASE`. To override you can provide a hex-encoded string as the third segment in the colon-delimited environment variable. + ### `FACEBOOK_OAUTH_CREDENTIALS` -| | | -| --------- | --- | -| Required? | No | -| Value | AppID:AppSecret | -| Default | nil | +| | | +|-----------|--------------------------------------| +| Required? | No | +| Value | AppID:AppSecret:SigningKey(optional) | +| Default | nil | Create a Facebook app at https://developers.facebook.com and enable the Facebook Login product. In the Quickstart, enter [AuthN's OAuth Return](api.md#oauth-return) as the Site URL. Then switch over to Settings and find the App ID and Secret. Join those together with a `:` and provide them to AuthN as a single variable. ### `GITHUB_OAUTH_CREDENTIALS` -| | | -| --------- | --- | -| Required? | No | -| Value | ClientID:ClientSecret | -| Default | nil | +| | | +|-----------|--------------------------------------| +| Required? | No | +| Value | AppID:AppSecret:SigningKey(optional) | +| Default | nil | Sign up for GitHub OAuth 2.0 credentials with the instructions here: https://developer.github.com/apps/building-oauth-apps. Your client's ID and secret must be joined together with a `:` and provided to AuthN as a single variable. ### `GOOGLE_OAUTH_CREDENTIALS` -| | | -| --------- | --- | -| Required? | No | -| Value | ClientID:ClientSecret | -| Default | nil | +| | | +|-----------|--------------------------------------| +| Required? | No | +| Value | AppID:AppSecret:SigningKey(optional) | +| Default | nil | Sign up for Google OAuth 2.0 credentials with the instructions here: https://developers.google.com/identity/protocols/OpenIDConnect. Your client's ID and secret must be joined together with a `:` and provided to AuthN as a single variable. ### `DISCORD_OAUTH_CREDENTIALS` -| | | -| --------- | --- | -| Required? | No | -| Value | ClientID:ClientSecret | -| Default | nil | +| | | +|-----------|--------------------------------------| +| Required? | No | +| Value | AppID:AppSecret:SigningKey(optional) | +| Default | nil | Sign up for Discord OAuth 2.0 credentials with the instructions here: https://discordapp.com/developers/docs/topics/oauth2. Your client's ID and secret must be joined together with a `:` and provided to AuthN as a single variable. ### `MICROSOFT_OAUTH_CREDENTIALS` -| | | -| --------- | --- | -| Required? | No | -| Value | ClientID:ClientSecret | -| Default | nil | +| | | +|-----------|--------------------------------------| +| Required? | No | +| Value | AppID:AppSecret:SigningKey(optional) | +| Default | nil | Sign up for Microsoft OAuth 2.0 credentials with the instructions here: https://docs.microsoft.com/fr-fr/graph/auth/. Your client's ID and secret must be joined together with a `:` and provided to AuthN as a single variable. diff --git a/lib/oauth/credentials.go b/lib/oauth/credentials.go index 4a1888af86..e4a1b64c28 100644 --- a/lib/oauth/credentials.go +++ b/lib/oauth/credentials.go @@ -1,25 +1,41 @@ package oauth import ( + "encoding/hex" "errors" + "fmt" "strings" ) // Credentials is a configuration struct for OAuth Providers type Credentials struct { - ID string - Secret string + ID string + Secret string + SigningKey []byte } -// NewCredentials parses a credential string in the format `id:string` and returns a Credentials -// suitable for OAuth Provider configuration. -func NewCredentials(credentials string) (*Credentials, error) { - if strings.Count(credentials, ":") != 1 { - return nil, errors.New("Credentials must be in the format `id:string`") +// NewCredentials parses a credential string in the format `id:string:signing_key(optional)` +// and returns a Credentials suitable for OAuth Provider configuration. If no signing key is +// provided the default key is used. +func NewCredentials(credentials string, defaultKey []byte) (*Credentials, error) { + if strings.Count(credentials, ":") < 1 { + return nil, errors.New("Credentials must be in the format `id:string:signing_key(optional)`") } - strs := strings.SplitN(credentials, ":", 2) - return &Credentials{ + strs := strings.SplitN(credentials, ":", 3) + + c := &Credentials{ ID: strs[0], Secret: strs[1], - }, nil + } + + if len(strs) == 3 { + key, err := hex.DecodeString(strs[2]) + if err != nil { + return nil, fmt.Errorf("failed to decode signing key: %w", err) + } + c.SigningKey = key + } else { + c.SigningKey = defaultKey + } + return c, nil } diff --git a/lib/oauth/credentials_test.go b/lib/oauth/credentials_test.go new file mode 100644 index 0000000000..e4405e196c --- /dev/null +++ b/lib/oauth/credentials_test.go @@ -0,0 +1,57 @@ +package oauth + +import ( + "encoding/hex" + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestNewCredentials(t *testing.T) { + defaultKey := []byte("key-a-reno") + + t.Run("invalid credentials", func(t *testing.T) { + credentials, err := NewCredentials("id", defaultKey) + assert.NotNil(t, err) + assert.Equal(t, "Credentials must be in the format `id:string:signing_key(optional)`", err.Error()) + assert.Nil(t, credentials) + }) + + t.Run("valid credentials", func(t *testing.T) { + id := uuid.NewString() + secret := uuid.NewString() + + credentials, err := NewCredentials(fmt.Sprintf("%s:%s", id, secret), defaultKey) + assert.Nil(t, err) + assert.NotNil(t, credentials) + assert.Equal(t, id, credentials.ID) + assert.Equal(t, secret, credentials.Secret) + assert.Equal(t, defaultKey, credentials.SigningKey) + }) + + t.Run("valid credentials with signing key", func(t *testing.T) { + id := uuid.NewString() + secret := uuid.NewString() + signingKey := []byte("key-override-a-reno") + + credentials, err := NewCredentials(fmt.Sprintf("%s:%s:%s", id, secret, hex.EncodeToString(signingKey)), defaultKey) + assert.Nil(t, err) + assert.NotNil(t, credentials) + assert.Equal(t, id, credentials.ID) + assert.Equal(t, secret, credentials.Secret) + assert.Equal(t, signingKey, credentials.SigningKey) + }) + + t.Run("invalid signing key", func(t *testing.T) { + id := uuid.NewString() + secret := uuid.NewString() + badKey := fmt.Sprintf("g%s", uuid.NewString()) // g is not a valid hex character + + credentials, err := NewCredentials(fmt.Sprintf("%s:%s:%s", id, secret, badKey), defaultKey) + assert.NotNil(t, err) + assert.Equal(t, "failed to decode signing key: encoding/hex: invalid byte: U+0067 'g'", err.Error()) + assert.Nil(t, credentials) + }) +} diff --git a/lib/oauth/discord.go b/lib/oauth/discord.go index 418ae4e0ee..9b19f23119 100644 --- a/lib/oauth/discord.go +++ b/lib/oauth/discord.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "golang.org/x/oauth2" + "gopkg.in/square/go-jose.v2" ) // NewDiscordProvider returns a AuthN integration for Discord OAuth @@ -20,24 +21,21 @@ func NewDiscordProvider(credentials *Credentials) *Provider { }, } - return &Provider{ - config: config, - UserInfo: func(t *oauth2.Token) (*UserInfo, error) { - client := config.Client(context.TODO(), t) - resp, err := client.Get("https://discordapp.com/api/users/@me") - if err != nil { - return nil, err - } - defer resp.Body.Close() + return NewProvider(config, func(t *oauth2.Token) (*UserInfo, error) { + client := config.Client(context.TODO(), t) + resp, err := client.Get("https://discordapp.com/api/users/@me") + if err != nil { + return nil, err + } + defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } - var user UserInfo - err = json.Unmarshal(body, &user) - return &user, err - }, - } + var user UserInfo + err = json.Unmarshal(body, &user) + return &user, err + }, jose.SigningKey{Key: credentials.SigningKey, Algorithm: jose.HS256}) } diff --git a/lib/oauth/facebook.go b/lib/oauth/facebook.go index 220b8e6bd3..d8e51d27b6 100644 --- a/lib/oauth/facebook.go +++ b/lib/oauth/facebook.go @@ -7,6 +7,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/facebook" + "gopkg.in/square/go-jose.v2" ) // NewFacebookProvider returns a AuthN integration for Facebook OAuth @@ -18,24 +19,21 @@ func NewFacebookProvider(credentials *Credentials) *Provider { Endpoint: facebook.Endpoint, } - return &Provider{ - config: config, - UserInfo: func(t *oauth2.Token) (*UserInfo, error) { - client := config.Client(context.TODO(), t) - resp, err := client.Get("https://graph.facebook.com/me?fields=id,email") - if err != nil { - return nil, err - } - defer resp.Body.Close() + return NewProvider(config, func(t *oauth2.Token) (*UserInfo, error) { + client := config.Client(context.TODO(), t) + resp, err := client.Get("https://graph.facebook.com/me?fields=id,email") + if err != nil { + return nil, err + } + defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } - var user UserInfo - err = json.Unmarshal(body, &user) - return &user, err - }, - } + var user UserInfo + err = json.Unmarshal(body, &user) + return &user, err + }, jose.SigningKey{Key: credentials.SigningKey, Algorithm: jose.HS256}) } diff --git a/lib/oauth/github.go b/lib/oauth/github.go index 5464c6768e..02e2d2ca29 100644 --- a/lib/oauth/github.go +++ b/lib/oauth/github.go @@ -8,6 +8,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/github" + "gopkg.in/square/go-jose.v2" ) // NewGitHubProvider returns a AuthN integration for GitHub OAuth @@ -71,23 +72,20 @@ func NewGitHubProvider(credentials *Credentials) *Provider { return strconv.Itoa(user.ID), nil } - return &Provider{ - config: config, - UserInfo: func(t *oauth2.Token) (*UserInfo, error) { - id, err := getID(t) - if err != nil { - return nil, err - } + return NewProvider(config, func(t *oauth2.Token) (*UserInfo, error) { + id, err := getID(t) + if err != nil { + return nil, err + } - email, err := getPrimaryEmail(t) - if err != nil { - return nil, err - } + email, err := getPrimaryEmail(t) + if err != nil { + return nil, err + } - return &UserInfo{ - ID: id, - Email: email, - }, nil - }, - } + return &UserInfo{ + ID: id, + Email: email, + }, nil + }, jose.SigningKey{Key: credentials.SigningKey, Algorithm: jose.HS256}) } diff --git a/lib/oauth/google.go b/lib/oauth/google.go index 408bcb6278..842869d3ce 100644 --- a/lib/oauth/google.go +++ b/lib/oauth/google.go @@ -7,6 +7,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" + "gopkg.in/square/go-jose.v2" ) // NewGoogleProvider returns a AuthN integration for Google OAuth @@ -18,24 +19,21 @@ func NewGoogleProvider(credentials *Credentials) *Provider { Endpoint: google.Endpoint, } - return &Provider{ - config: config, - UserInfo: func(t *oauth2.Token) (*UserInfo, error) { - client := config.Client(context.TODO(), t) - resp, err := client.Get("https://www.googleapis.com/oauth2/v1/userinfo?alt=json") - if err != nil { - return nil, err - } - defer resp.Body.Close() + return NewProvider(config, func(t *oauth2.Token) (*UserInfo, error) { + client := config.Client(context.TODO(), t) + resp, err := client.Get("https://www.googleapis.com/oauth2/v1/userinfo?alt=json") + if err != nil { + return nil, err + } + defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } - var user UserInfo - err = json.Unmarshal(body, &user) - return &user, err - }, - } + var user UserInfo + err = json.Unmarshal(body, &user) + return &user, err + }, jose.SigningKey{Key: credentials.SigningKey, Algorithm: jose.HS256}) } diff --git a/lib/oauth/microsoft.go b/lib/oauth/microsoft.go index a14fb8caf6..3f20cfd527 100644 --- a/lib/oauth/microsoft.go +++ b/lib/oauth/microsoft.go @@ -7,6 +7,7 @@ import ( "io" "golang.org/x/oauth2" + "gopkg.in/square/go-jose.v2" ) // NewMicrosoftProvider returns a AuthN integration for Microsoft OAuth @@ -21,32 +22,29 @@ func NewMicrosoftProvider(credentials *Credentials) *Provider { }, } - return &Provider{ - config: config, - UserInfo: func(t *oauth2.Token) (*UserInfo, error) { - var me struct { - Id string `json:"id"` - UserPrincipalName string `json:"userPrincipalName"` - } + return NewProvider(config, func(t *oauth2.Token) (*UserInfo, error) { + var me struct { + Id string `json:"id"` + UserPrincipalName string `json:"userPrincipalName"` + } - client := config.Client(context.TODO(), t) - resp, err := client.Get("https://graph.microsoft.com/v1.0/me") - if err != nil { - return nil, err - } - defer resp.Body.Close() + client := config.Client(context.TODO(), t) + resp, err := client.Get("https://graph.microsoft.com/v1.0/me") + if err != nil { + return nil, err + } + defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } - var user UserInfo - err = json.Unmarshal(body, &me) - user.ID = me.Id - user.Email = me.UserPrincipalName - fmt.Println(user) - return &user, err - }, - } + var user UserInfo + err = json.Unmarshal(body, &me) + user.ID = me.Id + user.Email = me.UserPrincipalName + fmt.Println(user) + return &user, err + }, jose.SigningKey{Key: credentials.SigningKey, Algorithm: jose.HS256}) } diff --git a/lib/oauth/provider.go b/lib/oauth/provider.go index 52ae109fe5..b03da71a99 100644 --- a/lib/oauth/provider.go +++ b/lib/oauth/provider.go @@ -2,12 +2,14 @@ package oauth import ( "golang.org/x/oauth2" + "gopkg.in/square/go-jose.v2" ) // Provider is a struct wrapping the necessary bits to integrate an OAuth2 provider with AuthN type Provider struct { - config *oauth2.Config - UserInfo UserInfoFetcher + config *oauth2.Config + UserInfo UserInfoFetcher + signingKey jose.SigningKey } // UserInfo is the minimum necessary needed from an OAuth Provider to connect with AuthN accounts @@ -20,8 +22,8 @@ type UserInfo struct { type UserInfoFetcher = func(t *oauth2.Token) (*UserInfo, error) // NewProvider returns a properly configured Provider -func NewProvider(config *oauth2.Config, userInfo UserInfoFetcher) *Provider { - return &Provider{config, userInfo} +func NewProvider(config *oauth2.Config, userInfo UserInfoFetcher, signingKey jose.SigningKey) *Provider { + return &Provider{config: config, UserInfo: userInfo, signingKey: signingKey} } // Config returns a complete oauth2.Config after injecting the RedirectURL @@ -34,3 +36,8 @@ func (p *Provider) Config(redirectURL string) *oauth2.Config { RedirectURL: redirectURL, } } + +func (p *Provider) SigningKey() jose.SigningKey { + //TODO: allow override with dynamic calc for apple + return p.signingKey +} diff --git a/lib/oauth/test.go b/lib/oauth/test.go index deea56ec80..6f73472c14 100644 --- a/lib/oauth/test.go +++ b/lib/oauth/test.go @@ -4,11 +4,12 @@ import ( "net/http/httptest" "golang.org/x/oauth2" + "gopkg.in/square/go-jose.v2" ) // NewTestProvider returns a special Provider for tests -func NewTestProvider(s *httptest.Server) *Provider { - return &Provider{ +func NewTestProvider(s *httptest.Server, signingKey []byte) *Provider { + return NewProvider( &oauth2.Config{ ClientID: "TEST", ClientSecret: "SECRET", @@ -23,6 +24,5 @@ func NewTestProvider(s *httptest.Server) *Provider { ID: t.AccessToken, Email: t.AccessToken, }, nil - }, - } + }, jose.SigningKey{Key: signingKey, Algorithm: jose.HS256}) } diff --git a/server/handlers/get_oauth.go b/server/handlers/get_oauth.go index 7e96d03856..bb06abb16f 100644 --- a/server/handlers/get_oauth.go +++ b/server/handlers/get_oauth.go @@ -39,14 +39,14 @@ func GetOauth(app *app.App, providerName string) http.HandlerFunc { } nonce := base64.StdEncoding.EncodeToString(bytes) http.SetCookie(w, nonceCookie(app.Config, string(nonce))) - + // save nonce and return URL into state param stateToken, err := oauth.New(app.Config, string(nonce), redirectURI) if err != nil { fail(err) return } - state, err := stateToken.Sign(app.Config.OAuthSigningKey) + state, err := stateToken.Sign(provider.SigningKey()) if err != nil { fail(err) return diff --git a/server/handlers/get_oauth_return_test.go b/server/handlers/get_oauth_return_test.go index ca0ceaa4b0..8828e08ccf 100644 --- a/server/handlers/get_oauth_return_test.go +++ b/server/handlers/get_oauth_return_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "gopkg.in/square/go-jose.v2" "github.com/stretchr/testify/require" @@ -20,11 +21,11 @@ func TestGetOauthReturn(t *testing.T) { providerServer := httptest.NewServer(test.ProviderApp()) defer providerServer.Close() - // configure a client for the fake oauth provider - providerClient := oauthlib.NewTestProvider(providerServer) - // configure and start the authn test server app := test.App() + + // configure a client for the fake oauth provider + providerClient := oauthlib.NewTestProvider(providerServer, app.Config.OAuthSigningKey) app.OauthProviders["test"] = *providerClient server := test.Server(app) defer server.Close() @@ -41,7 +42,7 @@ func TestGetOauthReturn(t *testing.T) { token, err := oauthtoken.New(app.Config, nonce, "https://localhost:9999/return") require.NoError(t, err) - state, err := token.Sign(app.Config.OAuthSigningKey) + state, err := token.Sign(jose.SigningKey{Algorithm: jose.HS256, Key: app.Config.OAuthSigningKey}) require.NoError(t, err) t.Run("sign up new identity with new email", func(t *testing.T) { diff --git a/server/handlers/get_oauth_test.go b/server/handlers/get_oauth_test.go index 8e08d3858f..af6bcf5dbb 100644 --- a/server/handlers/get_oauth_test.go +++ b/server/handlers/get_oauth_test.go @@ -5,11 +5,10 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/assert" - - "github.com/keratin/authn-server/server/test" oauthlib "github.com/keratin/authn-server/lib/oauth" "github.com/keratin/authn-server/lib/route" + "github.com/keratin/authn-server/server/test" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,12 +17,13 @@ func TestGetOauth(t *testing.T) { providerServer := httptest.NewServer(test.ProviderApp()) defer providerServer.Close() - // configure a client for the fake oauth provider - providerClient := oauthlib.NewTestProvider(providerServer) - // configure and start the authn test server app := test.App() + + // configure a client for the fake oauth provider + providerClient := oauthlib.NewTestProvider(providerServer, app.Config.OAuthSigningKey) app.OauthProviders["test"] = *providerClient + server := test.Server(app) defer server.Close() diff --git a/server/test/app.go b/server/test/app.go index d874fbf3ce..5897485ffd 100644 --- a/server/test/app.go +++ b/server/test/app.go @@ -40,6 +40,7 @@ func App() *app.App { EnableSignup: true, SameSite: http.SameSiteDefaultMode, PasswordChangeLogout: false, + OAuthSigningKey: []byte("key-a-reno"), } //Create mock blob stores for the totp cache object (TODO: Create an interface?)