diff --git a/leaktest_test.go b/leaktest_test.go index 7ffbffc..6cecb0b 100644 --- a/leaktest_test.go +++ b/leaktest_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "net/http/httptest" "sync" "testing" "time" @@ -20,6 +21,9 @@ func (tr *testReporter) Errorf(format string, args ...interface{}) { tr.msg = fmt.Sprintf(format, args...) } +// Client for the TestServer +var testServer *httptest.Server + func TestCheck(t *testing.T) { leakyFuncs := []struct { f func() @@ -81,7 +85,7 @@ func TestCheck(t *testing.T) { DisableKeepAlives: true, } client := &http.Client{Transport: tr} - _, err := client.Get("http://localhost:8091") + _, err := client.Get(testServer.URL) if err != nil { t.Error(err) } @@ -95,7 +99,7 @@ func TestCheck(t *testing.T) { DisableKeepAlives: false, } client := &http.Client{Transport: tr} - _, err := client.Get("http://localhost:8091") + _, err := client.Get(testServer.URL) if err != nil { t.Error(err) } @@ -106,7 +110,7 @@ func TestCheck(t *testing.T) { // Start our keep alive server for keep alive tests ctx, cancel := context.WithCancel(context.Background()) defer cancel() - go startKeepAliveEnabledServer(ctx) + testServer = startKeepAliveEnabledServer(ctx) // this works because the running goroutine is left running at the // start of the next test case - so the previous leaks don't affect the @@ -134,7 +138,7 @@ func TestCheck(t *testing.T) { // be based on time after the test finishes rather than time after the test's // start. func TestSlowTest(t *testing.T) { - defer CheckTimeout(t, 1000*time.Millisecond)() + defer CheckTimeout(t, 1000 * time.Millisecond)() go time.Sleep(1500 * time.Millisecond) time.Sleep(750 * time.Millisecond) diff --git a/leaktest_utils_test.go b/leaktest_utils_test.go index 77a1d1f..78dccbc 100644 --- a/leaktest_utils_test.go +++ b/leaktest_utils_test.go @@ -2,8 +2,8 @@ package leaktest import ( "context" - "log" "net/http" + "net/http/httptest" "time" ) @@ -13,30 +13,18 @@ func index() http.Handler { }) } -func startKeepAliveEnabledServer(ctx context.Context) { - router := http.NewServeMux() - router.Handle("/", index()) - - server := &http.Server{ - Addr: ":8091", - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 15 * time.Second, - } +func startKeepAliveEnabledServer(ctx context.Context) *httptest.Server { + server := httptest.NewUnstartedServer(index()) + server.Config.ReadTimeout = 5 * time.Second + server.Config.WriteTimeout = 10 * time.Second + server.Config.IdleTimeout = 15 * time.Second + server.Config.SetKeepAlivesEnabled(true) + server.Start() go func() { <-ctx.Done() - - server.SetKeepAlivesEnabled(false) - if err := server.Shutdown(ctx); err != nil { - log.Fatalf("Could not gracefully shutdown the server: %v\n", err) - } + server.Close() }() - log.Println("Server is ready to handle requests at", server.Addr) - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Could not listen on %s: %v\n", server.Addr, err) - } - - log.Println("Server stopped") + return server }