From 0abbe92812de3c9e89054cf6ea2d33cf9899b675 Mon Sep 17 00:00:00 2001 From: Ramon de Klein Date: Fri, 3 Jan 2025 13:48:08 +0100 Subject: [PATCH] Use default STS endpoint (#2044) --- api.go | 3 ++- pkg/credentials/assume_role.go | 33 ++++++++++++++++---------- pkg/credentials/chain.go | 2 +- pkg/credentials/credentials.go | 13 +++++++++- pkg/credentials/env_aws.go | 2 +- pkg/credentials/iam_aws.go | 21 ++++++++++------ pkg/credentials/sts_client_grants.go | 31 +++++++++++++++--------- pkg/credentials/sts_custom_identity.go | 27 ++++++++++++++------- pkg/credentials/sts_ldap_identity.go | 30 ++++++++++++++++------- pkg/credentials/sts_tls_identity.go | 33 +++++++++++++++----------- pkg/credentials/sts_web_identity.go | 30 ++++++++++++++--------- 11 files changed, 149 insertions(+), 76 deletions(-) diff --git a/api.go b/api.go index 5bcd903e3..cb46816d0 100644 --- a/api.go +++ b/api.go @@ -1026,6 +1026,7 @@ func (c *Client) CredContext() *credentials.CredContext { httpClient = http.DefaultClient } return &credentials.CredContext{ - Client: httpClient, + Client: httpClient, + Endpoint: c.endpointURL.String(), } } diff --git a/pkg/credentials/assume_role.go b/pkg/credentials/assume_role.go index 3d6715731..cd0a641bd 100644 --- a/pkg/credentials/assume_role.go +++ b/pkg/credentials/assume_role.go @@ -109,9 +109,6 @@ type STSAssumeRoleOptions struct { // NewSTSAssumeRole returns a pointer to a new // Credentials object wrapping the STSAssumeRole. func NewSTSAssumeRole(stsEndpoint string, opts STSAssumeRoleOptions) (*Credentials, error) { - if stsEndpoint == "" { - return nil, errors.New("STS endpoint cannot be empty") - } if opts.AccessKey == "" || opts.SecretKey == "" { return nil, errors.New("AssumeRole credentials access/secretkey is mandatory") } @@ -220,12 +217,30 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume return a, nil } -func (m *STSAssumeRole) retrieve(cc *CredContext) (Value, error) { +// RetrieveWithCredContext retrieves credentials from the MinIO service. +// Error will be returned if the request fails, optional cred context. +func (m *STSAssumeRole) RetrieveWithCredContext(cc *CredContext) (Value, error) { + if cc == nil { + cc = defaultCredContext + } + client := m.Client if client == nil { client = cc.Client } - a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options) + if client == nil { + client = defaultCredContext.Client + } + + stsEndpoint := m.STSEndpoint + if stsEndpoint == "" { + stsEndpoint = cc.Endpoint + } + if stsEndpoint == "" { + return Value{}, errors.New("STS endpoint unknown") + } + + a, err := getAssumeRoleCredentials(client, stsEndpoint, m.Options) if err != nil { return Value{}, err } @@ -242,14 +257,8 @@ func (m *STSAssumeRole) retrieve(cc *CredContext) (Value, error) { }, nil } -// RetrieveWithCredContext retrieves credentials from the MinIO service. -// Error will be returned if the request fails, optional cred context. -func (m *STSAssumeRole) RetrieveWithCredContext(cc *CredContext) (Value, error) { - return m.retrieve(cc) -} - // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. func (m *STSAssumeRole) Retrieve() (Value, error) { - return m.retrieve(defaultCredContext) + return m.RetrieveWithCredContext(nil) } diff --git a/pkg/credentials/chain.go b/pkg/credentials/chain.go index 8e70f387c..5ef3597d1 100644 --- a/pkg/credentials/chain.go +++ b/pkg/credentials/chain.go @@ -80,7 +80,7 @@ func (c *Chain) RetrieveWithCredContext(cc *CredContext) (Value, error) { // to IsExpired() will return the expired state of the cached provider. func (c *Chain) Retrieve() (Value, error) { for _, p := range c.Providers { - creds, _ := p.RetrieveWithCredContext(defaultCredContext) + creds, _ := p.Retrieve() // Always prioritize non-anonymous providers, if any. if creds.AccessKeyID == "" && creds.SecretAccessKey == "" { continue diff --git a/pkg/credentials/credentials.go b/pkg/credentials/credentials.go index 6d0fe9a7a..52aff9a57 100644 --- a/pkg/credentials/credentials.go +++ b/pkg/credentials/credentials.go @@ -64,6 +64,10 @@ type Provider interface { // Retrieve returns nil if it successfully retrieved the value. // Error is returned if the value were not obtainable, or empty. + // + // Deprecated: Retrieve() exists for historical compatibility and should not + // be used. To get new credentials use the RetrieveWithCredContext function + // to ensure the proper context (i.e. HTTP client) will be used. Retrieve() (Value, error) // IsExpired returns if the credentials are no longer valid, and need @@ -77,6 +81,10 @@ type CredContext struct { // Client specifies the HTTP client that should be used if an HTTP // request is to be made to fetch the credentials. Client *http.Client + + // Endpoint specifies the MinIO endpoint that will be used if no + // explicit endpoint is provided. + Endpoint string } // A Expiry provides shared expiration logic to be used by credentials @@ -169,7 +177,7 @@ func New(provider Provider) *Credentials { // used. To get new credentials use the Credentials.GetWithContext function // to ensure the proper context (i.e. HTTP client) will be used. func (c *Credentials) Get() (Value, error) { - return c.GetWithContext(defaultCredContext) + return c.GetWithContext(nil) } // GetWithContext returns the credentials value, or error if the @@ -185,6 +193,9 @@ func (c *Credentials) GetWithContext(cc *CredContext) (Value, error) { if c == nil { return Value{}, nil } + if cc == nil { + cc = defaultCredContext + } c.Lock() defer c.Unlock() diff --git a/pkg/credentials/env_aws.go b/pkg/credentials/env_aws.go index dd0d5ec81..21ab0a38a 100644 --- a/pkg/credentials/env_aws.go +++ b/pkg/credentials/env_aws.go @@ -69,7 +69,7 @@ func (e *EnvAWS) Retrieve() (Value, error) { return e.retrieve() } -// RetrieveWithContext is like Retrieve (no-op input of Cred Context) +// RetrieveWithCredContext is like Retrieve (no-op input of Cred Context) func (e *EnvAWS) RetrieveWithCredContext(_ *CredContext) (Value, error) { return e.retrieve() } diff --git a/pkg/credentials/iam_aws.go b/pkg/credentials/iam_aws.go index 717f2c03b..0ba06e710 100644 --- a/pkg/credentials/iam_aws.go +++ b/pkg/credentials/iam_aws.go @@ -95,7 +95,12 @@ func NewIAM(endpoint string) *Credentials { }) } -func (m *IAM) retrieve(cc *CredContext) (Value, error) { +// RetrieveWithCredContext is like Retrieve with Cred Context +func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) { + if cc == nil { + cc = defaultCredContext + } + token := os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN") if token == "" { token = m.Container.AuthorizationToken @@ -143,8 +148,15 @@ func (m *IAM) retrieve(cc *CredContext) (Value, error) { if client == nil { client = cc.Client } + if client == nil { + client = defaultCredContext.Client + } endpoint := m.Endpoint + if endpoint == "" { + endpoint = cc.Endpoint + } + switch { case identityFile != "": if len(endpoint) == 0 { @@ -228,12 +240,7 @@ func (m *IAM) retrieve(cc *CredContext) (Value, error) { // Error will be returned if the request fails, or unable to extract // the desired func (m *IAM) Retrieve() (Value, error) { - return m.retrieve(defaultCredContext) -} - -// RetrieveWithCredContext is like Retrieve with Cred Context -func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) { - return m.retrieve(cc) + return m.RetrieveWithCredContext(nil) } // A ec2RoleCredRespBody provides the shape for unmarshaling credential diff --git a/pkg/credentials/sts_client_grants.go b/pkg/credentials/sts_client_grants.go index 78dd5c129..ef6f436b8 100644 --- a/pkg/credentials/sts_client_grants.go +++ b/pkg/credentials/sts_client_grants.go @@ -91,9 +91,6 @@ type STSClientGrants struct { // NewSTSClientGrants returns a pointer to a new // Credentials object wrapping the STSClientGrants. func NewSTSClientGrants(stsEndpoint string, getClientGrantsTokenExpiry func() (*ClientGrantsToken, error)) (*Credentials, error) { - if stsEndpoint == "" { - return nil, errors.New("STS endpoint cannot be empty") - } if getClientGrantsTokenExpiry == nil { return nil, errors.New("Client grants access token and expiry retrieval function should be defined") } @@ -160,12 +157,29 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string, return a, nil } -func (m *STSClientGrants) retrieve(cc *CredContext) (Value, error) { +// RetrieveWithCredContext is like Retrieve() with cred context +func (m *STSClientGrants) RetrieveWithCredContext(cc *CredContext) (Value, error) { + if cc == nil { + cc = defaultCredContext + } + client := m.Client if client == nil { client = cc.Client } - a, err := getClientGrantsCredentials(client, m.STSEndpoint, m.GetClientGrantsTokenExpiry) + if client == nil { + client = defaultCredContext.Client + } + + stsEndpoint := m.STSEndpoint + if stsEndpoint == "" { + stsEndpoint = cc.Endpoint + } + if stsEndpoint == "" { + return Value{}, errors.New("STS endpoint unknown") + } + + a, err := getClientGrantsCredentials(client, stsEndpoint, m.GetClientGrantsTokenExpiry) if err != nil { return Value{}, err } @@ -182,13 +196,8 @@ func (m *STSClientGrants) retrieve(cc *CredContext) (Value, error) { }, nil } -// RetrieveWithCredContext is like Retrieve() with cred context -func (m *STSClientGrants) RetrieveWithCredContext(cc *CredContext) (Value, error) { - return m.retrieve(cc) -} - // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. func (m *STSClientGrants) Retrieve() (Value, error) { - return m.retrieve(defaultCredContext) + return m.RetrieveWithCredContext(nil) } diff --git a/pkg/credentials/sts_custom_identity.go b/pkg/credentials/sts_custom_identity.go index 8c0ce1284..0021f9315 100644 --- a/pkg/credentials/sts_custom_identity.go +++ b/pkg/credentials/sts_custom_identity.go @@ -71,8 +71,21 @@ type CustomTokenIdentity struct { RequestedExpiry time.Duration } -func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error) { - u, err := url.Parse(c.STSEndpoint) +// RetrieveWithCredContext with Retrieve optionally cred context +func (c *CustomTokenIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) { + if cc == nil { + cc = defaultCredContext + } + + stsEndpoint := c.STSEndpoint + if stsEndpoint == "" { + stsEndpoint = cc.Endpoint + } + if stsEndpoint == "" { + return Value{}, errors.New("STS endpoint unknown") + } + + u, err := url.Parse(stsEndpoint) if err != nil { return value, err } @@ -97,6 +110,9 @@ func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error) if client == nil { client = cc.Client } + if client == nil { + client = defaultCredContext.Client + } resp, err := client.Do(req) if err != nil { @@ -126,12 +142,7 @@ func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error) // Retrieve - to satisfy Provider interface; fetches credentials from MinIO. func (c *CustomTokenIdentity) Retrieve() (value Value, err error) { - return c.retrieve(defaultCredContext) -} - -// RetrieveWithCredContext with Retrieve optionally cred context -func (c *CustomTokenIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) { - return c.retrieve(cc) + return c.RetrieveWithCredContext(nil) } // NewCustomTokenCredentials - returns credentials using the diff --git a/pkg/credentials/sts_ldap_identity.go b/pkg/credentials/sts_ldap_identity.go index 8543cc2e8..e63997e6e 100644 --- a/pkg/credentials/sts_ldap_identity.go +++ b/pkg/credentials/sts_ldap_identity.go @@ -20,6 +20,7 @@ package credentials import ( "bytes" "encoding/xml" + "errors" "fmt" "io" "net/http" @@ -120,8 +121,22 @@ func NewLDAPIdentityWithSessionPolicy(stsEndpoint, ldapUsername, ldapPassword, p }), nil } -func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) { - u, err := url.Parse(k.STSEndpoint) +// RetrieveWithCredContext gets the credential by calling the MinIO STS API for +// LDAP on the configured stsEndpoint. +func (k *LDAPIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) { + if cc == nil { + cc = defaultCredContext + } + + stsEndpoint := k.STSEndpoint + if stsEndpoint == "" { + stsEndpoint = cc.Endpoint + } + if stsEndpoint == "" { + return Value{}, errors.New("STS endpoint unknown") + } + + u, err := url.Parse(stsEndpoint) if err != nil { return value, err } @@ -149,6 +164,9 @@ func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) { if client == nil { client = cc.Client } + if client == nil { + client = defaultCredContext.Client + } resp, err := client.Do(req) if err != nil { @@ -194,11 +212,5 @@ func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) { // Retrieve gets the credential by calling the MinIO STS API for // LDAP on the configured stsEndpoint. func (k *LDAPIdentity) Retrieve() (value Value, err error) { - return k.retrieve(defaultCredContext) -} - -// RetrieveWithCredContext gets the credential by calling the MinIO STS API for -// LDAP on the configured stsEndpoint. -func (k *LDAPIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) { - return k.retrieve(cc) + return k.RetrieveWithCredContext(defaultCredContext) } diff --git a/pkg/credentials/sts_tls_identity.go b/pkg/credentials/sts_tls_identity.go index d7ab25d34..c904bbeac 100644 --- a/pkg/credentials/sts_tls_identity.go +++ b/pkg/credentials/sts_tls_identity.go @@ -86,12 +86,6 @@ type STSCertificateIdentity struct { // to the given STS endpoint with the given TLS certificate and retrieves and // rotates S3 credentials. func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, options ...CertificateIdentityOption) (*Credentials, error) { - if endpoint == "" { - return nil, errors.New("STS endpoint cannot be empty") - } - if _, err := url.Parse(endpoint); err != nil { - return nil, err - } identity := &STSCertificateIdentity{ STSEndpoint: endpoint, Certificate: certificate, @@ -102,8 +96,21 @@ func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, opt return New(identity), nil } -func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) { - endpointURL, err := url.Parse(i.STSEndpoint) +// RetrieveWithCredContext is Retrieve with cred context +func (i *STSCertificateIdentity) RetrieveWithCredContext(cc *CredContext) (Value, error) { + if cc == nil { + cc = defaultCredContext + } + + stsEndpoint := i.STSEndpoint + if stsEndpoint == "" { + stsEndpoint = cc.Endpoint + } + if stsEndpoint == "" { + return Value{}, errors.New("STS endpoint unknown") + } + + endpointURL, err := url.Parse(stsEndpoint) if err != nil { return Value{}, err } @@ -130,6 +137,9 @@ func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) { if client == nil { client = cc.Client } + if client == nil { + client = defaultCredContext.Client + } tr, ok := client.Transport.(*http.Transport) if !ok { @@ -192,14 +202,9 @@ func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) { }, nil } -// RetrieveWithCredContext is Retrieve with cred context -func (i *STSCertificateIdentity) RetrieveWithCredContext(cc *CredContext) (Value, error) { - return i.retrieve(cc) -} - // Retrieve fetches a new set of S3 credentials from the configured STS API endpoint. func (i *STSCertificateIdentity) Retrieve() (Value, error) { - return i.retrieve(defaultCredContext) + return i.RetrieveWithCredContext(defaultCredContext) } // Expiration returns the expiration time of the current S3 credentials. diff --git a/pkg/credentials/sts_web_identity.go b/pkg/credentials/sts_web_identity.go index 9a84837b6..235258893 100644 --- a/pkg/credentials/sts_web_identity.go +++ b/pkg/credentials/sts_web_identity.go @@ -98,9 +98,6 @@ type STSWebIdentity struct { // NewSTSWebIdentity returns a pointer to a new // Credentials object wrapping the STSWebIdentity. func NewSTSWebIdentity(stsEndpoint string, getWebIDTokenExpiry func() (*WebIdentityToken, error), opts ...func(*STSWebIdentity)) (*Credentials, error) { - if stsEndpoint == "" { - return nil, errors.New("STS endpoint cannot be empty") - } if getWebIDTokenExpiry == nil { return nil, errors.New("Web ID token and expiry retrieval function should be defined") } @@ -217,13 +214,29 @@ func getWebIdentityCredentials(clnt *http.Client, endpoint, roleARN, roleSession return a, nil } -func (m *STSWebIdentity) retrieve(cc *CredContext) (Value, error) { +// RetrieveWithCredContext is like Retrieve with optional cred context. +func (m *STSWebIdentity) RetrieveWithCredContext(cc *CredContext) (Value, error) { + if cc == nil { + cc = defaultCredContext + } + client := m.Client if client == nil { client = cc.Client } + if client == nil { + client = defaultCredContext.Client + } - a, err := getWebIdentityCredentials(client, m.STSEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry) + stsEndpoint := m.STSEndpoint + if stsEndpoint == "" { + stsEndpoint = cc.Endpoint + } + if stsEndpoint == "" { + return Value{}, errors.New("STS endpoint unknown") + } + + a, err := getWebIdentityCredentials(client, stsEndpoint, m.RoleARN, m.roleSessionName, m.Policy, m.GetWebIDTokenExpiry) if err != nil { return Value{}, err } @@ -243,12 +256,7 @@ func (m *STSWebIdentity) retrieve(cc *CredContext) (Value, error) { // Retrieve retrieves credentials from the MinIO service. // Error will be returned if the request fails. func (m *STSWebIdentity) Retrieve() (Value, error) { - return m.retrieve(defaultCredContext) -} - -// RetrieveWithCredContext is like Retrieve with optional cred context. -func (m *STSWebIdentity) RetrieveWithCredContext(cc *CredContext) (Value, error) { - return m.retrieve(cc) + return m.RetrieveWithCredContext(nil) } // Expiration returns the expiration time of the credentials