From db684ac488d8258a96b06658712facac1df682f0 Mon Sep 17 00:00:00 2001 From: Harshavardhana Date: Tue, 20 Jun 2017 09:44:58 -0700 Subject: [PATCH] api: ListBuckets() should have a default region for AWS S3. Fixes #717 --- api.go | 25 +++++++++++++++++-------- utils.go | 21 +++++++++++++++++++++ utils_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 8 deletions(-) diff --git a/api.go b/api.go index e2479805a7..5ea8abdd9c 100644 --- a/api.go +++ b/api.go @@ -558,15 +558,24 @@ func (c Client) newRequest(method string, metadata requestMetadata) (req *http.R method = "POST" } - var location string - // Gather location only if bucketName is present. - if metadata.bucketName != "" && metadata.bucketLocation == "" { - location, err = c.getBucketLocation(metadata.bucketName) - if err != nil { - return nil, err + location := metadata.bucketLocation + if location == "" { + switch { + case metadata.bucketName != "": + // Gather location only if bucketName is present. + location, err = c.getBucketLocation(metadata.bucketName) + if err != nil { + if ToErrorResponse(err).Code != "AccessDenied" { + return nil, err + } + } + // Upon AccessDenied error on fetching bucket location, default + // to possible locations based on endpoint URL. This can usually + // happen when GetBucketLocation() is disabled using IAM policies. + } + if location == "" { + location = getDefaultLocation(c.endpointURL, c.region) } - } else { - location = metadata.bucketLocation } // Construct a new target URL. diff --git a/utils.go b/utils.go index d06f1f52c0..7531149cbe 100644 --- a/utils.go +++ b/utils.go @@ -192,3 +192,24 @@ func redactSignature(origAuth string) string { // Strip out 256-bit signature from: Signature=<256-bit signature> return regSign.ReplaceAllString(newAuth, "Signature=**REDACTED**") } + +// Get default location returns the location based on the input +// URL `u`, if none of the standard URL matches then if regionOverride +// is given location gets defaulted to that instead. +// +// If no other cases match then the location is set to `us-east-1` +// as a last resort. +func getDefaultLocation(u url.URL, regionOverride string) (location string) { + // Default to location to 'us-east-1'. + switch { + case s3utils.IsAmazonChinaEndpoint(u): + location = "cn-north-1" + case s3utils.IsAmazonGovCloudEndpoint(u): + location = "us-gov-west-1" + case regionOverride != "": + location = regionOverride + default: + location = "us-east-1" + } + return +} diff --git a/utils_test.go b/utils_test.go index 25622398a3..bd3fbbb0c1 100644 --- a/utils_test.go +++ b/utils_test.go @@ -128,6 +128,8 @@ func TestIsValidEndpointURL(t *testing.T) { {"/", nil, true}, {"https://s3.am1;4205;0cazonaws.com", nil, true}, {"https://s3.cn-north-1.amazonaws.com.cn", nil, true}, + {"https://s3.us-gov-west-1.amazonaws.com", nil, true}, + {"https://s3-fips-us-gov-west-1.amazonaws.com", nil, true}, {"https://s3.amazonaws.com/", nil, true}, {"https://storage.googleapis.com/", nil, true}, {"192.168.1.1", ErrInvalidArgument("Endpoint url cannot have fully qualified paths."), false}, @@ -165,6 +167,46 @@ func TestIsValidEndpointURL(t *testing.T) { } } +func TestDefaultBucketLocation(t *testing.T) { + testCases := []struct { + endpointURL url.URL + regionOverride string + expectedLocation string + }{ + // Region override is ignored. - Test 1. + { + endpointURL: url.URL{Host: "s3-fips-us-gov-west-1.amazonaws.com"}, + regionOverride: "us-west-1", + expectedLocation: "us-gov-west-1", + }, + // Region override is honored - Test 2. + { + endpointURL: url.URL{Host: "s3.amazonaws.com"}, + regionOverride: "us-west-1", + expectedLocation: "us-west-1", + }, + // China region should be honored, region override not provided. - Test 3. + { + endpointURL: url.URL{Host: "s3.cn-north-1.amazonaws.com.cn"}, + regionOverride: "", + expectedLocation: "cn-north-1", + }, + // No region provided, no standard region strings provided as well. -- Test 4. + { + endpointURL: url.URL{Host: "s3.amazonaws.com"}, + regionOverride: "", + expectedLocation: "us-east-1", + }, + } + + for i, testCase := range testCases { + retLocation := getDefaultLocation(testCase.endpointURL, testCase.regionOverride) + if testCase.expectedLocation != retLocation { + t.Errorf("Test %d: Expected location %s, got %s", i+1, testCase.expectedLocation, retLocation) + } + } +} + // Tests validate the expiry time validator. func TestIsValidExpiry(t *testing.T) { testCases := []struct {