Skip to content

Commit

Permalink
Use configured TLS certs for OIDC operations (#40)
Browse files Browse the repository at this point in the history
Fixes #39
  • Loading branch information
kalafut authored Apr 5, 2019
1 parent 7ca4cef commit c05fb7d
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 28 deletions.
2 changes: 1 addition & 1 deletion backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (b *jwtAuthBackend) reset() {
b.l.Unlock()
}

func (b *jwtAuthBackend) getProvider(ctx context.Context, config *jwtConfig) (*oidc.Provider, error) {
func (b *jwtAuthBackend) getProvider(config *jwtConfig) (*oidc.Provider, error) {
b.l.RLock()
unlockFunc := b.l.RUnlock
defer func() { unlockFunc() }()
Expand Down
37 changes: 25 additions & 12 deletions path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,29 @@ func (b *jwtAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque
}

func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, error) {
var certPool *x509.CertPool
if config.OIDCDiscoveryCAPEM != "" {
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(config.OIDCDiscoveryCAPEM)); !ok {
return nil, errors.New("could not parse 'oidc_discovery_ca_pem' value successfully")
}
oidcCtx, err := b.createOIDCContext(b.providerCtx, config)
if err != nil {
return nil, errwrap.Wrapf("error creating provider: {{err}}", err)
}

provider, err := oidc.NewProvider(oidcCtx, config.OIDCDiscoveryURL)
if err != nil {
return nil, errwrap.Wrapf("error creating provider with given values: {{err}}", err)
}

return provider, nil
}

// createOIDCContext returns a context with custom TLS client, configured with the root certificates
// from oidc_discovery_ca_pem. If no certificates are configured, the original context is returned.
func (b *jwtAuthBackend) createOIDCContext(ctx context.Context, config *jwtConfig) (context.Context, error) {
if config.OIDCDiscoveryCAPEM == "" {
return ctx, nil
}

certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(config.OIDCDiscoveryCAPEM)); !ok {
return nil, errors.New("could not parse 'oidc_discovery_ca_pem' value successfully")
}

tr := cleanhttp.DefaultPooledTransport()
Expand All @@ -216,14 +233,10 @@ func (b *jwtAuthBackend) createProvider(config *jwtConfig) (*oidc.Provider, erro
tc := &http.Client{
Transport: tr,
}
oidcCtx := context.WithValue(b.providerCtx, oauth2.HTTPClient, tc)

provider, err := oidc.NewProvider(oidcCtx, config.OIDCDiscoveryURL)
if err != nil {
return nil, errwrap.Wrapf("error creating provider with given values: {{err}}", err)
}
oidcCtx := context.WithValue(ctx, oauth2.HTTPClient, tc)

return provider, nil
return oidcCtx, nil
}

type jwtConfig struct {
Expand Down
2 changes: 1 addition & 1 deletion path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (b *jwtAuthBackend) pathLoginRenew(ctx context.Context, req *logical.Reques
func (b *jwtAuthBackend) verifyOIDCToken(ctx context.Context, config *jwtConfig, role *jwtRole, rawToken string) (map[string]interface{}, error) {
allClaims := make(map[string]interface{})

provider, err := b.getProvider(ctx, config)
provider, err := b.getProvider(config)
if err != nil {
return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err)
}
Expand Down
15 changes: 10 additions & 5 deletions path_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,14 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("request originated from invalid CIDR"), nil
}

provider, err := b.getProvider(ctx, config)
provider, err := b.getProvider(config)
if err != nil {
return nil, errwrap.Wrapf(errLoginFailed+" Error getting provider for login operation: {{err}}", err)
return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err)
}

oidcCtx, err := b.createOIDCContext(ctx, config)
if err != nil {
return nil, errwrap.Wrapf("error preparing context for login operation: {{err}}", err)
}

var oauth2Config = oauth2.Config{
Expand All @@ -121,7 +126,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
return logical.ErrorResponse(errLoginFailed + " OAuth code parameter not provided"), nil
}

oauth2Token, err := oauth2Config.Exchange(ctx, code)
oauth2Token, err := oauth2Config.Exchange(oidcCtx, code)
if err != nil {
return logical.ErrorResponse(errLoginFailed+" Error exchanging oidc code: %q.", err.Error()), nil
}
Expand All @@ -146,7 +151,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
// Attempt to fetch information from the /userinfo endpoint and merge it with
// the existing claims data. A failure to fetch additional information from this
// endpoint will not invalidate the authorization flow.
if userinfo, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token)); err == nil {
if userinfo, err := provider.UserInfo(oidcCtx, oauth2.StaticTokenSource(oauth2Token)); err == nil {
_ = userinfo.Claims(&allClaims)
} else {
logFunc := b.Logger().Warn
Expand Down Expand Up @@ -246,7 +251,7 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f
return resp, nil
}

provider, err := b.getProvider(ctx, config)
provider, err := b.getProvider(config)
if err != nil {
logger.Warn("error getting provider for login operation", "error", err)
return resp, nil
Expand Down
28 changes: 21 additions & 7 deletions path_oidc_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwtauth

import (
"bytes"
"context"
"crypto/x509"
"encoding/json"
Expand Down Expand Up @@ -217,14 +218,27 @@ func TestOIDC_Callback(t *testing.T) {
s.clientID = "abc"
s.clientSecret = "def"

// save test server root cert to config in PEM format
cert := s.server.Certificate()
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}

pemBuf := new(bytes.Buffer)
if err := pem.Encode(pemBuf, block); err != nil {
t.Fatal(err)
}

// Configure backend
data := map[string]interface{}{
"oidc_discovery_url": s.server.URL,
"oidc_client_id": "abc",
"oidc_client_secret": "def",
"default_role": "test",
"bound_issuer": "http://vault.example.com/",
"jwt_supported_algs": []string{"ES256"},
"oidc_discovery_url": s.server.URL,
"oidc_client_id": "abc",
"oidc_client_secret": "def",
"oidc_discovery_ca_pem": pemBuf.String(),
"default_role": "test",
"bound_issuer": "http://vault.example.com/",
"jwt_supported_algs": []string{"ES256"},
}

// basic configuration
Expand Down Expand Up @@ -758,7 +772,7 @@ type oidcProvider struct {
func newOIDCProvider(t *testing.T) *oidcProvider {
o := new(oidcProvider)
o.t = t
o.server = httptest.NewServer(o)
o.server = httptest.NewTLSServer(o)

return o
}
Expand Down
4 changes: 2 additions & 2 deletions scripts/local_dev.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env bash
set -e

MNT_PATH="jwt"
MNT_PATH="oidc"
PLUGIN_NAME="vault-plugin-auth-jwt"
PLUGIN_CATALOG_NAME="jwt"
PLUGIN_CATALOG_NAME="oidc"

#
# Helper script for local development. Automatically builds and registers the
Expand Down

0 comments on commit c05fb7d

Please sign in to comment.