-
-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #78 from MicahParks/multiple_jwks
Add support for multiple JWK Sets
- Loading branch information
Showing
5 changed files
with
170 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package keyfunc | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
|
||
"github.com/golang-jwt/jwt/v4" | ||
) | ||
|
||
// ErrMultipleJWKSSize is returned when the number of JWKS given are not enough to make a MultipleJWKS. | ||
var ErrMultipleJWKSSize = errors.New("multiple JWKS must have two or more remote JWK Set resources") | ||
|
||
// MultipleJWKS manages multiple JWKS and has a field for jwt.Keyfunc. | ||
type MultipleJWKS struct { | ||
keySelector func(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error) | ||
sets map[string]*JWKS // No lock is required because this map is read-only after initialization. | ||
} | ||
|
||
// GetMultiple creates a new MultipleJWKS. A map of length two or more JWKS URLs to Options is required. | ||
// | ||
// Be careful when choosing Options for each JWKS in the map. If RefreshUnknownKID is set to true for all JWKS in the | ||
// map then many refresh requests would take place each time a JWT is processed, this should be rate limited by | ||
// RefreshRateLimit. | ||
func GetMultiple(multiple map[string]Options, options MultipleOptions) (multiJWKS *MultipleJWKS, err error) { | ||
if multiple == nil || len(multiple) < 2 { | ||
return nil, fmt.Errorf("multiple JWKS must have two or more remote JWK Set resources: %w", ErrMultipleJWKSSize) | ||
} | ||
|
||
if options.KeySelector == nil { | ||
options.KeySelector = KeySelectorFirst | ||
} | ||
|
||
multiJWKS = &MultipleJWKS{ | ||
sets: make(map[string]*JWKS, len(multiple)), | ||
keySelector: options.KeySelector, | ||
} | ||
|
||
for u, opts := range multiple { | ||
jwks, err := Get(u, opts) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to get JWKS from %q: %w", u, err) | ||
} | ||
multiJWKS.sets[u] = jwks | ||
} | ||
|
||
return multiJWKS, nil | ||
} | ||
|
||
func (m *MultipleJWKS) JWKSets() map[string]*JWKS { | ||
sets := make(map[string]*JWKS, len(m.sets)) | ||
for u, jwks := range m.sets { | ||
sets[u] = jwks | ||
} | ||
return sets | ||
} | ||
|
||
func KeySelectorFirst(multiJWKS *MultipleJWKS, token *jwt.Token) (key interface{}, err error) { | ||
kid, alg, err := kidAlg(token) | ||
if err != nil { | ||
return nil, err | ||
} | ||
for _, jwks := range multiJWKS.sets { | ||
key, err = jwks.getKey(alg, kid) | ||
if err == nil { | ||
return key, nil | ||
} | ||
} | ||
return nil, fmt.Errorf("failed to find key ID in multiple JWKS: %w", ErrKIDNotFound) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package keyfunc_test | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/golang-jwt/jwt/v4" | ||
|
||
"github.com/MicahParks/keyfunc" | ||
) | ||
|
||
const ( | ||
jwks1 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"uniqueKID","kty":"OKP","x":"1IlXuWBIkjYbAXm5Hk5mvsbPq0skO3-G_hX1Cw7CY-8"},{"alg":"EdDSA","crv":"Ed25519","kid":"collisionKID","kty":"OKP","x":"IbQyt_GPqUJImuAgStdixWdadZGvzTPS_mKlOjmuOYU"}]}` | ||
jwks2 = `{"keys":[{"alg":"EdDSA","crv":"Ed25519","kid":"collisionKID","kty":"OKP","x":"IbQyt_GPqUJImuAgStdixWdadZGvzTPS_mKlOjmuOYU"}]}` | ||
) | ||
|
||
func TestMultipleJWKS(t *testing.T) { | ||
server1 := createTestServer([]byte(jwks1)) | ||
defer server1.Close() | ||
|
||
server2 := createTestServer([]byte(jwks2)) | ||
defer server2.Close() | ||
|
||
const ( | ||
collisionJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6ImNvbGxpc2lvbktJRCIsInR5cCI6IkpXVCJ9.e30.WXKmhyHjHQFXZ8dXfj07RvwKAgHB3EdGU1jeKUEY-wajgsRsHuhnotX1WqDSlngwGerEitnIcdMGViW_HNUCAA" | ||
uniqueJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6InVuaXF1ZUtJRCIsInR5cCI6IkpXVCJ9.e30.egdT5_vXYKIM7UfsyewYaR63tS9T9JvKwUJs7Srj6wG9JHXMvN9Ftq0rJGem07ESVtN5OtlcJOaMgSbtxnc6Bg" | ||
) | ||
|
||
m := map[string]keyfunc.Options{ | ||
server1.URL: {}, | ||
server2.URL: {}, | ||
} | ||
|
||
multiJWKS, err := keyfunc.GetMultiple(m, keyfunc.MultipleOptions{}) | ||
if err != nil { | ||
t.Fatalf("failed to get multiple JWKS: %v", err) | ||
} | ||
|
||
token, err := jwt.Parse(collisionJWT, multiJWKS.Keyfunc) | ||
if err != nil { | ||
t.Fatalf("failed to parse collision JWT: %v", err) | ||
} | ||
if !token.Valid { | ||
t.Fatalf("collision JWT is invalid") | ||
} | ||
|
||
token, err = jwt.Parse(uniqueJWT, multiJWKS.Keyfunc) | ||
if err != nil { | ||
t.Fatalf("failed to parse unique JWT: %v", err) | ||
} | ||
if !token.Valid { | ||
t.Fatalf("unique JWT is invalid") | ||
} | ||
|
||
sets := multiJWKS.JWKSets() | ||
if len(sets) != 2 { | ||
t.Fatalf("expected 2 JWKS, got %d", len(sets)) | ||
} | ||
} | ||
|
||
func createTestServer(body []byte) *httptest.Server { | ||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.Header().Set("Content-Type", "application/json") | ||
_, _ = w.Write(body) | ||
})) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters