diff --git a/client.go b/client.go index cda0210..9ecab4f 100644 --- a/client.go +++ b/client.go @@ -432,6 +432,48 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo return false, nil } +// ErrorPropagatedRetryPolicy is the same as DefaultRetryPolicy, except it +// propagates errors back instead of returning nil. This allows you to inspect +// why it decided to retry or not. +func ErrorPropagatedRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { + // do not retry on context.Canceled or context.DeadlineExceeded + if ctx.Err() != nil { + return false, ctx.Err() + } + + if err != nil { + if v, ok := err.(*url.Error); ok { + // Don't retry if the error was due to too many redirects. + if redirectsErrorRe.MatchString(v.Error()) { + return false, v + } + + // Don't retry if the error was due to an invalid protocol scheme. + if schemeErrorRe.MatchString(v.Error()) { + return false, v + } + + // Don't retry if the error was due to TLS cert verification failure. + if _, ok := v.Err.(x509.UnknownAuthorityError); ok { + return false, v + } + } + + // The error is likely recoverable so retry. + return true, nil + } + + // Check the response code. We retry on 500-range responses to allow + // the server time to recover, as 500's are typically not permanent + // errors and may relate to outages on the server side. This will catch + // invalid response codes as well, like 0 and 999. + if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != 501) { + return true, fmt.Errorf("unexpected HTTP status %s", resp.Status) + } + + return false, nil +} + // DefaultBackoff provides a default callback for Client.Backoff which // will perform exponential backoff based on the attempt number and limited // by the provided minimum and maximum durations. @@ -509,9 +551,13 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } var resp *http.Response - var err error + var attempt int + var shouldRetry bool + var doErr, checkErr error for i := 0; ; i++ { + attempt++ + var code int // HTTP response code // Always rewind the request body when non-nil. @@ -540,20 +586,20 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } // Attempt the request - resp, err = c.HTTPClient.Do(req.Request) + resp, doErr = c.HTTPClient.Do(req.Request) if resp != nil { code = resp.StatusCode } // Check if we should continue with retries. - checkOK, checkErr := c.CheckRetry(req.Context(), resp, err) + shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) - if err != nil { + if doErr != nil { switch v := logger.(type) { case Logger: - v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err) + v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, doErr) case LeveledLogger: - v.Error("request failed", "error", err, "method", req.Method, "url", req.URL) + v.Error("request failed", "error", doErr, "method", req.Method, "url", req.URL) } } else { // Call this here to maintain the behavior of logging all requests, @@ -571,13 +617,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } } - // Now decide if we should continue. - if !checkOK { - if checkErr != nil { - err = checkErr - } - c.HTTPClient.CloseIdleConnections() - return resp, err + if !shouldRetry { + break } // We do this before drainBody because there's no need for the I/O if @@ -588,7 +629,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } // We're going to retry, consume any response to reuse the connection. - if err == nil && resp != nil { + if doErr == nil { c.drainBody(resp.Body) } @@ -613,19 +654,37 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } } + // this is the closest we have to success criteria + if doErr == nil && checkErr == nil && !shouldRetry { + return resp, nil + } + + defer c.HTTPClient.CloseIdleConnections() + + err := doErr + if checkErr != nil { + err = checkErr + } + if c.ErrorHandler != nil { - c.HTTPClient.CloseIdleConnections() - return c.ErrorHandler(resp, err, c.RetryMax+1) + return c.ErrorHandler(resp, err, attempt) } // By default, we close the response body and return an error without // returning the response if resp != nil { - resp.Body.Close() + c.drainBody(resp.Body) } - c.HTTPClient.CloseIdleConnections() - return nil, fmt.Errorf("%s %s giving up after %d attempts", - req.Method, req.URL, c.RetryMax+1) + + // this means CheckRetry thought the request was a failure, but didn't + // communicate why + if err == nil { + return nil, fmt.Errorf("%s %s giving up after %d attempt(s)", + req.Method, req.URL, attempt) + } + + return nil, fmt.Errorf("%s %s giving up after %d attempt(s): %w", + req.Method, req.URL, attempt, err) } // Try to read the response body so we can reuse this connection. diff --git a/client_test.go b/client_test.go index dcacb01..27442e0 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "io/ioutil" "net" @@ -255,22 +256,44 @@ func TestClient_Do_fails(t *testing.T) { })) defer ts.Close() - // Create the client. Use short retry windows so we fail faster. - client := NewClient() - client.RetryWaitMin = 10 * time.Millisecond - client.RetryWaitMax = 10 * time.Millisecond - client.RetryMax = 2 - - // Create the request - req, err := NewRequest("POST", ts.URL, nil) - if err != nil { - t.Fatalf("err: %v", err) + tests := []struct { + name string + cr CheckRetry + err string + }{ + { + name: "default_retry_policy", + cr: DefaultRetryPolicy, + err: "giving up after 3 attempt(s)", + }, + { + name: "error_propagated_retry_policy", + cr: ErrorPropagatedRetryPolicy, + err: "giving up after 3 attempt(s): unexpected HTTP status 500 Internal Server Error", + }, } - // Send the request. - _, err = client.Do(req) - if err == nil || !strings.Contains(err.Error(), "giving up") { - t.Fatalf("expected giving up error, got: %#v", err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create the client. Use short retry windows so we fail faster. + client := NewClient() + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 10 * time.Millisecond + client.CheckRetry = tt.cr + client.RetryMax = 2 + + // Create the request + req, err := NewRequest("POST", ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Send the request. + _, err = client.Do(req) + if err == nil || !strings.HasSuffix(err.Error(), tt.err) { + t.Fatalf("expected giving up error, got: %#v", err) + } + }) } } @@ -462,8 +485,10 @@ func TestClient_RequestWithContext(t *testing.T) { t.Fatalf("CheckRetry called %d times, expected 1", called) } - if err != context.Canceled { - t.Fatalf("Expected context.Canceled err, got: %v", err) + e := fmt.Sprintf("GET %s giving up after 1 attempt(s): %s", ts.URL, context.Canceled.Error()) + + if err.Error() != e { + t.Fatalf("Expected err to contain %s, got: %v", e, err) } } @@ -493,10 +518,9 @@ func TestClient_CheckRetry(t *testing.T) { t.Fatalf("CheckRetry called %d times, expected 1", called) } - if err != retryErr { + if err.Error() != fmt.Sprintf("GET %s giving up after 2 attempt(s): retryError", ts.URL) { t.Fatalf("Expected retryError, got:%v", err) } - } func TestClient_DefaultRetryPolicy_TLS(t *testing.T) {