diff --git a/get.go b/get.go index 00f2f2f..7ba613f 100644 --- a/get.go +++ b/get.go @@ -49,7 +49,14 @@ func Get(jwksURL string, options Options) (jwks *JWKS, err error) { err = jwks.refresh() if err != nil { - return nil, err + if options.TolerateInitialJWKHTTPError { + if jwks.refreshErrorHandler != nil { + jwks.refreshErrorHandler(err) + } + jwks.keys = make(map[string]parsedJWK) + } else { + return nil, err + } } if jwks.refreshInterval != 0 || jwks.refreshUnknownKID { diff --git a/options.go b/options.go index b4ab35c..6291ffb 100644 --- a/options.go +++ b/options.go @@ -84,6 +84,14 @@ type Options struct { // ResponseExtractor consumes a *http.Response and produces the raw JSON for the JWKS. By default, the // ResponseExtractorStatusOK function is used. The default behavior changed in v1.4.0. ResponseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error) + + // TolerateInitialJWKHTTPError will tolerate any error from the initial HTTP JWKS request. If an error occurs, + // the RefreshErrorHandler will be given the error. The program will continue to run as if the error did not occur + // and a valid JWK Set with no keys was received in the response. This allows for the background goroutine to + // request the JWKS at a later time. + // + // It does not make sense to mark this field as true unless the background refresh goroutine is active. + TolerateInitialJWKHTTPError bool } // MultipleOptions is used to configure the behavior when multiple JWKS are used by MultipleJWKS. diff --git a/options_test.go b/options_test.go index 3d7abad..6bf4845 100644 --- a/options_test.go +++ b/options_test.go @@ -2,6 +2,7 @@ package keyfunc_test import ( "errors" + "github.com/golang-jwt/jwt/v5" "net/http" "net/http/httptest" "sync" @@ -77,3 +78,52 @@ func TestResponseExtractorStatusAny(t *testing.T) { t.Fatalf("Expected error no error for 500 status code.\nError: %s", err) } } + +func TestTolerateStartupFailure(t *testing.T) { + var mux sync.Mutex + shouldError := true + + server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + mux.Lock() + defer mux.Unlock() + if shouldError { + writer.WriteHeader(http.StatusInternalServerError) + } else { + writer.WriteHeader(http.StatusOK) + _, _ = writer.Write([]byte(jwksJSON)) + } + })) + defer server.Close() + + options := keyfunc.Options{ + TolerateInitialJWKHTTPError: true, + RefreshUnknownKID: true, + } + jwks, err := keyfunc.Get(server.URL, options) + if err != nil { + t.Fatalf("TolerateInitialJWKHTTPError should not return error on bad HTTP startup.\nError: %s", err) + } + + if len(jwks.ReadOnlyKeys()) != 0 { + t.Fatalf("Expected JWK Set to have no keys.") + } + + const token = "eyJhbGciOiJFUzI1NiIsInR5cCIgOiAiSldUIiwia2lkIiA6ICJDR3QwWldTNExjNWZhaUtTZGkwdFUwZmpDQWR2R1JPUVJHVTlpUjd0VjBBIn0.eyJleHAiOjE2MTU0MDY4NjEsImlhdCI6MTYxNTQwNjgwMSwianRpIjoiYWVmOWQ5YjItN2EyYy00ZmQ4LTk4MzktODRiMzQ0Y2VmYzZhIiwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDo4MDgwL2F1dGgvcmVhbG1zL21hc3RlciIsImF1ZCI6ImFjY291bnQiLCJzdWIiOiJhZDEyOGRmMS0xMTQwLTRlNGMtYjA5Ny1hY2RjZTcwNWJkOWIiLCJ0eXAiOiJCZWFyZXIiLCJhenAiOiJ0b2tlbmRlbG1lIiwiYWNyIjoiMSIsInJlYWxtX2FjY2VzcyI6eyJyb2xlcyI6WyJvZmZsaW5lX2FjY2VzcyIsInVtYV9hdXRob3JpemF0aW9uIl19LCJyZXNvdXJjZV9hY2Nlc3MiOnsiYWNjb3VudCI6eyJyb2xlcyI6WyJtYW5hZ2UtYWNjb3VudCIsIm1hbmFnZS1hY2NvdW50LWxpbmtzIiwidmlldy1wcm9maWxlIl19fSwic2NvcGUiOiJlbWFpbCBwcm9maWxlIiwiY2xpZW50SG9zdCI6IjE3Mi4yMC4wLjEiLCJjbGllbnRJZCI6InRva2VuZGVsbWUiLCJlbWFpbF92ZXJpZmllZCI6ZmFsc2UsInByZWZlcnJlZF91c2VybmFtZSI6InNlcnZpY2UtYWNjb3VudC10b2tlbmRlbG1lIiwiY2xpZW50QWRkcmVzcyI6IjE3Mi4yMC4wLjEifQ.iQ77QGoPDNjR2oWLu3zT851mswP8J-h_nrGhs3fpa_tFB3FT1deKPGkjef9JOTYFI-CIVxdCFtW3KODOaw9Nrw" + _, err = jwt.Parse(token, jwks.Keyfunc) + if !errors.Is(err, keyfunc.ErrKIDNotFound) { + t.Fatalf("Expected error to be ErrKIDNotFound.\nError: %s", err) + } + + mux.Lock() + shouldError = false + mux.Unlock() + + _, err = jwt.Parse(token, jwks.Keyfunc) + if !errors.Is(err, jwt.ErrTokenExpired) { + t.Fatalf("Expected error to be jwt.ErrTokenExpired.\nError: %s", err) + } + + if len(jwks.ReadOnlyKeys()) == 0 { + t.Fatalf("Expected JWK Set to have keys.") + } +}