diff --git a/fs/remote/resolver.go b/fs/remote/resolver.go index 5ecfb4e42..ba46e93ee 100644 --- a/fs/remote/resolver.go +++ b/fs/remote/resolver.go @@ -194,6 +194,7 @@ func newHTTPFetcher(ctx context.Context, fc *fetcherConfig) (*httpFetcher, error rt.Client.RetryWaitMax = fc.maxWait rt.Client.Backoff = socihttp.BackoffStrategy rt.Client.CheckRetry = socihttp.RetryStrategy + rt.Client.ErrorHandler = socihttp.HandleHTTPError timeout = rt.Client.HTTPClient.Timeout } diff --git a/util/http/log/redact_http_query_values.go b/util/http/log/redact_http_query_values.go new file mode 100644 index 000000000..039d3d2a7 --- /dev/null +++ b/util/http/log/redact_http_query_values.go @@ -0,0 +1,51 @@ +/* + Copyright The Soci Snapshotter Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package log + +import ( + "errors" + "net/url" +) + +// RedactHTTPQueryValuesFromError is a log utility to parse an error as a URL error and redact +// HTTP query values to prevent leaking sensitive information like encoded credentials or tokens. +func RedactHTTPQueryValuesFromError(err error) error { + var urlErr *url.Error + + if err != nil && errors.As(err, &urlErr) { + url, urlParseErr := url.Parse(urlErr.URL) + if urlParseErr == nil { + RedactHTTPQueryValuesFromURL(url) + urlErr.URL = url.Redacted() + return urlErr + } + } + + return err +} + +// RedactHTTPQueryValuesFromURL redacts HTTP query values from a URL. +func RedactHTTPQueryValuesFromURL(url *url.URL) { + if url != nil { + if query := url.Query(); len(query) > 0 { + for k := range query { + query.Set(k, "redacted") + } + url.RawQuery = query.Encode() + } + } +} diff --git a/util/http/log/redact_http_query_values_test.go b/util/http/log/redact_http_query_values_test.go new file mode 100644 index 000000000..968dcbbc8 --- /dev/null +++ b/util/http/log/redact_http_query_values_test.go @@ -0,0 +1,223 @@ +/* + Copyright The Soci Snapshotter Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package log + +import ( + "bytes" + "errors" + "net/url" + "strings" + "testing" + + "github.com/sirupsen/logrus" +) + +const ( + // mockURL is a fake URL modeling soci-snapshotter fetching content from S3. + mockURL = "https://s3.us-east-1.amazonaws.com/981ebdad55863b3631dce86a228a3ea230dc87673a06a7d216b1275d4dd707c9/12d7153d7eee2fd595a25e5378384f1ae4b6a1658298a54c5bd3f951ec50b7cb" + + // mockQuery is a fake HTTP query with sensitive information which should be redacted. + mockQuery = "?username=admin&password=admin" + + // redactedQuery is the expected result of redacting mockQuery. + // The query values will be sorted by key as a side-effect of encoding the URL query string back into the URL. + // See https://pkg.go.dev/net/url#Values.Encode + redactedQuery = "?password=redacted&username=redacted" +) + +func TestRedactHTTPQueryValuesFromError(t *testing.T) { + testcases := []struct { + Name string + Description string + Err error + Assert func(*testing.T, error) + }{ + { + Name: "NilError", + Description: "Utility should handle nil error gracefully", + Err: nil, + Assert: func(t *testing.T, actual error) { + if actual != nil { + t.Fatalf("Expected nil error, got '%v'", actual) + } + }, + }, + { + Name: "NonURLError", + Description: "Utility should not modify an error if error is not a URL error", + Err: errors.New("this error is not a URL error"), + Assert: func(t *testing.T, actual error) { + const expected = "this error is not a URL error" + if strings.Compare(expected, actual.Error()) != 0 { + t.Fatalf("Expected '%s', got '%v'", expected, actual) + } + }, + }, + { + Name: "ErrorWithNoHTTPQuery", + Description: "Utility should not modify an error if no HTTP queries are present.", + Err: &url.Error{ + Op: "GET", + URL: mockURL, + Err: errors.New("connect: connection refused"), + }, + Assert: func(t *testing.T, actual error) { + const expected = "GET \"" + mockURL + "\": connect: connection refused" + if strings.Compare(expected, actual.Error()) != 0 { + t.Fatalf("Expected '%s', got '%v'", expected, actual) + } + }, + }, + { + Name: "ErrorWithHTTPQuery", + Description: "Utility should redact HTTP query values in errors to prevent logging sensitive information.", + Err: &url.Error{ + Op: "GET", + URL: mockURL + mockQuery, + Err: errors.New("connect: connection refused"), + }, + Assert: func(t *testing.T, actual error) { + const expected = "GET \"" + mockURL + redactedQuery + "\": connect: connection refused" + if strings.Compare(expected, actual.Error()) != 0 { + t.Fatalf("Expected '%s', got '%v'", expected, actual) + } + }, + }, + } + + for _, testcase := range testcases { + t.Run(testcase.Name, func(t *testing.T) { + actual := RedactHTTPQueryValuesFromError(testcase.Err) + testcase.Assert(t, actual) + }) + } +} + +func TestRedactHTTPQueryValuesFromURL(t *testing.T) { + testcases := []struct { + Name string + Description string + URL *url.URL + Assert func(*testing.T, *url.URL) + }{ + { + Name: "NilURL", + Description: "Utility should handle nil URL gracefully", + URL: nil, + Assert: func(t *testing.T, url *url.URL) { + if url != nil { + t.Fatalf("Expected got '%v'", url) + } + }, + }, + { + Name: "URLWithNoQueries", + Description: "Utility should not modify a URL with no queries", + URL: &url.URL{}, + Assert: func(t *testing.T, url *url.URL) { + if len(url.RawQuery) != 0 { + t.Fatalf("Expected '' got '%s'", url.RawQuery) + } + }, + }, + { + Name: "URLWithQueries", + Description: "Utility should not modify a URL with no queries", + URL: &url.URL{RawQuery: "key1=value1&key2=value2"}, + Assert: func(t *testing.T, url *url.URL) { + const expected = "key1=redacted&key2=redacted" + if strings.Compare(expected, url.RawQuery) != 0 { + t.Fatalf("Expected '%s', got '%s'", expected, url.RawQuery) + } + }, + }, + } + + for _, testcase := range testcases { + t.Run(testcase.Name, func(t *testing.T) { + RedactHTTPQueryValuesFromURL(testcase.URL) + testcase.Assert(t, testcase.URL) + }) + } +} + +func BenchmarkRedactHTTPQueryValuesOverhead(b *testing.B) { + benchmarks := []struct { + Name string + Description string + Err error + Log func(*logrus.Entry, error) + }{ + { + Name: "BaselineLogging", + Description: "Log a message to memory without redaction to measure baseline.", + Err: &url.Error{ + Op: "GET", + URL: mockURL + mockQuery, + Err: errors.New("connect: connection refused"), + }, + Log: func(logger *logrus.Entry, err error) { + logger.WithError(err).Info("Error on HTTP Get") + }, + }, + { + Name: "WithoutReplacement", + Description: "Log a message with no HTTP query values to memory with redaction utility to measure the flat overhead.", + Err: &url.Error{ + Op: "GET", + URL: mockURL, + Err: errors.New("connect: connection refused"), + }, + Log: func(logger *logrus.Entry, err error) { + logger.WithError(RedactHTTPQueryValuesFromError(err)).Info("Error on HTTP Get") + }, + }, + { + Name: "WithErrorReplacement", + Description: "Log a message with HTTP query values to memory with redaction utility to measure replacement overhead.", + Err: &url.Error{ + Op: "GET", + URL: mockURL + mockQuery, + Err: errors.New("connect: connection refused"), + }, + Log: func(logger *logrus.Entry, err error) { + logger.WithError(RedactHTTPQueryValuesFromError(err)).Info("Error on HTTP Get") + }, + }, + } + + setupUUT := func() *logrus.Entry { + entry := &logrus.Entry{ + Logger: logrus.New(), + } + + entry.Logger.Out = bytes.NewBuffer([]byte{}) + + return entry + } + + for _, benchmark := range benchmarks { + b.Run(benchmark.Name, func(b *testing.B) { + uut := setupUUT() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmark.Log(uut, benchmark.Err) + } + }) + } +} diff --git a/util/http/retry.go b/util/http/retry.go index a66288234..215ac9d53 100644 --- a/util/http/retry.go +++ b/util/http/retry.go @@ -19,12 +19,14 @@ package http import ( "context" "fmt" + "io" "math/rand" "net" "net/http" "time" "github.com/awslabs/soci-snapshotter/config" + logutil "github.com/awslabs/soci-snapshotter/util/http/log" "github.com/awslabs/soci-snapshotter/version" "github.com/containerd/log" rhttp "github.com/hashicorp/go-retryablehttp" @@ -49,6 +51,7 @@ func NewRetryableClient(config config.RetryableHTTPClientConfig) *http.Client { rhttpClient.Backoff = BackoffStrategy rhttpClient.CheckRetry = RetryStrategy rhttpClient.HTTPClient.Timeout = time.Duration(config.RequestTimeoutMsec) * time.Millisecond + rhttpClient.ErrorHandler = HandleHTTPError // set timeouts innerTransport := rhttpClient.HTTPClient.Transport @@ -83,9 +86,50 @@ func RetryStrategy(ctx context.Context, resp *http.Response, err error) (bool, e retry, err2 := rhttp.DefaultRetryPolicy(ctx, resp, err) if retry { log.G(ctx).WithFields(logrus.Fields{ - "error": err, + "error": logutil.RedactHTTPQueryValuesFromError(err), "response": resp, }).Debugf("retrying request") } - return retry, err2 + return retry, logutil.RedactHTTPQueryValuesFromError(err2) +} + +// HandleHTTPError implements retryablehttp client's ErrorHandler to ensure returned errors +// have HTTP query values redacted to prevent leaking sensitive information like encoded credentials or tokens. +func HandleHTTPError(resp *http.Response, err error, attempts int) (*http.Response, error) { + var ( + method = "unknown" + url = "unknown" + ) + + if resp != nil { + drain(resp.Body) + + if resp.Request != nil { + + method = resp.Request.Method + + if resp.Request.URL != nil { + logutil.RedactHTTPQueryValuesFromURL(resp.Request.URL) + url = resp.Request.URL.Redacted() + } + } + } + + if err == nil { + return nil, fmt.Errorf("%s \"%s\": giving up request after %d attempt(s)", method, url, attempts) + } + + err = logutil.RedactHTTPQueryValuesFromError(err) + return nil, fmt.Errorf("%s \"%s\": giving up request after %d attempt(s): %w", method, url, attempts, err) +} + +// Try to read and discard the response body so the connection can be reused. +// See https://pkg.go.dev/net/http#Response for more information. +func drain(body io.ReadCloser) { + defer body.Close() + + // We want to consume response bodies to maintain HTTP connections, + // but also want to limit the size read. 4KiB is arbitirary but reasonable. + const responseReadLimit = int64(4096) + _, _ = io.Copy(io.Discard, io.LimitReader(body, responseReadLimit)) } diff --git a/util/http/retry_test.go b/util/http/retry_test.go new file mode 100644 index 000000000..0c65db1f4 --- /dev/null +++ b/util/http/retry_test.go @@ -0,0 +1,160 @@ +/* + Copyright The Soci Snapshotter Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package http + +import ( + "errors" + "io" + "net/http" + "net/url" + "strings" + "testing" +) + +const ( + // mockURL is a fake URL modeling soci-snapshotter fetching content from S3. + mockURL = "https://s3.us-east-1.amazonaws.com/981ebdad55863b3631dce86a228a3ea230dc87673a06a7d216b1275d4dd707c9/12d7153d7eee2fd595a25e5378384f1ae4b6a1658298a54c5bd3f951ec50b7cb" + + // mockQuery is a fake HTTP query with sensitive information which should be redacted. + mockQuery = "?username=admin&password=admin" + + // redactedQuery is the expected result of redacting mockQuery. + // The query values will be sorted by key as a side-effect of encoding the URL query string back into the URL. + // See https://pkg.go.dev/net/url#Values.Encode + redactedQuery = "?password=redacted&username=redacted" +) + +func TestHandleHTTPErrorRedactsHTTPQueries(t *testing.T) { + createHTTPResponse := func(path string, query string) *http.Response { + url, err := url.Parse(path + query) + if err != nil { + panic(err) + } + return &http.Response{ + Body: &mockBody{}, + Request: &http.Request{ + Method: "GET", + URL: url, + }, + } + } + + testcases := []struct { + Name string + Description string + Response *http.Response + Err error + Attempts int + Assert func(*testing.T, *http.Response, error) + }{ + { + Name: "NilResponseOrError", + Description: "Handler should gracefully handle a nil response or error", + Response: nil, + Err: nil, + Attempts: 10, + Assert: func(t *testing.T, response *http.Response, err error) { + if response != nil { + t.Fatalf("Expected nil response, got '%v'", response) + } + + const expected = "unknown \"unknown\": giving up request after 10 attempt(s)" + if strings.Compare(expected, err.Error()) != 0 { + t.Fatalf("Expected '%s', got '%s'", expected, err.Error()) + } + }, + }, + { + Name: "RedactURLInResponse", + Description: "Handler should redact HTTP queries in response", + Response: createHTTPResponse(mockURL, mockQuery), + Err: errors.New("connect: connection refused"), + Attempts: 10, + Assert: func(t *testing.T, response *http.Response, err error) { + if response != nil { + t.Fatalf("Expected nil response, got '%v'", response) + } + + const expected = "GET \"" + mockURL + redactedQuery + "\": giving up request after 10 attempt(s): connect: connection refused" + if strings.Compare(expected, err.Error()) != 0 { + t.Fatalf("Expected '%s', got '%s'", expected, err.Error()) + } + }, + }, + { + Name: "RedactURLInError", + Description: "Handler should redact HTTP queries in error", + Response: createHTTPResponse(mockURL, ""), + Err: &url.Error{ + Op: "GET", + URL: mockURL + mockQuery, + Err: errors.New("connect: connection refused"), + }, + Attempts: 10, + Assert: func(t *testing.T, response *http.Response, err error) { + if response != nil { + t.Fatalf("Expected nil response, got '%v'", response) + } + + const expected = "GET \"" + mockURL + "\": giving up request after 10 attempt(s): GET \"" + mockURL + redactedQuery + "\": connect: connection refused" + if strings.Compare(expected, err.Error()) != 0 { + t.Fatalf("Expected '%s', got '%s'", expected, err.Error()) + } + }, + }, + } + + for _, testcase := range testcases { + t.Run(testcase.Name, func(t *testing.T) { + response, err := HandleHTTPError(testcase.Response, testcase.Err, testcase.Attempts) + testcase.Assert(t, response, err) + }) + } +} + +func TestHandleHTTPErrorReadsAndClosesResponseBody(t *testing.T) { + body := &mockBody{} + response := &http.Response{ + Body: body, + } + err := errors.New("connect: connection refused") + + _, _ = HandleHTTPError(response, err, 0) + + if !body.WasRead { + t.Fatalf("The response body was not read by handler") + } + + if !body.Closed { + t.Fatalf("The response body was not closed by handler") + } +} + +type mockBody struct { + Closed bool + WasRead bool +} + +func (b *mockBody) Read(_ []byte) (int, error) { + b.WasRead = true + return 0, io.EOF +} + +func (b *mockBody) Close() error { + b.Closed = true + return nil +}