Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: include underlying credentials in s3express credentials cache keys #2414

Merged
merged 1 commit into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/8e6a01197da848c88aaab5adb296abc1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "8e6a0119-7da8-48c8-8aaa-b5adb296abc1",
"type": "bugfix",
"description": "Improve uniqueness of default S3Express sesssion credentials cache keying to prevent collision in multi-credential scenarios.",
"modules": [
"service/s3"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate;
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
import static software.amazon.smithy.go.codegen.SymbolUtils.buildPackageSymbol;

public class S3ExpressAuthScheme implements GoIntegration {
private static final ConfigField s3ExpressCredentials =
Expand All @@ -67,6 +68,14 @@ public class S3ExpressAuthScheme implements GoIntegration {
.withClientInput(true)
.build();

private static final ConfigFieldResolver s3ExpressCredentialsOperationFinalizer =
ConfigFieldResolver.builder()
.location(ConfigFieldResolver.Location.OPERATION)
.target(ConfigFieldResolver.Target.FINALIZATION)
.resolver(buildPackageSymbol("finalizeOperationExpressCredentials"))
.withClientInput(true)
.build();

@Override
public void writeAdditionalFiles(
GoSettings settings, Model model, SymbolProvider symbolProvider, GoDelegator goDelegator
Expand All @@ -84,6 +93,7 @@ public List<RuntimeClientPlugin> getClientPlugins() {
.addConfigField(s3ExpressCredentials)
.addConfigFieldResolver(s3ExpressCredentialsResolver)
.addConfigFieldResolver(s3ExpressCredentialsClientFinalizer)
.addConfigFieldResolver(s3ExpressCredentialsOperationFinalizer)
.addAuthSchemeDefinition(SigV4S3ExpressTrait.ID, new SigV4S3Express())
.build()
);
Expand Down
2 changes: 2 additions & 0 deletions service/s3/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

103 changes: 75 additions & 28 deletions service/s3/express_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package s3

import (
"context"
"crypto/hmac"
"crypto/sha256"
"errors"
"fmt"
"sync"
"time"

Expand All @@ -17,18 +20,49 @@ const s3ExpressCacheCap = 100

const s3ExpressRefreshWindow = 1 * time.Minute

type cacheKey struct {
CredentialsHash string // hmac(sigv4 akid, sigv4 secret)
Bucket string
}

func (c cacheKey) Slug() string {
return fmt.Sprintf("%s%s", c.CredentialsHash, c.Bucket)
}

type sessionCredsCache struct {
mu sync.Mutex
cache cache.Cache
}

func (c *sessionCredsCache) Get(key cacheKey) (*aws.Credentials, bool) {
c.mu.Lock()
defer c.mu.Unlock()

if v, ok := c.cache.Get(key); ok {
return v.(*aws.Credentials), true
}
return nil, false
}

func (c *sessionCredsCache) Put(key cacheKey, creds *aws.Credentials) {
c.mu.Lock()
defer c.mu.Unlock()

c.cache.Put(key, creds)
}

// The default S3Express provider uses an LRU cache with a capacity of 100.
//
// Credentials will be refreshed asynchronously when a Retrieve() call is made
// for cached credentials within an expiry window (1 minute, currently
// non-configurable).
type defaultS3ExpressCredentialsProvider struct {
mu sync.Mutex
sf singleflight.Group

client createSessionAPIClient
credsCache cache.Cache
cache *sessionCredsCache
refreshWindow time.Duration
v4creds aws.CredentialsProvider // underlying credentials used for CreateSession
}

type createSessionAPIClient interface {
Expand All @@ -37,35 +71,54 @@ type createSessionAPIClient interface {

func newDefaultS3ExpressCredentialsProvider() *defaultS3ExpressCredentialsProvider {
return &defaultS3ExpressCredentialsProvider{
credsCache: lru.New(s3ExpressCacheCap),
cache: &sessionCredsCache{
cache: lru.New(s3ExpressCacheCap),
},
refreshWindow: s3ExpressRefreshWindow,
}
}

// returns a cloned provider using new base credentials, used when per-op
// config mutations change the credentials provider
func (p *defaultS3ExpressCredentialsProvider) CloneWithBaseCredentials(v4creds aws.CredentialsProvider) *defaultS3ExpressCredentialsProvider {
return &defaultS3ExpressCredentialsProvider{
client: p.client,
cache: p.cache,
refreshWindow: p.refreshWindow,
v4creds: v4creds,
}
}

func (p *defaultS3ExpressCredentialsProvider) Retrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
p.mu.Lock()
defer p.mu.Unlock()
v4creds, err := p.v4creds.Retrieve(ctx)
if err != nil {
return aws.Credentials{}, fmt.Errorf("get sigv4 creds: %w", err)
}

creds, ok := p.getCacheCredentials(bucket)
key := cacheKey{
CredentialsHash: gethmac(v4creds.AccessKeyID, v4creds.SecretAccessKey),
Bucket: bucket,
}
creds, ok := p.cache.Get(key)
if !ok || creds.Expired() {
return p.awaitDoChanRetrieve(ctx, bucket)
return p.awaitDoChanRetrieve(ctx, key)
}

if creds.Expires.Sub(sdk.NowTime()) <= p.refreshWindow {
p.doChanRetrieve(ctx, bucket)
p.doChanRetrieve(ctx, key)
}

return *creds, nil
}

func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, bucket string) <-chan singleflight.Result {
return p.sf.DoChan(bucket, func() (interface{}, error) {
return p.retrieve(ctx, bucket)
func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, key cacheKey) <-chan singleflight.Result {
return p.sf.DoChan(key.Slug(), func() (interface{}, error) {
return p.retrieve(ctx, key)
})
}

func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
ch := p.doChanRetrieve(ctx, bucket)
func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
ch := p.doChanRetrieve(ctx, key)

select {
case r := <-ch:
Expand All @@ -75,9 +128,9 @@ func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Co
}
}

func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
resp, err := p.client.CreateSession(ctx, &CreateSessionInput{
Bucket: aws.String(bucket),
Bucket: aws.String(key.Bucket),
})
if err != nil {
return aws.Credentials{}, err
Expand All @@ -88,22 +141,10 @@ func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, buck
return aws.Credentials{}, err
}

p.putCacheCredentials(bucket, creds)
p.cache.Put(key, creds)
return *creds, nil
}

func (p *defaultS3ExpressCredentialsProvider) getCacheCredentials(bucket string) (*aws.Credentials, bool) {
if v, ok := p.credsCache.Get(bucket); ok {
return v.(*aws.Credentials), true
}

return nil, false
}

func (p *defaultS3ExpressCredentialsProvider) putCacheCredentials(bucket string, creds *aws.Credentials) {
p.credsCache.Put(bucket, creds)
}

func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
if o.Credentials == nil {
return nil, errors.New("s3express session credentials unset")
Expand All @@ -121,3 +162,9 @@ func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
Expires: *o.Credentials.Expiration,
}, nil
}

func gethmac(p, key string) string {
hash := hmac.New(sha256.New, []byte(key))
hash.Write([]byte(p))
return string(hash.Sum(nil))
}
19 changes: 17 additions & 2 deletions service/s3/express_resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,26 @@ func resolveExpressCredentials(o *Options) {
}
}

// Config finalizer: if we're using the default S3Express implementation,
// grab a reference to the client for its CreateSession API.
// Config finalizer: if we're using the default S3Express implementation, grab
// a reference to the client for its CreateSession API, and the underlying
// sigv4 credentials provider for cache keying.
func finalizeExpressCredentials(o *Options, c *Client) {
if p, ok := o.ExpressCredentials.(*defaultS3ExpressCredentialsProvider); ok {
p.client = c
p.v4creds = o.Credentials
}
}

// Operation config finalizer: update the sigv4 credentials on the default
// express provider if it changed to ensure different cache keys
func finalizeOperationExpressCredentials(o *Options, c Client) {
p, ok := o.ExpressCredentials.(*defaultS3ExpressCredentialsProvider)
if !ok {
return
}

if c.options.Credentials != o.Credentials {
o.ExpressCredentials = p.CloneWithBaseCredentials(o.Credentials)
}
}

Expand Down
Loading
Loading