diff --git a/multiple.go b/multiple.go index 08946b0..c6666c5 100644 --- a/multiple.go +++ b/multiple.go @@ -22,8 +22,8 @@ type MultipleJWKS struct { // map then many refresh requests would take place each time a JWT is processed, this should be rate limited by // RefreshRateLimit. func GetMultiple(multiple map[string]Options, options MultipleOptions) (multiJWKS *MultipleJWKS, err error) { - if multiple == nil || len(multiple) < 2 { - return nil, fmt.Errorf("multiple JWKS must have two or more remote JWK Set resources: %w", ErrMultipleJWKSSize) + if len(multiple) < 1 { + return nil, fmt.Errorf("multiple JWKS must have one or more remote JWK Set resources: %w", ErrMultipleJWKSSize) } if options.KeySelector == nil { diff --git a/multiple_test.go b/multiple_test.go index 8227864..4ec37a9 100644 --- a/multiple_test.go +++ b/multiple_test.go @@ -59,6 +59,46 @@ func TestMultipleJWKS(t *testing.T) { } } +func TestMultipleJWKSSingle(t *testing.T) { + server1 := createTestServer([]byte(jwks1)) + defer server1.Close() + + const ( + collisionJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6ImNvbGxpc2lvbktJRCIsInR5cCI6IkpXVCJ9.e30.WXKmhyHjHQFXZ8dXfj07RvwKAgHB3EdGU1jeKUEY-wajgsRsHuhnotX1WqDSlngwGerEitnIcdMGViW_HNUCAA" + uniqueJWT = "eyJhbGciOiJFZERTQSIsImtpZCI6InVuaXF1ZUtJRCIsInR5cCI6IkpXVCJ9.e30.egdT5_vXYKIM7UfsyewYaR63tS9T9JvKwUJs7Srj6wG9JHXMvN9Ftq0rJGem07ESVtN5OtlcJOaMgSbtxnc6Bg" + ) + + m := map[string]keyfunc.Options{ + server1.URL: {}, + } + + multiJWKS, err := keyfunc.GetMultiple(m, keyfunc.MultipleOptions{}) + if err != nil { + t.Fatalf("failed to get multiple JWKS: %v", err) + } + + token, err := jwt.Parse(collisionJWT, multiJWKS.Keyfunc) + if err != nil { + t.Fatalf("failed to parse collision JWT: %v", err) + } + if !token.Valid { + t.Fatalf("collision JWT is invalid") + } + + token, err = jwt.Parse(uniqueJWT, multiJWKS.Keyfunc) + if err != nil { + t.Fatalf("failed to parse unique JWT: %v", err) + } + if !token.Valid { + t.Fatalf("unique JWT is invalid") + } + + sets := multiJWKS.JWKSets() + if len(sets) != 1 { + t.Fatalf("expected 2 JWKS, got %d", len(sets)) + } +} + func createTestServer(body []byte) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json")