diff --git a/get.go b/get.go index 5dd754b..00f2f2f 100644 --- a/get.go +++ b/get.go @@ -53,7 +53,10 @@ func Get(jwksURL string, options Options) (jwks *JWKS, err error) { } if jwks.refreshInterval != 0 || jwks.refreshUnknownKID { - jwks.ctx, jwks.cancel = context.WithCancel(context.Background()) + if jwks.ctx == nil { + jwks.ctx = context.Background() + } + jwks.ctx, jwks.cancel = context.WithCancel(jwks.ctx) jwks.refreshRequests = make(chan refreshRequest, 1) go jwks.backgroundRefresh() } diff --git a/get_test.go b/get_test.go index 30fbbe6..6624b84 100644 --- a/get_test.go +++ b/get_test.go @@ -80,3 +80,85 @@ func TestJWKS_RefreshUsingBackgroundGoroutine(t *testing.T) { t.Fatalf("Expected 2 refreshes, got %d.", count) } } + +func TestJWKS_RefreshCancelCtx(t *testing.T) { + tests := map[string]struct { + provideOptionsCtx bool + cancelOptionsCtx bool + expectedRefreshes int + }{ + "cancel Options.Ctx": { + provideOptionsCtx: true, + cancelOptionsCtx: true, + expectedRefreshes: 2, + }, + "do not cancel Options.Ctx": { + provideOptionsCtx: true, + cancelOptionsCtx: false, + expectedRefreshes: 3, + }, + "do not provide Options.Ctx": { + provideOptionsCtx: false, + cancelOptionsCtx: false, + expectedRefreshes: 3, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + var ( + ctx context.Context + cancel = func() {} + ) + if tc.provideOptionsCtx { + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + } + + var counter uint64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddUint64(&counter, 1) + _, err := w.Write([]byte(jwksJSON)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + })) + defer server.Close() + + jwksURL := server.URL + opts := keyfunc.Options{ + Ctx: ctx, + RefreshInterval: 1 * time.Second, + } + jwks, err := keyfunc.Get(jwksURL, opts) + if err != nil { + t.Fatalf(logFmt, "Failed to get JWKS from testing URL.", err) + } + + // Wait for the first refresh to occur to ensure the + // JWKS gets refreshed at least once. + time.Sleep(1100 * time.Millisecond) + + if tc.cancelOptionsCtx { + cancel() + } + + // Wait for another refresh cycle to occur to ensure the + // JWKS either did or did not get refreshed depending on + // whether the passed Options.Ctx has been canceled. + time.Sleep(1101 * time.Millisecond) + + jwks.EndBackground() + + // Wait for another refresh cycle to occur to verify that + // the JWKS did not get refreshed after EndBackground() + // has been called. + time.Sleep(1100 * time.Millisecond) + + count := atomic.LoadUint64(&counter) + if count != uint64(tc.expectedRefreshes) { + t.Fatalf("Expected %d refreshes, got %d.", tc.expectedRefreshes, count) + } + }) + } +}