Skip to content

Commit

Permalink
Merge pull request #78 from MicahParks/multiple_jwks
Browse files Browse the repository at this point in the history
Add support for multiple JWK Sets
  • Loading branch information
MicahParks authored Dec 23, 2022
2 parents fb3c60d + eaceb56 commit f76c64f
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ These features can be configured by populating fields in the
parsing JWTs. For an example, see the `examples/custom` directory.
* The remote JWKS resource can be refreshed manually using the `.Refresh` method. This can bypass the rate limit, if the
option is set.
* There is support for creating one `jwt.Keyfunc` from multiple JWK Sets through the use of the `keyfunc.GetMultiple`.

## Notes
Trailing padding is required to be removed from base64url encoded keys inside a JWKS. This is because RFC 7517 defines
Expand Down
26 changes: 18 additions & 8 deletions keyfunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,33 @@ var (

// Keyfunc matches the signature of github.com/golang-jwt/jwt/v4's jwt.Keyfunc function.
func (j *JWKS) Keyfunc(token *jwt.Token) (interface{}, error) {
kid, alg, err := kidAlg(token)
if err != nil {
return nil, err
}
return j.getKey(alg, kid)
}

func (m *MultipleJWKS) Keyfunc(token *jwt.Token) (interface{}, error) {
return m.keySelector(m, token)
}

func kidAlg(token *jwt.Token) (kid, alg string, err error) {
kidInter, ok := token.Header["kid"]
if !ok {
return nil, fmt.Errorf("%w: could not find kid in JWT header", ErrKID)
return "", "", fmt.Errorf("%w: could not find kid in JWT header", ErrKID)
}
kid, ok := kidInter.(string)
kid, ok = kidInter.(string)
if !ok {
return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID)
return "", "", fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKID)
}

alg, ok := token.Header["alg"].(string)
alg, ok = token.Header["alg"].(string)
if !ok {
// For test coverage purposes, this should be impossible to reach because the JWT package rejects a token
// without an alg parameter in the header before calling jwt.Keyfunc.
return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrJWKAlgMismatch)
return "", "", fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrJWKAlgMismatch)
}

return j.getKey(alg, kid)
return kid, alg, nil
}

// base64urlTrailingPadding removes trailing padding before decoding a string from base64url. Some non-RFC compliant
Expand Down
69 changes: 69 additions & 0 deletions multiple.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package keyfunc

import (
"errors"
"fmt"

"github.com/golang-jwt/jwt/v4"
)

// ErrMultipleJWKSSize is returned when the number of JWKS given are not enough to make a MultipleJWKS.
var ErrMultipleJWKSSize = errors.New("multiple JWKS must have two or more remote JWK Set resources")

// MultipleJWKS manages multiple JWKS and has a field for jwt.Keyfunc.
type MultipleJWKS struct {
keySelector func(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error)
sets map[string]*JWKS // No lock is required because this map is read-only after initialization.
}

// GetMultiple creates a new MultipleJWKS. A map of length two or more JWKS URLs to Options is required.
//
// Be careful when choosing Options for each JWKS in the map. If RefreshUnknownKID is set to true for all JWKS in the
// map then many refresh requests would take place each time a JWT is processed, this should be rate limited by
// RefreshRateLimit.
func GetMultiple(multiple map[string]Options, options MultipleOptions) (multiJWKS *MultipleJWKS, err error) {
if multiple == nil || len(multiple) < 2 {
return nil, fmt.Errorf("multiple JWKS must have two or more remote JWK Set resources: %w", ErrMultipleJWKSSize)
}

if options.KeySelector == nil {
options.KeySelector = KeySelectorFirst
}

multiJWKS = &MultipleJWKS{
sets: make(map[string]*JWKS, len(multiple)),
keySelector: options.KeySelector,
}

for u, opts := range multiple {
jwks, err := Get(u, opts)
if err != nil {
return nil, fmt.Errorf("failed to get JWKS from %q: %w", u, err)
}
multiJWKS.sets[u] = jwks
}

return multiJWKS, nil
}

func (m *MultipleJWKS) JWKSets() map[string]*JWKS {
sets := make(map[string]*JWKS, len(m.sets))
for u, jwks := range m.sets {
sets[u] = jwks
}
return sets
}

func KeySelectorFirst(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error) {
kid, alg, err := kidAlg(token)
if err != nil {
return nil, err
}
for _, jwks := range multiJWKS.sets {
key, err = jwks.getKey(alg, kid)
if err == nil {
return key, nil
}
}
return nil, fmt.Errorf("failed to find key ID in multiple JWKS: %w", ErrKIDNotFound)
}
67 changes: 67 additions & 0 deletions multiple_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package keyfunc_test

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/golang-jwt/jwt/v4"

"github.com/MicahParks/keyfunc"
)

const (
jwks1 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"uniqueKID","kty":"OKP","x":"1IlXuWBIkjYbAXm5Hk5mvsbPq0skO3-G_hX1Cw7CY-8"},{"alg":"EdDSA","crv":"Ed25519","kid":"collisionKID","kty":"OKP","x":"IbQyt_GPqUJImuAgStdixWdadZGvzTPS_mKlOjmuOYU"}]}`
jwks2 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"collisionKID","kty":"OKP","x":"IbQyt_GPqUJImuAgStdixWdadZGvzTPS_mKlOjmuOYU"}]}`
)

func TestMultipleJWKS(t *testing.T) {
server1 := createTestServer([]byte(jwks1))
defer server1.Close()

server2 := createTestServer([]byte(jwks2))
defer server2.Close()

const (
collisionJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6ImNvbGxpc2lvbktJRCIsInR5cCI6IkpXVCJ9.e30.WXKmhyHjHQFXZ8dXfj07RvwKAgHB3EdGU1jeKUEY-wajgsRsHuhnotX1WqDSlngwGerEitnIcdMGViW_HNUCAA"
uniqueJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6InVuaXF1ZUtJRCIsInR5cCI6IkpXVCJ9.e30.egdT5_vXYKIM7UfsyewYaR63tS9T9JvKwUJs7Srj6wG9JHXMvN9Ftq0rJGem07ESVtN5OtlcJOaMgSbtxnc6Bg"
)

m := map[string]keyfunc.Options{
server1.URL: {},
server2.URL: {},
}

multiJWKS, err := keyfunc.GetMultiple(m, keyfunc.MultipleOptions{})
if err != nil {
t.Fatalf("failed to get multiple JWKS: %v", err)
}

token, err := jwt.Parse(collisionJWT, multiJWKS.Keyfunc)
if err != nil {
t.Fatalf("failed to parse collision JWT: %v", err)
}
if !token.Valid {
t.Fatalf("collision JWT is invalid")
}

token, err = jwt.Parse(uniqueJWT, multiJWKS.Keyfunc)
if err != nil {
t.Fatalf("failed to parse unique JWT: %v", err)
}
if !token.Valid {
t.Fatalf("unique JWT is invalid")
}

sets := multiJWKS.JWKSets()
if len(sets) != 2 {
t.Fatalf("expected 2 JWKS, got %d", len(sets))
}
}

func createTestServer(body []byte) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(body)
}))
}
15 changes: 15 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"io"
"net/http"
"time"

"github.com/golang-jwt/jwt/v4"
)

// ErrInvalidHTTPStatusCode indicates that the HTTP status code is invalid.
Expand Down Expand Up @@ -70,6 +72,9 @@ type Options struct {
// This is done through a background goroutine. Without specifying a RefreshInterval a malicious client could
// self-sign X JWTs, send them to this service, then cause potentially high network usage proportional to X. Make
// sure to call the JWKS.EndBackground method to end this goroutine when it's no longer needed.
//
// It is recommended this option is not used when in MultipleJWKS. This is because KID collisions SHOULD be uncommon
// meaning nearly any JWT SHOULD trigger a refresh for the number of JWKS in the MultipleJWKS minus one.
RefreshUnknownKID bool

// RequestFactory creates HTTP requests for the remote JWKS resource located at the given url. For example, an
Expand All @@ -81,6 +86,16 @@ type Options struct {
ResponseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
}

// MultipleOptions is used to configure the behavior when multiple JWKS are used by MultipleJWKS.
type MultipleOptions struct {
// KeySelector is a function that selects the key to use for a given token. It will be used in the implementation
// for jwt.Keyfunc. If implementing this custom selector extract the key ID and algorithm from the token's header.
// Use the key ID to select a token and confirm the key's algorithm before returning it.
//
// This value defaults to KeySelectorFirst.
KeySelector func(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error)
}

// RefreshOptions are used to specify manual refresh behavior.
type RefreshOptions struct {
IgnoreRateLimit bool
Expand Down

0 comments on commit f76c64f

Please sign in to comment.