diff --git a/.changelog/8e6a01197da848c88aaab5adb296abc1.json b/.changelog/8e6a01197da848c88aaab5adb296abc1.json new file mode 100644 index 00000000000..6a3c4a03652 --- /dev/null +++ b/.changelog/8e6a01197da848c88aaab5adb296abc1.json @@ -0,0 +1,8 @@ +{ + "id": "8e6a0119-7da8-48c8-8aaa-b5adb296abc1", + "type": "bugfix", + "description": "Improve uniqueness of default S3Express sesssion credentials cache keying to prevent collision in multi-credential scenarios.", + "modules": [ + "service/s3" + ] +} \ No newline at end of file diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/auth/S3ExpressAuthScheme.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/auth/S3ExpressAuthScheme.java index 95839530ea5..61ec442ec08 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/auth/S3ExpressAuthScheme.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/customization/auth/S3ExpressAuthScheme.java @@ -43,6 +43,7 @@ import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate; import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; +import static software.amazon.smithy.go.codegen.SymbolUtils.buildPackageSymbol; public class S3ExpressAuthScheme implements GoIntegration { private static final ConfigField s3ExpressCredentials = @@ -67,6 +68,14 @@ public class S3ExpressAuthScheme implements GoIntegration { .withClientInput(true) .build(); + private static final ConfigFieldResolver s3ExpressCredentialsOperationFinalizer = + ConfigFieldResolver.builder() + .location(ConfigFieldResolver.Location.OPERATION) + .target(ConfigFieldResolver.Target.FINALIZATION) + .resolver(buildPackageSymbol("finalizeOperationExpressCredentials")) + .withClientInput(true) + .build(); + @Override public void writeAdditionalFiles( GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator @@ -84,6 +93,7 @@ public List getClientPlugins() { .addConfigField(s3ExpressCredentials) .addConfigFieldResolver(s3ExpressCredentialsResolver) .addConfigFieldResolver(s3ExpressCredentialsClientFinalizer) + .addConfigFieldResolver(s3ExpressCredentialsOperationFinalizer) .addAuthSchemeDefinition(SigV4S3ExpressTrait.ID, new SigV4S3Express()) .build() ); diff --git a/service/s3/api_client.go b/service/s3/api_client.go index 6649d914fdf..9daad080d25 100644 --- a/service/s3/api_client.go +++ b/service/s3/api_client.go @@ -111,6 +111,8 @@ func (c *Client) invokeOperation(ctx context.Context, opID string, params interf resolveCredentialProvider(&options) + finalizeOperationExpressCredentials(&options, *c) + finalizeOperationEndpointAuthResolver(&options) for _, fn := range stackFns { diff --git a/service/s3/express_default.go b/service/s3/express_default.go index 3a55e26bc93..3b35a3e5748 100644 --- a/service/s3/express_default.go +++ b/service/s3/express_default.go @@ -2,7 +2,10 @@ package s3 import ( "context" + "crypto/hmac" + "crypto/sha256" "errors" + "fmt" "sync" "time" @@ -17,18 +20,49 @@ const s3ExpressCacheCap = 100 const s3ExpressRefreshWindow = 1 * time.Minute +type cacheKey struct { + CredentialsHash string // hmac(sigv4 akid, sigv4 secret) + Bucket string +} + +func (c cacheKey) Slug() string { + return fmt.Sprintf("%s%s", c.CredentialsHash, c.Bucket) +} + +type sessionCredsCache struct { + mu sync.Mutex + cache cache.Cache +} + +func (c *sessionCredsCache) Get(key cacheKey) (*aws.Credentials, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + if v, ok := c.cache.Get(key); ok { + return v.(*aws.Credentials), true + } + return nil, false +} + +func (c *sessionCredsCache) Put(key cacheKey, creds *aws.Credentials) { + c.mu.Lock() + defer c.mu.Unlock() + + c.cache.Put(key, creds) +} + // The default S3Express provider uses an LRU cache with a capacity of 100. // // Credentials will be refreshed asynchronously when a Retrieve() call is made // for cached credentials within an expiry window (1 minute, currently // non-configurable). type defaultS3ExpressCredentialsProvider struct { - mu sync.Mutex sf singleflight.Group client createSessionAPIClient - credsCache cache.Cache + cache *sessionCredsCache refreshWindow time.Duration + v4creds aws.CredentialsProvider // underlying credentials used for CreateSession } type createSessionAPIClient interface { @@ -37,35 +71,54 @@ type createSessionAPIClient interface { func newDefaultS3ExpressCredentialsProvider() *defaultS3ExpressCredentialsProvider { return &defaultS3ExpressCredentialsProvider{ - credsCache: lru.New(s3ExpressCacheCap), + cache: &sessionCredsCache{ + cache: lru.New(s3ExpressCacheCap), + }, refreshWindow: s3ExpressRefreshWindow, } } +// returns a cloned provider using new base credentials, used when per-op +// config mutations change the credentials provider +func (p *defaultS3ExpressCredentialsProvider) CloneWithBaseCredentials(v4creds aws.CredentialsProvider) *defaultS3ExpressCredentialsProvider { + return &defaultS3ExpressCredentialsProvider{ + client: p.client, + cache: p.cache, + refreshWindow: p.refreshWindow, + v4creds: v4creds, + } +} + func (p *defaultS3ExpressCredentialsProvider) Retrieve(ctx context.Context, bucket string) (aws.Credentials, error) { - p.mu.Lock() - defer p.mu.Unlock() + v4creds, err := p.v4creds.Retrieve(ctx) + if err != nil { + return aws.Credentials{}, fmt.Errorf("get sigv4 creds: %w", err) + } - creds, ok := p.getCacheCredentials(bucket) + key := cacheKey{ + CredentialsHash: gethmac(v4creds.AccessKeyID, v4creds.SecretAccessKey), + Bucket: bucket, + } + creds, ok := p.cache.Get(key) if !ok || creds.Expired() { - return p.awaitDoChanRetrieve(ctx, bucket) + return p.awaitDoChanRetrieve(ctx, key) } if creds.Expires.Sub(sdk.NowTime()) <= p.refreshWindow { - p.doChanRetrieve(ctx, bucket) + p.doChanRetrieve(ctx, key) } return *creds, nil } -func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, bucket string) <-chan singleflight.Result { - return p.sf.DoChan(bucket, func() (interface{}, error) { - return p.retrieve(ctx, bucket) +func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, key cacheKey) <-chan singleflight.Result { + return p.sf.DoChan(key.Slug(), func() (interface{}, error) { + return p.retrieve(ctx, key) }) } -func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, bucket string) (aws.Credentials, error) { - ch := p.doChanRetrieve(ctx, bucket) +func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) { + ch := p.doChanRetrieve(ctx, key) select { case r := <-ch: @@ -75,9 +128,9 @@ func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Co } } -func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, bucket string) (aws.Credentials, error) { +func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) { resp, err := p.client.CreateSession(ctx, &CreateSessionInput{ - Bucket: aws.String(bucket), + Bucket: aws.String(key.Bucket), }) if err != nil { return aws.Credentials{}, err @@ -88,22 +141,10 @@ func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, buck return aws.Credentials{}, err } - p.putCacheCredentials(bucket, creds) + p.cache.Put(key, creds) return *creds, nil } -func (p *defaultS3ExpressCredentialsProvider) getCacheCredentials(bucket string) (*aws.Credentials, bool) { - if v, ok := p.credsCache.Get(bucket); ok { - return v.(*aws.Credentials), true - } - - return nil, false -} - -func (p *defaultS3ExpressCredentialsProvider) putCacheCredentials(bucket string, creds *aws.Credentials) { - p.credsCache.Put(bucket, creds) -} - func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) { if o.Credentials == nil { return nil, errors.New("s3express session credentials unset") @@ -121,3 +162,9 @@ func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) { Expires: *o.Credentials.Expiration, }, nil } + +func gethmac(p, key string) string { + hash := hmac.New(sha256.New, []byte(key)) + hash.Write([]byte(p)) + return string(hash.Sum(nil)) +} diff --git a/service/s3/express_resolve.go b/service/s3/express_resolve.go index 2c357a2df30..18d6c06ada0 100644 --- a/service/s3/express_resolve.go +++ b/service/s3/express_resolve.go @@ -13,11 +13,26 @@ func resolveExpressCredentials(o *Options) { } } -// Config finalizer: if we're using the default S3Express implementation, -// grab a reference to the client for its CreateSession API. +// Config finalizer: if we're using the default S3Express implementation, grab +// a reference to the client for its CreateSession API, and the underlying +// sigv4 credentials provider for cache keying. func finalizeExpressCredentials(o *Options, c *Client) { if p, ok := o.ExpressCredentials.(*defaultS3ExpressCredentialsProvider); ok { p.client = c + p.v4creds = o.Credentials + } +} + +// Operation config finalizer: update the sigv4 credentials on the default +// express provider if it changed to ensure different cache keys +func finalizeOperationExpressCredentials(o *Options, c Client) { + p, ok := o.ExpressCredentials.(*defaultS3ExpressCredentialsProvider) + if !ok { + return + } + + if c.options.Credentials != o.Credentials { + o.ExpressCredentials = p.CloneWithBaseCredentials(o.Credentials) } } diff --git a/service/s3/express_test.go b/service/s3/express_test.go index 1a2f95bef75..733e27a9ad0 100644 --- a/service/s3/express_test.go +++ b/service/s3/express_test.go @@ -2,6 +2,7 @@ package s3 import ( "context" + "net/http" "sync" "testing" "time" @@ -9,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/internal/sdk" "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/aws/smithy-go/middleware" ) type mockCreateSession struct { @@ -36,6 +38,22 @@ func (m *mockCreateSession) CreateSession(context.Context, *CreateSessionInput, return o.output, o.err } +type mockCreds struct { + akid, secret, session string +} + +func newMockCreds(akid, secret, session string) *mockCreds { + return &mockCreds{akid: akid, secret: secret, session: session} +} + +func (m *mockCreds) Retrieve(ctx context.Context) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: m.akid, + SecretAccessKey: m.secret, + SessionToken: m.session, + }, nil +} + func TestS3Express_Retrieve(t *testing.T) { sdk.NowTime = func() time.Time { return time.Unix(0, 0) @@ -67,6 +85,7 @@ func TestS3Express_Retrieve(t *testing.T) { c := newDefaultS3ExpressCredentialsProvider() c.client = mockClient + c.v4creds = newMockCreds("AKID", "SECRET", "SESSION") mockClient.wg.Add(3) c0, err := c.Retrieve(context.Background(), "bucket-0") @@ -141,6 +160,7 @@ func TestS3Express_AsyncRefresh(t *testing.T) { c := newDefaultS3ExpressCredentialsProvider() c.client = mockClient + c.v4creds = newMockCreds("AKID", "SECRET", "SESSION") mockClient.wg.Add(2) c0, err := c.Retrieve(context.Background(), "bucket-0") @@ -172,3 +192,110 @@ func TestS3Express_AsyncRefresh(t *testing.T) { t.Errorf("expected credentials %v, got %v", expected, c1) } } + +type mockHTTP struct{} + +func (*mockHTTP) Do(*http.Request) (*http.Response, error) { + return &http.Response{}, nil +} + +func TestS3Express_OperationCredentialOverride(t *testing.T) { + sdk.NowTime = func() time.Time { + return time.Unix(0, 0) + } + + createSessionClient := &mockCreateSession{ + calls: []mockCreateSessionCall{ + { + output: &CreateSessionOutput{ + Credentials: &types.SessionCredentials{ + AccessKeyId: aws.String("EXPRESS_AKID0"), + SecretAccessKey: aws.String("EXPRESS_SECRET0"), + SessionToken: aws.String("EXPRESS_TOKEN0"), + Expiration: aws.Time(time.Unix(3600, 0).UTC()), + }, + }, + }, + { + output: &CreateSessionOutput{ + Credentials: &types.SessionCredentials{ + AccessKeyId: aws.String("EXPRESS_AKID1"), + SecretAccessKey: aws.String("EXPRESS_SECRET1"), + SessionToken: aws.String("EXPRESS_TOKEN1"), + Expiration: aws.Time(time.Unix(3600, 0).UTC()), + }, + }, + }, + }, + } + createSessionClient.wg.Add(2) + + svc := New(Options{ + Region: "us-west-2", + Credentials: newMockCreds("AKID0", "SECRET0", "SESSION0"), + HTTPClient: &mockHTTP{}, + APIOptions: []func(*middleware.Stack) error{ + func(stack *middleware.Stack) error { + stack.Deserialize.Clear() + return stack.Deserialize.Add( + middleware.DeserializeMiddlewareFunc( + "mockResponse", + func(context.Context, middleware.DeserializeInput, middleware.DeserializeHandler) (middleware.DeserializeOutput, middleware.Metadata, error) { + out := middleware.DeserializeOutput{ + Result: &GetObjectOutput{}, + } + return out, middleware.Metadata{}, nil + }, + ), + middleware.After, + ) + }, + }, + }) + + expressProvider, _ := svc.options.ExpressCredentials.(*defaultS3ExpressCredentialsProvider) + expressProvider.client = createSessionClient + + _, err := svc.GetObject(context.Background(), &GetObjectInput{ + Bucket: aws.String("bucket--usw2-az1--x-s3"), + Key: aws.String("key"), + }) + if err != nil { + t.Errorf("get object: %v", err) + } + + // there should be one set of credentials in the cache + key0 := cacheKey{ + CredentialsHash: gethmac("AKID0", "SECRET0"), + Bucket: "bucket--usw2-az1--x-s3", + } + _, ok := expressProvider.cache.Get(key0) + if !ok { + t.Errorf("creds for AKID0/SECRET0 are missing") + } + + _, err = svc.GetObject(context.Background(), &GetObjectInput{ + Bucket: aws.String("bucket--usw2-az1--x-s3"), + Key: aws.String("key"), + }, func(o *Options) { + o.Credentials = newMockCreds("AKID1", "SECRET1", "SESSION1") + }) + if err != nil { + t.Errorf("get object: %v", err) + } + + // checking two things here: + // - we have a new cache entry since creds changed + // - note we're still using the original pointer, the operation finalizer + // should have copied it and passed the cache along + key1 := cacheKey{ + CredentialsHash: gethmac("AKID1", "SECRET1"), + Bucket: "bucket--usw2-az1--x-s3", + } + _, ok = expressProvider.cache.Get(key1) + if !ok { + t.Errorf("creds for AKID1/SECRET1 are missing") + } + + createSessionClient.expectCalled(t, 2) +}