Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional client_nonce for OIDC logins #104

Merged
merged 2 commits into from
Mar 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/base62"
)

const defaultMount = "oidc"
Expand Down Expand Up @@ -73,13 +74,13 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro

role := m["role"]

authURL, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
authURL, clientNonce, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
if err != nil {
return nil, err
}

// Set up callback handler
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, doneCh))
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))

listener, err := net.Listen("tcp", listenAddress+":"+port)
if err != nil {
Expand Down Expand Up @@ -112,7 +113,7 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
}
}

func callbackHandler(c *api.Client, mount string, doneCh chan<- loginResp) http.HandlerFunc {
func callbackHandler(c *api.Client, mount string, clientNonce string, doneCh chan<- loginResp) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
var response string
var secret *api.Secret
Expand All @@ -126,9 +127,10 @@ func callbackHandler(c *api.Client, mount string, doneCh chan<- loginResp) http.
// Pull any parameters from either the body or query parameters.
// FormValue prioritizes body values, if found.
data := map[string][]string{
"state": {req.FormValue("state")},
"code": {req.FormValue("code")},
"id_token": {req.FormValue("id_token")},
"state": {req.FormValue("state")},
"code": {req.FormValue("code")},
"id_token": {req.FormValue("id_token")},
"client_nonce": {clientNonce},
}

// If this is a POST, then the form_post response_mode is being used and the flow
Expand Down Expand Up @@ -158,28 +160,34 @@ func callbackHandler(c *api.Client, mount string, doneCh chan<- loginResp) http.
}
}

func fetchAuthURL(c *api.Client, role, mount, callbackport string, callbackMethod string, callbackHost string) (string, error) {
func fetchAuthURL(c *api.Client, role, mount, callbackport string, callbackMethod string, callbackHost string) (string, string, error) {
var authURL string

clientNonce, err := base62.Random(20)
if err != nil {
return "", "", err
}

data := map[string]interface{}{
"role": role,
"redirect_uri": fmt.Sprintf("%s://%s:%s/oidc/callback", callbackMethod, callbackHost, callbackport),
"client_nonce": clientNonce,
}

secret, err := c.Logical().Write(fmt.Sprintf("auth/%s/oidc/auth_url", mount), data)
if err != nil {
return "", err
return "", "", err
}

if secret != nil {
authURL = secret.Data["auth_url"].(string)
}

if authURL == "" {
return "", fmt.Errorf("Unable to authorize role %q. Check Vault logs for more information.", role)
return "", "", fmt.Errorf("Unable to authorize role %q. Check Vault logs for more information.", role)
}

return authURL, nil
return authURL, clientNonce, nil
}

// isWSL tests if the binary is being run in Windows Subsystem for Linux
Expand Down
26 changes: 24 additions & 2 deletions path_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ type oidcState struct {
redirectURI string
code string
idToken string

// clientNonce is used between Vault and the client/application (e.g. CLI) making the request,
// and is unrelated to the OIDC nonce above. It is optional.
clientNonce string
}

func pathOIDC(b *jwtAuthBackend) []*framework.Path {
Expand All @@ -56,6 +60,9 @@ func pathOIDC(b *jwtAuthBackend) []*framework.Path {
"id_token": {
Type: framework.TypeString,
},
"client_nonce": {
Type: framework.TypeString,
},
},

Operations: map[logical.Operation]framework.OperationHandler{
Expand Down Expand Up @@ -86,6 +93,10 @@ func pathOIDC(b *jwtAuthBackend) []*framework.Path {
Type: framework.TypeString,
Description: "The OAuth redirect_uri to use in the authorization URL.",
},
"client_nonce": {
Type: framework.TypeString,
Description: "Optional client-provided nonce that must match during callback, if present.",
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.UpdateOperation: &framework.PathOperation{
Expand Down Expand Up @@ -158,6 +169,14 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
return logical.ErrorResponse(errLoginFailed + " Expired or missing OAuth state."), nil
}

clientNonce := d.Get("client_nonce").(string)

// If a client_nonce was provided at the start of the auth process as part of the auth_url
// request, require that it is present and matching during the callback phase.
if state.clientNonce != "" && clientNonce != state.clientNonce {
return logical.ErrorResponse("invalid client_nonce"), nil
}

roleName := state.rolename
role, err := b.role(ctx, req.Storage, roleName)
if err != nil {
Expand Down Expand Up @@ -341,6 +360,8 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f
return logical.ErrorResponse("missing redirect_uri"), nil
}

clientNonce := d.Get("client_nonce").(string)

role, err := b.role(ctx, req.Storage, roleName)
if err != nil {
return nil, err
Expand Down Expand Up @@ -381,7 +402,7 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f
Scopes: scopes,
}

stateID, nonce, err := b.createState(roleName, redirectURI)
stateID, nonce, err := b.createState(roleName, redirectURI, clientNonce)
if err != nil {
logger.Warn("error generating OAuth state", "error", err)
return resp, nil
Expand Down Expand Up @@ -421,7 +442,7 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f
// createState make an expiring state object, associated with a random state ID
// that is passed throughout the OAuth process. A nonce is also included in the
// auth process, and for simplicity will be identical in length/format as the state ID.
func (b *jwtAuthBackend) createState(rolename, redirectURI string) (string, string, error) {
func (b *jwtAuthBackend) createState(rolename, redirectURI, clientNonce string) (string, string, error) {
// Get enough bytes for 2 160-bit IDs (per rfc6749#section-10.10)
bytes, err := uuid.GenerateRandomBytes(2 * 20)
if err != nil {
Expand All @@ -435,6 +456,7 @@ func (b *jwtAuthBackend) createState(rolename, redirectURI string) (string, stri
rolename: rolename,
nonce: nonce,
redirectURI: redirectURI,
clientNonce: clientNonce,
})

return stateID, nonce, nil
Expand Down
89 changes: 89 additions & 0 deletions path_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,95 @@ func TestOIDC_Callback(t *testing.T) {
t.Fatalf("expected invalid client_id error, got : %v", *resp)
}
})

t.Run("client_nonce", func(t *testing.T) {
b, storage, s := getBackendAndServer(t, false)
defer s.server.Close()

// General behavior is that if a client_nonce is provided during the authURL phase
// it must be provided during the callback phase.
tests := map[string]struct {
authURLNonce string
callbackNonce string
errExpected bool
}{
"default, no nonces": {
errExpected: false,
},
"matching nonces": {
authURLNonce: "abc123",
callbackNonce: "abc123",
errExpected: false,
},
"mismatched nonces": {
authURLNonce: "abc123",
callbackNonce: "abc123xyz",
errExpected: true,
},
"missing nonce": {
authURLNonce: "abc123",
errExpected: true,
},
"ignore unexpected callback nonce": {
callbackNonce: "abc123",
errExpected: false,
},
}

for name, test := range tests {
// get auth_url
data := map[string]interface{}{
"role": "test",
"redirect_uri": "https://example.com",
"client_nonce": test.authURLNonce,
}
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "oidc/auth_url",
Storage: storage,
Data: data,
}

resp, err := b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}

authURL := resp.Data["auth_url"].(string)

state := getQueryParam(t, authURL, "state")
nonce := getQueryParam(t, authURL, "nonce")

// set provider claims that will be returned by the mock server
s.customClaims = sampleClaims(nonce)

// set mock provider's expected code
s.code = "abc"

// invoke the callback, which will try to exchange the code
// with the mock provider.
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "oidc/callback",
Storage: storage,
Data: map[string]interface{}{
"state": state,
"code": "abc",
"client_nonce": test.callbackNonce,
},
}

resp, err = b.HandleRequest(context.Background(), req)

if err != nil {
t.Fatal(err)
}

if test.errExpected != resp.IsError() {
t.Fatalf("%s: unexpected error response, expected: %v, got: %v", name, test.errExpected, resp.Data)
}
}
})
}

// oidcProvider is local server the mocks the basis endpoints used by the
Expand Down