Skip to content

Commit

Permalink
Use default STS endpoint (#2044)
Browse files Browse the repository at this point in the history
  • Loading branch information
ramondeklein authored Jan 3, 2025
1 parent 5757f2c commit 0abbe92
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 76 deletions.
3 changes: 2 additions & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ func (c *Client) CredContext() *credentials.CredContext {
httpClient = http.DefaultClient
}
return &credentials.CredContext{
Client: httpClient,
Client: httpClient,
Endpoint: c.endpointURL.String(),
}
}
33 changes: 21 additions & 12 deletions pkg/credentials/assume_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ type STSAssumeRoleOptions struct {
// NewSTSAssumeRole returns a pointer to a new
// Credentials object wrapping the STSAssumeRole.
func NewSTSAssumeRole(stsEndpoint string, opts STSAssumeRoleOptions) (*Credentials, error) {
if stsEndpoint == "" {
return nil, errors.New("STS endpoint cannot be empty")
}
if opts.AccessKey == "" || opts.SecretKey == "" {
return nil, errors.New("AssumeRole credentials access/secretkey is mandatory")
}
Expand Down Expand Up @@ -220,12 +217,30 @@ func getAssumeRoleCredentials(clnt *http.Client, endpoint string, opts STSAssume
return a, nil
}

func (m *STSAssumeRole) retrieve(cc *CredContext) (Value, error) {
// RetrieveWithCredContext retrieves credentials from the MinIO service.
// Error will be returned if the request fails, optional cred context.
func (m *STSAssumeRole) RetrieveWithCredContext(cc *CredContext) (Value, error) {
if cc == nil {
cc = defaultCredContext
}

client := m.Client
if client == nil {
client = cc.Client
}
a, err := getAssumeRoleCredentials(client, m.STSEndpoint, m.Options)
if client == nil {
client = defaultCredContext.Client
}

stsEndpoint := m.STSEndpoint
if stsEndpoint == "" {
stsEndpoint = cc.Endpoint
}
if stsEndpoint == "" {
return Value{}, errors.New("STS endpoint unknown")
}

a, err := getAssumeRoleCredentials(client, stsEndpoint, m.Options)
if err != nil {
return Value{}, err
}
Expand All @@ -242,14 +257,8 @@ func (m *STSAssumeRole) retrieve(cc *CredContext) (Value, error) {
}, nil
}

// RetrieveWithCredContext retrieves credentials from the MinIO service.
// Error will be returned if the request fails, optional cred context.
func (m *STSAssumeRole) RetrieveWithCredContext(cc *CredContext) (Value, error) {
return m.retrieve(cc)
}

// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSAssumeRole) Retrieve() (Value, error) {
return m.retrieve(defaultCredContext)
return m.RetrieveWithCredContext(nil)
}
2 changes: 1 addition & 1 deletion pkg/credentials/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (c *Chain) RetrieveWithCredContext(cc *CredContext) (Value, error) {
// to IsExpired() will return the expired state of the cached provider.
func (c *Chain) Retrieve() (Value, error) {
for _, p := range c.Providers {
creds, _ := p.RetrieveWithCredContext(defaultCredContext)
creds, _ := p.Retrieve()
// Always prioritize non-anonymous providers, if any.
if creds.AccessKeyID == "" && creds.SecretAccessKey == "" {
continue
Expand Down
13 changes: 12 additions & 1 deletion pkg/credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ type Provider interface {

// Retrieve returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
//
// Deprecated: Retrieve() exists for historical compatibility and should not
// be used. To get new credentials use the RetrieveWithCredContext function
// to ensure the proper context (i.e. HTTP client) will be used.
Retrieve() (Value, error)

// IsExpired returns if the credentials are no longer valid, and need
Expand All @@ -77,6 +81,10 @@ type CredContext struct {
// Client specifies the HTTP client that should be used if an HTTP
// request is to be made to fetch the credentials.
Client *http.Client

// Endpoint specifies the MinIO endpoint that will be used if no
// explicit endpoint is provided.
Endpoint string
}

// A Expiry provides shared expiration logic to be used by credentials
Expand Down Expand Up @@ -169,7 +177,7 @@ func New(provider Provider) *Credentials {
// used. To get new credentials use the Credentials.GetWithContext function
// to ensure the proper context (i.e. HTTP client) will be used.
func (c *Credentials) Get() (Value, error) {
return c.GetWithContext(defaultCredContext)
return c.GetWithContext(nil)
}

// GetWithContext returns the credentials value, or error if the
Expand All @@ -185,6 +193,9 @@ func (c *Credentials) GetWithContext(cc *CredContext) (Value, error) {
if c == nil {
return Value{}, nil
}
if cc == nil {
cc = defaultCredContext
}

c.Lock()
defer c.Unlock()
Expand Down
2 changes: 1 addition & 1 deletion pkg/credentials/env_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (e *EnvAWS) Retrieve() (Value, error) {
return e.retrieve()
}

// RetrieveWithContext is like Retrieve (no-op input of Cred Context)
// RetrieveWithCredContext is like Retrieve (no-op input of Cred Context)
func (e *EnvAWS) RetrieveWithCredContext(_ *CredContext) (Value, error) {
return e.retrieve()
}
Expand Down
21 changes: 14 additions & 7 deletions pkg/credentials/iam_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ func NewIAM(endpoint string) *Credentials {
})
}

func (m *IAM) retrieve(cc *CredContext) (Value, error) {
// RetrieveWithCredContext is like Retrieve with Cred Context
func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) {
if cc == nil {
cc = defaultCredContext
}

token := os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN")
if token == "" {
token = m.Container.AuthorizationToken
Expand Down Expand Up @@ -143,8 +148,15 @@ func (m *IAM) retrieve(cc *CredContext) (Value, error) {
if client == nil {
client = cc.Client
}
if client == nil {
client = defaultCredContext.Client
}

endpoint := m.Endpoint
if endpoint == "" {
endpoint = cc.Endpoint
}

switch {
case identityFile != "":
if len(endpoint) == 0 {
Expand Down Expand Up @@ -228,12 +240,7 @@ func (m *IAM) retrieve(cc *CredContext) (Value, error) {
// Error will be returned if the request fails, or unable to extract
// the desired
func (m *IAM) Retrieve() (Value, error) {
return m.retrieve(defaultCredContext)
}

// RetrieveWithCredContext is like Retrieve with Cred Context
func (m *IAM) RetrieveWithCredContext(cc *CredContext) (Value, error) {
return m.retrieve(cc)
return m.RetrieveWithCredContext(nil)
}

// A ec2RoleCredRespBody provides the shape for unmarshaling credential
Expand Down
31 changes: 20 additions & 11 deletions pkg/credentials/sts_client_grants.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ type STSClientGrants struct {
// NewSTSClientGrants returns a pointer to a new
// Credentials object wrapping the STSClientGrants.
func NewSTSClientGrants(stsEndpoint string, getClientGrantsTokenExpiry func() (*ClientGrantsToken, error)) (*Credentials, error) {
if stsEndpoint == "" {
return nil, errors.New("STS endpoint cannot be empty")
}
if getClientGrantsTokenExpiry == nil {
return nil, errors.New("Client grants access token and expiry retrieval function should be defined")
}
Expand Down Expand Up @@ -160,12 +157,29 @@ func getClientGrantsCredentials(clnt *http.Client, endpoint string,
return a, nil
}

func (m *STSClientGrants) retrieve(cc *CredContext) (Value, error) {
// RetrieveWithCredContext is like Retrieve() with cred context
func (m *STSClientGrants) RetrieveWithCredContext(cc *CredContext) (Value, error) {
if cc == nil {
cc = defaultCredContext
}

client := m.Client
if client == nil {
client = cc.Client
}
a, err := getClientGrantsCredentials(client, m.STSEndpoint, m.GetClientGrantsTokenExpiry)
if client == nil {
client = defaultCredContext.Client
}

stsEndpoint := m.STSEndpoint
if stsEndpoint == "" {
stsEndpoint = cc.Endpoint
}
if stsEndpoint == "" {
return Value{}, errors.New("STS endpoint unknown")
}

a, err := getClientGrantsCredentials(client, stsEndpoint, m.GetClientGrantsTokenExpiry)
if err != nil {
return Value{}, err
}
Expand All @@ -182,13 +196,8 @@ func (m *STSClientGrants) retrieve(cc *CredContext) (Value, error) {
}, nil
}

// RetrieveWithCredContext is like Retrieve() with cred context
func (m *STSClientGrants) RetrieveWithCredContext(cc *CredContext) (Value, error) {
return m.retrieve(cc)
}

// Retrieve retrieves credentials from the MinIO service.
// Error will be returned if the request fails.
func (m *STSClientGrants) Retrieve() (Value, error) {
return m.retrieve(defaultCredContext)
return m.RetrieveWithCredContext(nil)
}
27 changes: 19 additions & 8 deletions pkg/credentials/sts_custom_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,21 @@ type CustomTokenIdentity struct {
RequestedExpiry time.Duration
}

func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error) {
u, err := url.Parse(c.STSEndpoint)
// RetrieveWithCredContext with Retrieve optionally cred context
func (c *CustomTokenIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
if cc == nil {
cc = defaultCredContext
}

stsEndpoint := c.STSEndpoint
if stsEndpoint == "" {
stsEndpoint = cc.Endpoint
}
if stsEndpoint == "" {
return Value{}, errors.New("STS endpoint unknown")
}

u, err := url.Parse(stsEndpoint)
if err != nil {
return value, err
}
Expand All @@ -97,6 +110,9 @@ func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error)
if client == nil {
client = cc.Client
}
if client == nil {
client = defaultCredContext.Client
}

resp, err := client.Do(req)
if err != nil {
Expand Down Expand Up @@ -126,12 +142,7 @@ func (c *CustomTokenIdentity) retrieve(cc *CredContext) (value Value, err error)

// Retrieve - to satisfy Provider interface; fetches credentials from MinIO.
func (c *CustomTokenIdentity) Retrieve() (value Value, err error) {
return c.retrieve(defaultCredContext)
}

// RetrieveWithCredContext with Retrieve optionally cred context
func (c *CustomTokenIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
return c.retrieve(cc)
return c.RetrieveWithCredContext(nil)
}

// NewCustomTokenCredentials - returns credentials using the
Expand Down
30 changes: 21 additions & 9 deletions pkg/credentials/sts_ldap_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package credentials
import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -120,8 +121,22 @@ func NewLDAPIdentityWithSessionPolicy(stsEndpoint, ldapUsername, ldapPassword, p
}), nil
}

func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) {
u, err := url.Parse(k.STSEndpoint)
// RetrieveWithCredContext gets the credential by calling the MinIO STS API for
// LDAP on the configured stsEndpoint.
func (k *LDAPIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
if cc == nil {
cc = defaultCredContext
}

stsEndpoint := k.STSEndpoint
if stsEndpoint == "" {
stsEndpoint = cc.Endpoint
}
if stsEndpoint == "" {
return Value{}, errors.New("STS endpoint unknown")
}

u, err := url.Parse(stsEndpoint)
if err != nil {
return value, err
}
Expand Down Expand Up @@ -149,6 +164,9 @@ func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) {
if client == nil {
client = cc.Client
}
if client == nil {
client = defaultCredContext.Client
}

resp, err := client.Do(req)
if err != nil {
Expand Down Expand Up @@ -194,11 +212,5 @@ func (k *LDAPIdentity) retrieve(cc *CredContext) (value Value, err error) {
// Retrieve gets the credential by calling the MinIO STS API for
// LDAP on the configured stsEndpoint.
func (k *LDAPIdentity) Retrieve() (value Value, err error) {
return k.retrieve(defaultCredContext)
}

// RetrieveWithCredContext gets the credential by calling the MinIO STS API for
// LDAP on the configured stsEndpoint.
func (k *LDAPIdentity) RetrieveWithCredContext(cc *CredContext) (value Value, err error) {
return k.retrieve(cc)
return k.RetrieveWithCredContext(defaultCredContext)
}
33 changes: 19 additions & 14 deletions pkg/credentials/sts_tls_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,6 @@ type STSCertificateIdentity struct {
// to the given STS endpoint with the given TLS certificate and retrieves and
// rotates S3 credentials.
func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, options ...CertificateIdentityOption) (*Credentials, error) {
if endpoint == "" {
return nil, errors.New("STS endpoint cannot be empty")
}
if _, err := url.Parse(endpoint); err != nil {
return nil, err
}
identity := &STSCertificateIdentity{
STSEndpoint: endpoint,
Certificate: certificate,
Expand All @@ -102,8 +96,21 @@ func NewSTSCertificateIdentity(endpoint string, certificate tls.Certificate, opt
return New(identity), nil
}

func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) {
endpointURL, err := url.Parse(i.STSEndpoint)
// RetrieveWithCredContext is Retrieve with cred context
func (i *STSCertificateIdentity) RetrieveWithCredContext(cc *CredContext) (Value, error) {
if cc == nil {
cc = defaultCredContext
}

stsEndpoint := i.STSEndpoint
if stsEndpoint == "" {
stsEndpoint = cc.Endpoint
}
if stsEndpoint == "" {
return Value{}, errors.New("STS endpoint unknown")
}

endpointURL, err := url.Parse(stsEndpoint)
if err != nil {
return Value{}, err
}
Expand All @@ -130,6 +137,9 @@ func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) {
if client == nil {
client = cc.Client
}
if client == nil {
client = defaultCredContext.Client
}

tr, ok := client.Transport.(*http.Transport)
if !ok {
Expand Down Expand Up @@ -192,14 +202,9 @@ func (i *STSCertificateIdentity) retrieve(cc *CredContext) (Value, error) {
}, nil
}

// RetrieveWithCredContext is Retrieve with cred context
func (i *STSCertificateIdentity) RetrieveWithCredContext(cc *CredContext) (Value, error) {
return i.retrieve(cc)
}

// Retrieve fetches a new set of S3 credentials from the configured STS API endpoint.
func (i *STSCertificateIdentity) Retrieve() (Value, error) {
return i.retrieve(defaultCredContext)
return i.RetrieveWithCredContext(defaultCredContext)
}

// Expiration returns the expiration time of the current S3 credentials.
Expand Down
Loading

0 comments on commit 0abbe92

Please sign in to comment.