From 5a2e692d45aeb3856583003924cf73e32b62aaf6 Mon Sep 17 00:00:00 2001 From: souleb Date: Tue, 17 Jan 2023 13:34:24 +0100 Subject: [PATCH] feat!: retryable http client (#398) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If implemented, this will provide a default http client with retry. The retry function is an exponential back off 0.25s * 2^n ± 10% and max 5 attempts. The client is the default client of `auth.Client` BREAKING CHANGE: `auth.DefaultClient` uses `retry.DefaultClient` instead of `http.DefaultClient` Fixes: #147 Co-authored-by: Shiwei Zhang Signed-off-by: Soule BA --- example_test.go | 2 +- registry/example_test.go | 2 +- registry/remote/auth/client.go | 7 ++ registry/remote/example_test.go | 2 +- registry/remote/retry/client.go | 114 ++++++++++++++++++++ registry/remote/retry/client_test.go | 97 +++++++++++++++++ registry/remote/retry/policy.go | 154 +++++++++++++++++++++++++++ registry/remote/retry/policy_test.go | 64 +++++++++++ 8 files changed, 439 insertions(+), 3 deletions(-) create mode 100644 registry/remote/retry/client.go create mode 100644 registry/remote/retry/client_test.go create mode 100644 registry/remote/retry/policy.go create mode 100644 registry/remote/retry/policy_test.go diff --git a/example_test.go b/example_test.go index c1e4c200..82fa1260 100644 --- a/example_test.go +++ b/example_test.go @@ -182,7 +182,7 @@ func TestMain(m *testing.M) { panic(err) } remoteHost = u.Host - http.DefaultClient = httpsServer.Client() + http.DefaultTransport = httpsServer.Client().Transport os.Exit(m.Run()) } diff --git a/registry/example_test.go b/registry/example_test.go index 70f102c1..2463667d 100644 --- a/registry/example_test.go +++ b/registry/example_test.go @@ -65,7 +65,7 @@ func TestMain(m *testing.M) { panic(err) } host = u.Host - http.DefaultClient = ts.Client() + http.DefaultTransport = ts.Client().Transport os.Exit(m.Run()) } diff --git a/registry/remote/auth/client.go b/registry/remote/auth/client.go index ef071b98..9e8947e6 100644 --- a/registry/remote/auth/client.go +++ b/registry/remote/auth/client.go @@ -28,10 +28,12 @@ import ( "strings" "oras.land/oras-go/v2/registry/remote/internal/errutil" + "oras.land/oras-go/v2/registry/remote/retry" ) // DefaultClient is the default auth-decorated client. var DefaultClient = &Client{ + Client: retry.DefaultClient, Header: http.Header{ "User-Agent": {"oras-go"}, }, @@ -68,6 +70,11 @@ type Client struct { // Client is the underlying HTTP client used to access the remote // server. // If nil, http.DefaultClient is used. + // It is possible to use the default retry client from the package + // `oras.land/oras-go/v2/registry/remote/retry`. That client is already available + // in the DefaultClient. + // It is also possible to use a custom client. For example, github.com/hashicorp/go-retryablehttp + // is a popular HTTP client that supports retries. Client *http.Client // Header contains the custom headers to be added to each request. diff --git a/registry/remote/example_test.go b/registry/remote/example_test.go index bbd292eb..370486d7 100644 --- a/registry/remote/example_test.go +++ b/registry/remote/example_test.go @@ -214,7 +214,7 @@ func TestMain(m *testing.M) { panic(err) } host = u.Host - http.DefaultClient = ts.Client() + http.DefaultTransport = ts.Client().Transport os.Exit(m.Run()) } diff --git a/registry/remote/retry/client.go b/registry/remote/retry/client.go new file mode 100644 index 00000000..5e986ea0 --- /dev/null +++ b/registry/remote/retry/client.go @@ -0,0 +1,114 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retry + +import ( + "net/http" + "time" +) + +// DefaultClient is a client with the default retry policy. +var DefaultClient = NewClient() + +// NewClient creates an HTTP client with the default retry policy. +func NewClient() *http.Client { + return &http.Client{ + Transport: NewTransport(nil), + } +} + +// Transport is an HTTP transport with retry policy. +type Transport struct { + // Base is the underlying HTTP transport to use. + // If nil, http.DefaultTransport is used for round trips. + Base http.RoundTripper + + // Policy returns a retry Policy to use for the request. + // If nil, DefaultPolicy is used to determine if the request should be retried. + Policy func() Policy +} + +// NewTransport creates an HTTP Transport with the default retry policy. +func NewTransport(base http.RoundTripper) *Transport { + return &Transport{ + Base: base, + } +} + +// RoundTrip executes a single HTTP transaction, returning a Response for the +// provided Request. +// It relies on the configured Policy to determine if the request should be +// retried and to backoff. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + policy := t.policy() + attempt := 0 + for { + resp, respErr := t.roundTrip(req) + duration, err := policy.Retry(attempt, resp, respErr) + if err != nil { + if respErr == nil { + resp.Body.Close() + } + return nil, err + } + if duration < 0 { + return resp, respErr + } + + // rewind the body if possible + if req.Body != nil { + if req.GetBody == nil { + // body can't be rewound, so we can't retry + return resp, respErr + } + body, err := req.GetBody() + if err != nil { + // failed to rewind the body, so we can't retry + return resp, respErr + } + req.Body = body + } + + // close the response body if needed + if respErr == nil { + resp.Body.Close() + } + + timer := time.NewTimer(duration) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } + attempt++ + } +} + +func (t *Transport) roundTrip(req *http.Request) (*http.Response, error) { + if t.Base == nil { + return http.DefaultTransport.RoundTrip(req) + } + return t.Base.RoundTrip(req) +} + +func (t *Transport) policy() Policy { + if t.Policy == nil { + return DefaultPolicy + } + return t.Policy() +} diff --git a/registry/remote/retry/client_test.go b/registry/remote/retry/client_test.go new file mode 100644 index 00000000..34a566ad --- /dev/null +++ b/registry/remote/retry/client_test.go @@ -0,0 +1,97 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retry + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" +) + +func Test_Client(t *testing.T) { + testCases := []struct { + name string + attempts int + retryAfter bool + StatusCode int + expectedErr bool + }{ + { + name: "successful request with 0 retry", + attempts: 1, retryAfter: false, StatusCode: http.StatusOK, expectedErr: false, + }, + { + name: "successful request with 1 retry caused by rate limit", + // 1 request + 1 retry = 2 attempts + attempts: 2, retryAfter: true, StatusCode: http.StatusTooManyRequests, expectedErr: false, + }, + { + name: "successful request with 1 retry caused by 408", + // 1 request + 1 retry = 2 attempts + attempts: 2, retryAfter: false, StatusCode: http.StatusRequestTimeout, expectedErr: false, + }, + { + name: "successful request with 2 retries caused by 429", + // 1 request + 2 retries = 3 attempts + attempts: 3, retryAfter: false, StatusCode: http.StatusTooManyRequests, expectedErr: false, + }, + { + name: "unsuccessful request with 6 retries caused by too many retries", + // 1 request + 6 retries = 7 attempts + attempts: 7, retryAfter: false, StatusCode: http.StatusServiceUnavailable, expectedErr: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + count := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + if count < tc.attempts { + if tc.retryAfter { + w.Header().Set("Retry-After", "1") + } + http.Error(w, "error", tc.StatusCode) + return + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + req, err := http.NewRequest(http.MethodPost, ts.URL, bytes.NewReader([]byte("test"))) + if err != nil { + t.Fatalf("failed to create test request: %v", err) + } + + resp, err := DefaultClient.Do(req) + if err != nil { + t.Fatalf("failed to do test request: %v", err) + } + if tc.expectedErr { + if count != (tc.attempts - 1) { + t.Errorf("expected attempts %d, got %d", tc.attempts, count) + } + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("expected status code %d, got %d", http.StatusServiceUnavailable, resp.StatusCode) + } + return + } + if tc.attempts != count { + t.Errorf("expected attempts %d, got %d", tc.attempts, count) + } + }) + } +} diff --git a/registry/remote/retry/policy.go b/registry/remote/retry/policy.go new file mode 100644 index 00000000..fe7fadee --- /dev/null +++ b/registry/remote/retry/policy.go @@ -0,0 +1,154 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retry + +import ( + "hash/maphash" + "math" + "math/rand" + "net" + "net/http" + "strconv" + "time" +) + +// headerRetryAfter is the header key for Retry-After. +const headerRetryAfter = "Retry-After" + +// DefaultPolicy is a policy with fine-tuned retry parameters. +// It uses an exponential backoff with jitter. +var DefaultPolicy Policy = &GenericPolicy{ + Retryable: DefaultPredicate, + Backoff: DefaultBackoff, + MinWait: 200 * time.Millisecond, + MaxWait: 3 * time.Second, + MaxRetry: 5, +} + +// DefaultPredicate is a predicate that retries on 5xx errors, 429 Too Many +// Requests, 408 Request Timeout and on network dial timeout. +var DefaultPredicate Predicate = func(resp *http.Response, err error) (bool, error) { + if err != nil { + // retry on Dial timeout + if err, ok := err.(net.Error); ok && err.Timeout() { + return true, nil + } + return false, err + } + + if resp.StatusCode == http.StatusRequestTimeout || resp.StatusCode == http.StatusTooManyRequests { + return true, nil + } + + if resp.StatusCode == 0 || resp.StatusCode >= 500 { + return true, nil + } + + return false, nil +} + +// DefaultBackoff is a backoff that uses an exponential backoff with jitter. +// It uses a base of 250ms, a factor of 2 and a jitter of 10%. +var DefaultBackoff Backoff = ExponentialBackoff(250*time.Millisecond, 2, 0.1) + +// Policy is a retry policy. +type Policy interface { + // Retry returns the duration to wait before retrying the request. + // It returns a negative value if the request should not be retried. + // The attempt is used to: + // - calculate the backoff duration, the default backoff is an exponential backoff. + // - determine if the request should be retried. + // The attempt starts at 0 and should be less than MaxRetry for the request to + // be retried. + Retry(attempt int, resp *http.Response, err error) (time.Duration, error) +} + +// Predicate is a function that returns true if the request should be retried. +type Predicate func(resp *http.Response, err error) (bool, error) + +// Backoff is a function that returns the duration to wait before retrying the +// request. The attempt, is the next attempt number. The response is the +// response from the previous request. +type Backoff func(attempt int, resp *http.Response) time.Duration + +// ExponentialBackoff returns a Backoff that uses an exponential backoff with +// jitter. The backoff is calculated as: +// +// temp = backoff * factor ^ attempt +// interval = temp * (1 - jitter) + rand.Int63n(2 * jitter * temp) +// +// The HTTP response is checked for a Retry-After header. If it is present, the +// value is used as the backoff duration. +func ExponentialBackoff(backoff time.Duration, factor, jitter float64) Backoff { + return func(attempt int, resp *http.Response) time.Duration { + var h maphash.Hash + h.SetSeed(maphash.MakeSeed()) + rand := rand.New(rand.NewSource(int64(h.Sum64()))) + + // check Retry-After + if resp != nil && resp.StatusCode == http.StatusTooManyRequests { + if v := resp.Header.Get(headerRetryAfter); v != "" { + if retryAfter, _ := strconv.ParseInt(v, 10, 64); retryAfter > 0 { + return time.Duration(retryAfter) * time.Second + } + } + } + + // do exponential backoff with jitter + temp := float64(backoff) * math.Pow(factor, float64(attempt)) + return time.Duration(temp*(1-jitter)) + time.Duration(rand.Int63n(int64(2*jitter*temp))) + } +} + +// GenericPolicy is a generic retry policy. +type GenericPolicy struct { + // Retryable is a predicate that returns true if the request should be + // retried. + Retryable Predicate + + // Backoff is a function that returns the duration to wait before retrying. + Backoff Backoff + + // MinWait is the minimum duration to wait before retrying. + MinWait time.Duration + + // MaxWait is the maximum duration to wait before retrying. + MaxWait time.Duration + + // MaxRetry is the maximum number of retries. + MaxRetry int +} + +// Retry returns the duration to wait before retrying the request. +// It returns -1 if the request should not be retried. +func (p *GenericPolicy) Retry(attempt int, resp *http.Response, err error) (time.Duration, error) { + if attempt >= p.MaxRetry { + return -1, nil + } + if ok, err := p.Retryable(resp, err); err != nil { + return -1, err + } else if !ok { + return -1, nil + } + backoff := p.Backoff(attempt, resp) + if backoff < p.MinWait { + backoff = p.MinWait + } + if backoff > p.MaxWait { + backoff = p.MaxWait + } + return backoff, nil +} diff --git a/registry/remote/retry/policy_test.go b/registry/remote/retry/policy_test.go new file mode 100644 index 00000000..3a925940 --- /dev/null +++ b/registry/remote/retry/policy_test.go @@ -0,0 +1,64 @@ +/* +Copyright The ORAS Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retry + +import ( + "testing" + "time" +) + +func Test_ExponentialBackoff(t *testing.T) { + testCases := []struct { + name string + attempt int + expectedBackoff time.Duration + }{ + { + name: "attempt 0 should have a backoff of 0,25s ± 10%", + attempt: 0, expectedBackoff: 250 * time.Millisecond, + }, + { + name: "attempt 1 should have a backoff of 0,5s ± 10%", + attempt: 1, expectedBackoff: 500 * time.Millisecond, + }, + { + name: "attempt 2 should have a backoff of 1s ± 10%", + attempt: 2, expectedBackoff: 1 * time.Second, + }, + { + name: "attempt 3 should have a backoff of 2s ± 10%", + attempt: 3, expectedBackoff: 2 * time.Second, + }, + { + name: "attempt 4 should have a backoff of 4s ± 10%", + attempt: 4, expectedBackoff: 4 * time.Second, + }, + { + name: "attempt 5 should have a backoff of 8s ± 10%", + attempt: 5, expectedBackoff: 8 * time.Second, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + b := DefaultBackoff(tc.attempt, nil) + base := float64(tc.expectedBackoff) + if !(b >= time.Duration(base*0.9) && b <= time.Duration(base+base*0.1)) { + t.Errorf("expected backoff to be %s + jitter, got %s", time.Duration(base), b) + } + }) + } +}