Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use proper HTTP client for fetching credentials #2041

Merged
merged 2 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api-presigned.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (c *Client) PresignedPostPolicy(ctx context.Context, p *PostPolicy) (u *url
}

// Get credentials from the configured credentials provider.
credValues, err := c.credsProvider.Get()
credValues, err := c.credsProvider.GetWithContext(c.CredContext())
if err != nil {
return nil, nil, err
}
Expand Down
19 changes: 15 additions & 4 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,9 @@ func (c *Client) executeMethod(ctx context.Context, method string, metadata requ
return nil, errors.New(c.endpointURL.String() + " is offline.")
}

var retryable bool // Indicates if request can be retried.
var bodySeeker io.Seeker // Extracted seeker from io.Reader.
var reqRetry = c.maxRetries // Indicates how many times we can retry the request
var retryable bool // Indicates if request can be retried.
var bodySeeker io.Seeker // Extracted seeker from io.Reader.
reqRetry := c.maxRetries // Indicates how many times we can retry the request
Comment on lines -603 to +605
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was changed by my IDE to follow coding guidelines.


if metadata.contentBody != nil {
// Check if body is seekable then it is retryable.
Expand Down Expand Up @@ -808,7 +808,7 @@ func (c *Client) newRequest(ctx context.Context, method string, metadata request
}

// Get credentials from the configured credentials provider.
value, err := c.credsProvider.Get()
value, err := c.credsProvider.GetWithContext(c.CredContext())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1018,3 +1018,14 @@ func (c *Client) isVirtualHostStyleRequest(url url.URL, bucketName string) bool
// path style requests
return s3utils.IsVirtualHostSupported(url, bucketName)
}

// CredContext returns the context for fetching credentials
func (c *Client) CredContext() *credentials.CredContext {
httpClient := c.httpClient
if httpClient == nil {
httpClient = http.DefaultClient
}
return &credentials.CredContext{
Client: httpClient,
}
}
2 changes: 1 addition & 1 deletion bucket-cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (c *Client) getBucketLocationRequest(ctx context.Context, bucketName string
c.setUserAgent(req)

// Get credentials from the configured credentials provider.
value, err := c.credsProvider.Get()
value, err := c.credsProvider.GetWithContext(c.CredContext())
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion bucket-cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestGetBucketLocationRequest(t *testing.T) {
c.setUserAgent(req)

// Get credentials from the configured credentials provider.
value, err := c.credsProvider.Get()
value, err := c.credsProvider.GetWithContext(c.CredContext())
if err != nil {
return nil, err
}
Expand Down
14 changes: 8 additions & 6 deletions pkg/credentials/assume_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ type AssumeRoleResult struct {
type STSAssumeRole struct {
Expiry

// Required http Client to use when connecting to MinIO STS service.
// Optional http Client to use when connecting to MinIO STS service
// (overrides default client in CredContext)
Client *http.Client

// STS endpoint to fetch STS credentials.
Expand Down Expand Up @@ -115,9 +116,6 @@ func NewSTSAssumeRole(stsEndpoint string, opts STSAssumeRoleOptions) (*Credentia
return nil, errors.New("AssumeRole credentials access/secretkey is mandatory")
}
return New(&STSAssumeRole{
Client: &http.Client{
Transport: http.DefaultTransport,
},
STSEndpoint: stsEndpoint,
Options: opts,
}), nil
Expand Down Expand Up @@ -224,8 +222,12 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume

// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSAssumeRole) Retrieve() (Value, error) {
a, err := getAssumeRoleCredentials(m.Client, m.STSEndpoint, m.Options)
func (m *STSAssumeRole) Retrieve(cc *CredContext) (Value, error) {
client := m.Client
if client == nil {
client = cc.Client
}
a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options)
if err != nil {
return Value{}, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/credentials/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func NewChainCredentials(providers []Provider) *Credentials {
//
// If a provider is found with credentials, it will be cached and any calls
// to IsExpired() will return the expired state of the cached provider.
func (c *Chain) Retrieve() (Value, error) {
func (c *Chain) Retrieve(cc *CredContext) (Value, error) {
for _, p := range c.Providers {
creds, _ := p.Retrieve()
creds, _ := p.Retrieve(cc)
// Always prioritize non-anonymous providers, if any.
if creds.AccessKeyID == "" && creds.SecretAccessKey == "" {
continue
Expand Down
10 changes: 5 additions & 5 deletions pkg/credentials/chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type testCredProvider struct {
err error
}

func (s *testCredProvider) Retrieve() (Value, error) {
func (s *testCredProvider) Retrieve(_ *CredContext) (Value, error) {
s.expired = false
return s.creds, s.err
}
Expand Down Expand Up @@ -59,7 +59,7 @@ func TestChainGet(t *testing.T) {
},
}

creds, err := p.Retrieve()
creds, err := p.Retrieve(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -95,7 +95,7 @@ func TestChainIsExpired(t *testing.T) {
t.Fatal("Expected expired to be true before any Retrieve")
}

_, err := p.Retrieve()
_, err := p.Retrieve(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand All @@ -112,7 +112,7 @@ func TestChainWithNoProvider(t *testing.T) {
if !p.IsExpired() {
t.Fatal("Expected to be expired with no providers")
}
_, err := p.Retrieve()
_, err := p.Retrieve(defaultCredContext)
if err != nil {
if err.Error() != "No valid providers found []" {
t.Error(err)
Expand All @@ -136,7 +136,7 @@ func TestChainProviderWithNoValidProvider(t *testing.T) {
t.Fatal("Expected to be expired with no providers")
}

_, err := p.Retrieve()
_, err := p.Retrieve(defaultCredContext)
if err != nil {
if err.Error() != "No valid providers found [FirstError SecondError]" {
t.Error(err)
Expand Down
34 changes: 32 additions & 2 deletions pkg/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package credentials

import (
"net/http"
"sync"
"time"
)
Expand All @@ -30,6 +31,10 @@ const (
defaultExpiryWindow = 0.8
)

// defaultCredContext is used when the credential context doesn't
// actually matter or the default context is suitable.
var defaultCredContext = &CredContext{Client: http.DefaultClient}

// A Value is the S3 credentials value for individual credential fields.
type Value struct {
// S3 Access key ID
Expand All @@ -54,13 +59,21 @@ type Value struct {
type Provider interface {
// Retrieve returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
Retrieve() (Value, error)
Retrieve(cc *CredContext) (Value, error)

// IsExpired returns if the credentials are no longer valid, and need
// to be retrieved.
IsExpired() bool
}

// CredContext is passed to the Retrieve function of a provider to provide
// some additional context to retrieve credentials.
type CredContext struct {
// Client specifies the HTTP client that should be used if an HTTP
// request is to be made to fetch the credentials.
Client *http.Client
}

// A Expiry provides shared expiration logic to be used by credentials
// providers to implement expiry functionality.
//
Expand Down Expand Up @@ -146,7 +159,24 @@ func New(provider Provider) *Credentials {
//
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
//
// Deprecated: Get() exists for historical compatibility and should not be
// used. To get new credentials use the Credentials.GetWithContext function
// to ensure the proper context (i.e. HTTP client) will be used.
func (c *Credentials) Get() (Value, error) {
return c.GetWithContext(defaultCredContext)
}

// GetWithContext returns the credentials value, or error if the
// credentials Value failed to be retrieved.
//
// Will return the cached credentials Value if it has not expired. If the
// credentials Value has expired the Provider's Retrieve() will be called
// to refresh the credentials.
//
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) GetWithContext(cc *CredContext) (Value, error) {
if c == nil {
return Value{}, nil
}
Expand All @@ -155,7 +185,7 @@ func (c *Credentials) Get() (Value, error) {
defer c.Unlock()

if c.isExpired() {
creds, err := c.provider.Retrieve()
creds, err := c.provider.Retrieve(cc)
if err != nil {
return Value{}, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type credProvider struct {
err error
}

func (s *credProvider) Retrieve() (Value, error) {
func (s *credProvider) Retrieve(_ *CredContext) (Value, error) {
s.expired = false
return s.creds, s.err
}
Expand All @@ -47,7 +47,7 @@ func TestCredentialsGet(t *testing.T) {
expired: true,
})

creds, err := c.Get()
creds, err := c.GetWithContext(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand All @@ -65,7 +65,7 @@ func TestCredentialsGet(t *testing.T) {
func TestCredentialsGetWithError(t *testing.T) {
c := New(&credProvider{err: errors.New("Custom error")})

_, err := c.Get()
_, err := c.GetWithContext(defaultCredContext)
if err != nil {
if err.Error() != "Custom error" {
t.Errorf("Expected \"Custom error\", got %s", err.Error())
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/env_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func NewEnvAWS() *Credentials {
}

// Retrieve retrieves the keys from the environment.
func (e *EnvAWS) Retrieve() (Value, error) {
func (e *EnvAWS) Retrieve(_ *CredContext) (Value, error) {
e.retrieved = false

id := os.Getenv("AWS_ACCESS_KEY_ID")
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/env_minio.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func NewEnvMinio() *Credentials {
}

// Retrieve retrieves the keys from the environment.
func (e *EnvMinio) Retrieve() (Value, error) {
func (e *EnvMinio) Retrieve(_ *CredContext) (Value, error) {
e.retrieved = false

id := os.Getenv("MINIO_ROOT_USER")
Expand Down
6 changes: 3 additions & 3 deletions pkg/credentials/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestEnvAWSRetrieve(t *testing.T) {
t.Error("Expect creds to be expired before retrieve.")
}

creds, err := e.Retrieve()
creds, err := e.Retrieve(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestEnvAWSRetrieve(t *testing.T) {
SignerType: SignatureV4,
}

creds, err = e.Retrieve()
creds, err = e.Retrieve(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand All @@ -84,7 +84,7 @@ func TestEnvMinioRetrieve(t *testing.T) {
t.Error("Expect creds to be expired before retrieve.")
}

creds, err := e.Retrieve()
creds, err := e.Retrieve(defaultCredContext)
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/file_aws_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func NewFileAWSCredentials(filename, profile string) *Credentials {

// Retrieve reads and extracts the shared credentials from the current
// users home directory.
func (p *FileAWSCredentials) Retrieve() (Value, error) {
func (p *FileAWSCredentials) Retrieve(_ *CredContext) (Value, error) {
if p.Filename == "" {
p.Filename = os.Getenv("AWS_SHARED_CREDENTIALS_FILE")
if p.Filename == "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/file_minio_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func NewFileMinioClient(filename, alias string) *Credentials {

// Retrieve reads and extracts the shared credentials from the current
// users home directory.
func (p *FileMinioClient) Retrieve() (Value, error) {
func (p *FileMinioClient) Retrieve(_ *CredContext) (Value, error) {
if p.Filename == "" {
if value, ok := os.LookupEnv("MINIO_SHARED_CREDENTIALS_FILE"); ok {
p.Filename = value
Expand Down
Loading
Loading