diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index 3ce8a4f17657e..f41a6f4513b4e 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -33,7 +33,7 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jackc/pgconn" @@ -605,7 +605,7 @@ func mustParseURL(t *testing.T, rawURL string) *url.URL { type fakeSTSClient struct { accountID string arn string - credentials *credentials.Credentials + credentials aws.CredentialsProvider } func (f fakeSTSClient) Do(req *http.Request) (*http.Response, error) { @@ -640,10 +640,10 @@ func mustCreateIAMJoinProvisionToken(t *testing.T, name, awsAccountID, allowedAR return provisionToken } -func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token string, credentials *credentials.Credentials) { +func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token string, credentials aws.CredentialsProvider) { t.Helper() - cred, err := credentials.Get() + cred, err := credentials.Retrieve(context.Background()) require.NoError(t, err) t.Setenv("AWS_ACCESS_KEY_ID", cred.AccessKeyID) diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index 262b8c1046726..b230a04c0e549 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -31,7 +31,7 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -1739,7 +1739,7 @@ func TestALPNSNIProxyGRPCInsecure(t *testing.T) { nodeAccount := "123456789012" nodeRoleARN := "arn:aws:iam::123456789012:role/test" - nodeCredentials := credentials.NewStaticCredentials("FAKE_ID", "FAKE_KEY", "FAKE_TOKEN") + nodeCredentials := credentials.NewStaticCredentialsProvider("FAKE_ID", "FAKE_KEY", "FAKE_TOKEN") provisionToken := mustCreateIAMJoinProvisionToken(t, "iam-join-token", nodeAccount, nodeRoleARN) suite := newSuite(t, diff --git a/lib/cloud/awsconfig/awsconfig.go b/lib/cloud/awsconfig/awsconfig.go index 245fe8a9a6b23..5e6967d2d4909 100644 --- a/lib/cloud/awsconfig/awsconfig.go +++ b/lib/cloud/awsconfig/awsconfig.go @@ -18,12 +18,16 @@ package awsconfig import ( "context" + "crypto/sha1" + "encoding/hex" + "fmt" "log/slog" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/aws/smithy-go/tracing/smithyoteltracing" "github.com/gravitational/trace" "go.opentelemetry.io/otel" @@ -69,7 +73,12 @@ type AssumeRole struct { // RoleARN is the ARN of the role to assume. RoleARN string `json:"role_arn"` // ExternalID is an optional ID to include when assuming the role. - ExternalID string `json:"external_id"` + ExternalID string `json:"external_id,omitempty"` + // SessionName is an optional session name to use when assuming the role. + SessionName string `json:"session_name,omitempty"` + // Tags is a list of STS session tags to pass when assuming the role. + // https://docs.aws.amazon.com/IAM/latest/UserGuide/id_session-tags.html + Tags map[string]string `json:"tags,omitempty"` } // options is a struct of additional options for assuming an AWS role @@ -153,6 +162,18 @@ func WithAssumeRole(roleARN, externalID string) OptionsFn { } } +// WithDetailedAssumeRole configures options needed for assuming an AWS role, +// including optional details like session name, duration, and tags. +func WithDetailedAssumeRole(ar AssumeRole) OptionsFn { + return func(options *options) { + if ar.RoleARN == "" { + // ignore empty role ARN for caller convenience. + return + } + options.assumeRoles = append(options.assumeRoles, ar) + } +} + // WithRetryer sets a custom retryer for the config. func WithRetryer(retryer func() aws.Retryer) OptionsFn { return func(options *options) { @@ -302,6 +323,13 @@ func getAssumeRoleProvider(ctx context.Context, clt stscreds.AssumeRoleAPIClient if role.ExternalID != "" { aro.ExternalID = aws.String(role.ExternalID) } + aro.RoleSessionName = maybeHashRoleSessionName(role.SessionName) + for k, v := range role.Tags { + aro.Tags = append(aro.Tags, ststypes.Tag{ + Key: aws.String(k), + Value: aws.String(v), + }) + } }) } @@ -342,3 +370,38 @@ func (p *integrationCredentialsProvider) Retrieve(ctx context.Context) (aws.Cred ).Retrieve(ctx) return cred, trace.Wrap(err) } + +// maybeHashRoleSessionName truncates the role session name and adds a hash +// when the original role session name is greater than AWS character limit +// (64). +func maybeHashRoleSessionName(roleSessionName string) (ret string) { + // maxRoleSessionNameLength is the maximum length of the role session name + // used by the AssumeRole call. + // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_iam-quotas.html + const maxRoleSessionNameLength = 64 + if len(roleSessionName) <= maxRoleSessionNameLength { + return roleSessionName + } + + const hashLen = 16 + hash := sha1.New() + hash.Write([]byte(roleSessionName)) + hex := hex.EncodeToString(hash.Sum(nil))[:hashLen] + + // "1" for the delimiter. + keepPrefixIndex := maxRoleSessionNameLength - len(hex) - 1 + + // Sanity check. This should never happen since hash length and + // MaxRoleSessionNameLength are both constant. + if keepPrefixIndex < 0 { + keepPrefixIndex = 0 + } + + ret = fmt.Sprintf("%s-%s", roleSessionName[:keepPrefixIndex], hex) + slog.DebugContext(context.Background(), + "AWS role session name is too long. Using a hash instead.", + "hashed", ret, + "original", roleSessionName, + ) + return ret +} diff --git a/lib/cloud/awsconfig/awsconfig_test.go b/lib/cloud/awsconfig/awsconfig_test.go index 2de624fe86c54..ca6d5a1576976 100644 --- a/lib/cloud/awsconfig/awsconfig_test.go +++ b/lib/cloud/awsconfig/awsconfig_test.go @@ -251,12 +251,12 @@ func testGetConfigIntegration(t *testing.T, provider Provider) { func TestNewCacheKey(t *testing.T) { roleChain := []AssumeRole{ {RoleARN: "roleA"}, - {RoleARN: "roleB", ExternalID: "abc123"}, + {RoleARN: "roleB", ExternalID: "abc123", SessionName: "alice", Tags: map[string]string{"AKey": "AValue"}}, } got, err := newCacheKey("integration-name", roleChain...) require.NoError(t, err) want := strings.TrimSpace(` -{"integration":"integration-name","role_chain":[{"role_arn":"roleA","external_id":""},{"role_arn":"roleB","external_id":"abc123"}]} +{"integration":"integration-name","role_chain":[{"role_arn":"roleA"},{"role_arn":"roleB","external_id":"abc123","session_name":"alice","tags":{"AKey":"AValue"}}]} `) require.Equal(t, want, got) } diff --git a/lib/service/service.go b/lib/service/service.go index 7003d108b9843..4d43026364c50 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -112,6 +112,7 @@ import ( pgrepl "github.com/gravitational/teleport/lib/client/db/postgres/repl" dbrepl "github.com/gravitational/teleport/lib/client/db/repl" "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/cloud/gcp" "github.com/gravitational/teleport/lib/cloud/imds" awsimds "github.com/gravitational/teleport/lib/cloud/imds/aws" @@ -4889,6 +4890,12 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return awsoidc.NewSessionV1(ctx, conn.Client, region, integration) } + awsConfigProvider, err := awsconfig.NewCache(awsconfig.WithDefaults( + awsconfig.WithOIDCIntegrationClient(conn.Client), + )) + if err != nil { + return trace.Wrap(err, "unable to create AWS config provider cache") + } connectionsHandler, err := app.NewConnectionsHandler(process.GracefulExitContext(), &app.ConnectionsHandlerConfig{ Clock: process.Clock, @@ -4903,6 +4910,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { CipherSuites: cfg.CipherSuites, ServiceComponent: teleport.ComponentWebProxy, AWSSessionProvider: awsSessionGetter, + AWSConfigProvider: awsConfigProvider, }) if err != nil { return trace.Wrap(err) @@ -6218,6 +6226,10 @@ func (process *TeleportProcess) initApps() { return trace.Wrap(err) } + awsConfigProvider, err := awsconfig.NewCache() + if err != nil { + return trace.Wrap(err, "unable to create AWS config provider cache") + } connectionsHandler, err := app.NewConnectionsHandler(process.ExitContext(), &app.ConnectionsHandlerConfig{ Clock: process.Config.Clock, DataDir: process.Config.DataDir, @@ -6232,6 +6244,7 @@ func (process *TeleportProcess) initApps() { ServiceComponent: teleport.ComponentApp, Logger: logger, AWSSessionProvider: awsutils.SessionProviderUsingAmbientCredentials(), + AWSConfigProvider: awsConfigProvider, }) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/alpnproxy/aws_local_proxy.go b/lib/srv/alpnproxy/aws_local_proxy.go index 77980930b2f37..2498058e927d1 100644 --- a/lib/srv/alpnproxy/aws_local_proxy.go +++ b/lib/srv/alpnproxy/aws_local_proxy.go @@ -134,7 +134,7 @@ func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Re } func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *http.Request) bool { - if err := awsutils.VerifyAWSSignatureV2(req, m.AWSCredentialsProvider); err != nil { + if err := awsutils.VerifyAWSSignature(req, m.AWSCredentialsProvider); err != nil { m.Log.ErrorContext(req.Context(), "AWS signature verification failed", "error", err) rw.WriteHeader(http.StatusForbidden) return true @@ -149,7 +149,7 @@ func (m *AWSAccessMiddleware) handleRequestByAssumedRole(rw http.ResponseWriter, aws.ToString(assumedRole.Credentials.SessionToken), ) - if err := awsutils.VerifyAWSSignatureV2(req, credentials); err != nil { + if err := awsutils.VerifyAWSSignature(req, credentials); err != nil { m.Log.ErrorContext(req.Context(), "AWS signature verification failed", "error", err) rw.WriteHeader(http.StatusForbidden) return true diff --git a/lib/srv/app/aws/handler.go b/lib/srv/app/aws/handler.go index 8bd04ea356e89..f7d95e0badf74 100644 --- a/lib/srv/app/aws/handler.go +++ b/lib/srv/app/aws/handler.go @@ -33,6 +33,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/httplib/reverseproxy" @@ -53,12 +54,12 @@ type signerHandler struct { // SignerHandlerConfig is the awsSignerHandler configuration. type SignerHandlerConfig struct { + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider // Log is a logger for the handler. Log *slog.Logger // RoundTripper is an http.RoundTripper instance used for requests. RoundTripper http.RoundTripper - // SigningService is used to sign requests before forwarding them. - *awsutils.SigningService // Clock is used to override time in tests. Clock clockwork.Clock // MaxHTTPRequestBodySize is the limit on how big a request body can be. @@ -67,8 +68,8 @@ type SignerHandlerConfig struct { // CheckAndSetDefaults validates the AwsSignerHandlerConfig. func (cfg *SignerHandlerConfig) CheckAndSetDefaults() error { - if cfg.SigningService == nil { - return trace.BadParameter("missing SigningService") + if cfg.AWSConfigProvider == nil { + return trace.BadParameter("aws config provider missing") } if cfg.RoundTripper == nil { tr, err := defaults.Transport() @@ -165,15 +166,24 @@ func (s *signerHandler) serveCommonRequest(sessCtx *common.SessionContext, w htt return trace.Wrap(err) } - signedReq, err := s.SignRequest(s.closeContext, unsignedReq, + awsCfg, err := s.AWSConfigProvider.GetConfig(s.closeContext, re.SigningRegion, + awsconfig.WithDetailedAssumeRole(awsconfig.AssumeRole{ + RoleARN: sessCtx.Identity.RouteToApp.AWSRoleARN, + ExternalID: sessCtx.App.GetAWSExternalID(), + SessionName: sessCtx.Identity.Username, + }), + awsconfig.WithCredentialsMaybeIntegration(sessCtx.App.GetIntegration()), + ) + if err != nil { + return trace.Wrap(err) + } + + signedReq, err := awsutils.SignRequest(s.closeContext, unsignedReq, &awsutils.SigningCtx{ + Clock: s.Clock, + Credentials: awsCfg.Credentials, SigningName: re.SigningName, SigningRegion: re.SigningRegion, - Expiry: sessCtx.Identity.Expires, - SessionName: sessCtx.Identity.Username, - AWSRoleArn: sessCtx.Identity.RouteToApp.AWSRoleARN, - AWSExternalID: sessCtx.App.GetAWSExternalID(), - Integration: sessCtx.App.GetIntegration(), }) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/app/aws/handler_test.go b/lib/srv/app/aws/handler_test.go index 90dbd4ed46d21..4f19b9fb18a95 100644 --- a/lib/srv/app/aws/handler_test.go +++ b/lib/srv/app/aws/handler_test.go @@ -31,6 +31,7 @@ import ( "testing" "time" + credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/client" @@ -50,6 +51,8 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/cloud/awsconfig" + "github.com/gravitational/teleport/lib/cloud/mocks" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/srv/app/common" @@ -191,12 +194,19 @@ func TestAWSSignerHandler(t *testing.T) { }) require.NoError(t, err) + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC( + types.Metadata{Name: "my-integration"}, + &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "arn:aws:sts::123456789012:role/TestRole", + }, + ) + require.NoError(t, err) consoleAppWithIntegration, err := types.NewAppV3(types.Metadata{ Name: "awsconsole", }, types.AppSpecV3{ URI: constants.AWSConsoleURL, PublicAddr: "test.local", - Integration: "my-integration", + Integration: awsOIDCIntegration.GetName(), }) require.NoError(t, err) @@ -204,7 +214,7 @@ func TestAWSSignerHandler(t *testing.T) { name string app types.Application awsClientSession *session.Session - awsSessionProvider awsutils.AWSSessionProvider + awsConfigProvider awsconfig.Provider request makeRequest advanceClock time.Duration wantHost string @@ -226,7 +236,7 @@ func TestAWSSignerHandler(t *testing.T) { })), request: s3Request, wantHost: "s3.us-west-2.amazonaws.com", - wantAuthCredKeyID: "AKIDl", + wantAuthCredKeyID: "FAKEACCESSKEYID", wantAuthCredService: "s3", wantAuthCredRegion: "us-west-2", wantEventType: &events.AppSessionRequest{}, @@ -242,14 +252,14 @@ func TestAWSSignerHandler(t *testing.T) { Region: aws.String("us-west-2"), })), request: s3Request, - awsSessionProvider: func(ctx context.Context, region, integration string) (*session.Session, error) { - if integration != "my-integration" { - return nil, trace.BadParameter("") - } - return nil, nil + awsConfigProvider: &mocks.AWSConfigProvider{ + OIDCIntegrationClient: &mocks.FakeOIDCIntegrationClient{ + Integration: awsOIDCIntegration, + Token: "fake-oidc-token", + }, }, wantHost: "s3.us-west-2.amazonaws.com", - wantAuthCredKeyID: "AKIDl", + wantAuthCredKeyID: "FAKEACCESSKEYID", wantAuthCredService: "s3", wantAuthCredRegion: "us-west-2", wantEventType: &events.AppSessionRequest{}, @@ -266,7 +276,7 @@ func TestAWSSignerHandler(t *testing.T) { })), request: s3Request, wantHost: "s3.us-west-1.amazonaws.com", - wantAuthCredKeyID: "AKIDl", + wantAuthCredKeyID: "FAKEACCESSKEYID", wantAuthCredService: "s3", wantAuthCredRegion: "us-west-1", wantEventType: &events.AppSessionRequest{}, @@ -314,7 +324,7 @@ func TestAWSSignerHandler(t *testing.T) { })), request: dynamoRequest, wantHost: "dynamodb.us-east-1.amazonaws.com", - wantAuthCredKeyID: "AKIDl", + wantAuthCredKeyID: "FAKEACCESSKEYID", wantAuthCredService: "dynamodb", wantAuthCredRegion: "us-east-1", wantEventType: &events.AppSessionDynamoDBRequest{}, @@ -331,7 +341,7 @@ func TestAWSSignerHandler(t *testing.T) { })), request: dynamoRequest, wantHost: "dynamodb.us-west-1.amazonaws.com", - wantAuthCredKeyID: "AKIDl", + wantAuthCredKeyID: "FAKEACCESSKEYID", wantAuthCredService: "dynamodb", wantAuthCredRegion: "us-west-1", wantEventType: &events.AppSessionDynamoDBRequest{}, @@ -379,7 +389,7 @@ func TestAWSSignerHandler(t *testing.T) { })), request: lambdaRequest, wantHost: "lambda.us-east-1.amazonaws.com", - wantAuthCredKeyID: "AKIDl", + wantAuthCredKeyID: "FAKEACCESSKEYID", wantAuthCredService: "lambda", wantAuthCredRegion: "us-east-1", wantEventType: &events.AppSessionRequest{}, @@ -411,7 +421,7 @@ func TestAWSSignerHandler(t *testing.T) { request: assumeRoleRequest(2 * time.Hour), advanceClock: 10 * time.Minute, wantHost: "sts.amazonaws.com", - wantAuthCredKeyID: "AKIDl", + wantAuthCredKeyID: "FAKEACCESSKEYID", wantAuthCredService: "sts", wantAuthCredRegion: "us-east-1", wantEventType: &events.AppSessionRequest{}, @@ -429,7 +439,7 @@ func TestAWSSignerHandler(t *testing.T) { })), request: assumeRoleRequest(32 * time.Minute), wantHost: "sts.amazonaws.com", - wantAuthCredKeyID: "AKIDl", + wantAuthCredKeyID: "FAKEACCESSKEYID", wantAuthCredService: "sts", wantAuthCredRegion: "us-east-1", wantEventType: &events.AppSessionRequest{}, @@ -445,14 +455,10 @@ func TestAWSSignerHandler(t *testing.T) { Credentials: staticAWSCredentialsForClient, Region: aws.String("us-east-1"), })), - request: assumeRoleRequest(2 * time.Hour), - advanceClock: 50 * time.Minute, // identity is expiring in 10m which is less than minimum - wantHost: "sts.amazonaws.com", - wantAuthCredKeyID: "AKIDl", - wantAuthCredService: "sts", - wantAuthCredRegion: "us-east-1", - wantEventType: &events.AppSessionRequest{}, + request: assumeRoleRequest(2 * time.Hour), + advanceClock: 50 * time.Minute, // identity is expiring in 10m which is less than minimum errAssertionFns: []require.ErrorAssertionFunc{ + // the request is 403 forbidden by Teleport, so the mock AWS handler won't be sent anything. hasStatusCode(http.StatusForbidden), }, }, @@ -476,7 +482,9 @@ func TestAWSSignerHandler(t *testing.T) { // check that the signature is valid. if !tc.skipVerifySignature { - err = awsutils.VerifyAWSSignature(r, staticAWSCredentials) + err := awsutils.VerifyAWSSignature(r, + credentialsv2.NewStaticCredentialsProvider(tc.wantAuthCredKeyID, "secret", "token"), + ) if !assert.NoError(t, err) { http.Error(w, err.Error(), trace.ErrorToCode(err)) return @@ -490,12 +498,12 @@ func TestAWSSignerHandler(t *testing.T) { w.WriteHeader(http.StatusOK) } - sessionProvider := awsutils.SessionProviderUsingAmbientCredentials() - if tc.awsSessionProvider != nil { - sessionProvider = tc.awsSessionProvider + awsCfgProvider := tc.awsConfigProvider + if awsCfgProvider == nil { + awsCfgProvider = &mocks.AWSConfigProvider{} } - suite := createSuite(t, mockAwsHandler, tc.app, fakeClock, sessionProvider) + suite := createSuite(t, mockAwsHandler, tc.app, fakeClock, awsCfgProvider) fakeClock.Advance(tc.advanceClock) err := tc.request(suite.URL, tc.awsClientSession, tc.wantHost) @@ -603,7 +611,6 @@ const assumedRoleKeyID = "assumedRoleKeyID" var ( staticAWSCredentialsForAssumedRole = credentials.NewStaticCredentials(assumedRoleKeyID, "assumedRoleKeySecret", "") - staticAWSCredentials = credentials.NewStaticCredentials("AKIDl", "SECRET", "SESSION") staticAWSCredentialsForClient = credentials.NewStaticCredentials("fakeClientKeyID", "fakeClientSecret", "") ) @@ -614,7 +621,7 @@ type suite struct { recorder *eventstest.ChannelRecorder } -func createSuite(t *testing.T, mockAWSHandler http.HandlerFunc, app types.Application, clock clockwork.Clock, awsSessionProvider awsutils.AWSSessionProvider) *suite { +func createSuite(t *testing.T, mockAWSHandler http.HandlerFunc, app types.Application, clock clockwork.Clock, acp awsconfig.Provider) *suite { recorder := eventstest.NewChannelRecorder(1) identity := tlsca.Identity{ Username: "user", @@ -630,13 +637,6 @@ func createSuite(t *testing.T, mockAWSHandler http.HandlerFunc, app types.Applic awsAPIMock.Close() }) - svc, err := awsutils.NewSigningService(awsutils.SigningServiceConfig{ - SessionProvider: awsSessionProvider, - CredentialsGetter: awsutils.NewStaticCredentialsGetter(staticAWSCredentials), - Clock: clock, - }) - require.NoError(t, err) - audit, err := common.NewAudit(common.AuditConfig{ Emitter: libevents.NewDiscardEmitter(), Recorder: libevents.WithNoOpPreparer(recorder), @@ -644,7 +644,7 @@ func createSuite(t *testing.T, mockAWSHandler http.HandlerFunc, app types.Applic require.NoError(t, err) signerHandler, err := NewAWSSignerHandler(context.Background(), SignerHandlerConfig{ - SigningService: svc, + AWSConfigProvider: acp, RoundTripper: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, diff --git a/lib/srv/app/connections_handler.go b/lib/srv/app/connections_handler.go index 3fad12e54eaaa..f5a583fc05c77 100644 --- a/lib/srv/app/connections_handler.go +++ b/lib/srv/app/connections_handler.go @@ -45,6 +45,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/services" @@ -93,6 +94,9 @@ type ConnectionsHandlerConfig struct { // AWSSessionProvider is used to provide AWS Sessions. AWSSessionProvider awsutils.AWSSessionProvider + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider + // TLSConfig is the *tls.Config for this server. TLSConfig *tls.Config @@ -142,6 +146,9 @@ func (c *ConnectionsHandlerConfig) CheckAndSetDefaults() error { if c.AWSSessionProvider == nil { return trace.BadParameter("aws session provider missing") } + if c.AWSConfigProvider == nil { + return trace.BadParameter("aws config provider missing") + } if c.Cloud == nil { cloud, err := NewCloud(CloudConfig{ Clock: c.Clock, @@ -206,16 +213,9 @@ func NewConnectionsHandler(closeContext context.Context, cfg *ConnectionsHandler return nil, trace.Wrap(err) } - awsSigner, err := awsutils.NewSigningService(awsutils.SigningServiceConfig{ - Clock: cfg.Clock, - SessionProvider: cfg.AWSSessionProvider, - }) - if err != nil { - return nil, trace.Wrap(err) - } awsHandler, err := appaws.NewAWSSignerHandler(closeContext, appaws.SignerHandlerConfig{ - SigningService: awsSigner, - Clock: cfg.Clock, + AWSConfigProvider: cfg.AWSConfigProvider, + Clock: cfg.Clock, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/app/server_test.go b/lib/srv/app/server_test.go index 1c0c8b4f322b3..3f0da8abda9b6 100644 --- a/lib/srv/app/server_test.go +++ b/lib/srv/app/server_test.go @@ -56,6 +56,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/httplib/reverseproxy" @@ -365,6 +366,7 @@ func SetUpSuiteWithConfig(t *testing.T, config suiteConfig) *Suite { CipherSuites: utils.DefaultCipherSuites(), ServiceComponent: teleport.ComponentApp, AWSSessionProvider: aws.SessionProviderUsingAmbientCredentials(), + AWSConfigProvider: &mocks.AWSConfigProvider{}, }) require.NoError(t, err) diff --git a/lib/srv/db/dynamodb/engine.go b/lib/srv/db/dynamodb/engine.go index 734ff51568f2c..165d856677f9d 100644 --- a/lib/srv/db/dynamodb/engine.go +++ b/lib/srv/db/dynamodb/engine.go @@ -40,6 +40,7 @@ import ( "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" apiaws "github.com/gravitational/teleport/api/utils/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/modules" @@ -138,14 +139,6 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error } defer e.Audit.OnSessionEnd(e.Context, e.sessionCtx) - signer, err := libaws.NewSigningService(libaws.SigningServiceConfig{ - Clock: e.Clock, - AWSConfigProvider: e.AWSConfigProvider, - }) - if err != nil { - return trace.Wrap(err) - } - clientConnReader := bufio.NewReader(e.clientConn) observe() @@ -159,7 +152,7 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error return trace.Wrap(err) } - if err := e.process(ctx, req, signer, msgFromClient, msgFromServer); err != nil { + if err := e.process(ctx, req, msgFromClient, msgFromServer); err != nil { return trace.Wrap(err) } } @@ -167,7 +160,7 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error // process reads request from connected dynamodb client, processes the requests/responses and sends data back // to the client. -func (e *Engine) process(ctx context.Context, req *http.Request, signer *libaws.SigningService, msgFromClient prometheus.Counter, msgFromServer prometheus.Counter) (err error) { +func (e *Engine) process(ctx context.Context, req *http.Request, msgFromClient prometheus.Counter, msgFromServer prometheus.Counter) (err error) { msgFromClient.Inc() if req.Body != nil { @@ -210,20 +203,32 @@ func (e *Engine) process(ctx context.Context, req *http.Request, signer *libaws. if err != nil { return trace.Wrap(err) } - signingCtx := &libaws.SigningCtx{ - SigningName: re.SigningName, - SigningRegion: re.SigningRegion, - Expiry: e.sessionCtx.Identity.Expires, - SessionName: e.sessionCtx.Identity.Username, - BaseAWSRoleARN: meta.AssumeRoleARN, - BaseAWSExternalID: meta.ExternalID, - AWSRoleArn: roleArn, - SessionTags: e.sessionCtx.Database.GetAWS().SessionTags, + + ar := awsconfig.AssumeRole{ + RoleARN: roleArn, + SessionName: e.sessionCtx.Identity.Username, + Tags: meta.SessionTags, } if meta.AssumeRoleARN == "" { - signingCtx.AWSExternalID = meta.ExternalID + ar.ExternalID = meta.ExternalID } - signedReq, err := signer.SignRequest(e.Context, outReq, signingCtx) + awsCfg, err := e.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithDetailedAssumeRole(ar), + awsconfig.WithAmbientCredentials(), + ) + if err != nil { + return trace.Wrap(err) + } + + signingCtx := &libaws.SigningCtx{ + Clock: e.Clock, + Credentials: awsCfg.Credentials, + SigningName: re.SigningName, + SigningRegion: re.SigningRegion, + } + + signedReq, err := libaws.SignRequest(e.Context, outReq, signingCtx) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/dynamodb/test.go b/lib/srv/db/dynamodb/test.go index b12d0493ea75e..e91125dff69f8 100644 --- a/lib/srv/db/dynamodb/test.go +++ b/lib/srv/db/dynamodb/test.go @@ -101,7 +101,7 @@ func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (*T mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - err := awsutils.VerifyAWSSignatureV2(r, + err := awsutils.VerifyAWSSignature(r, credentials.NewStaticCredentialsProvider("FAKEACCESSKEYID", "secret", "token"), ) if err != nil { diff --git a/lib/srv/db/opensearch/engine.go b/lib/srv/db/opensearch/engine.go index 4f54f4b5a282d..f017f5c7a8369 100644 --- a/lib/srv/db/opensearch/engine.go +++ b/lib/srv/db/opensearch/engine.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/srv/db/common" @@ -133,14 +134,6 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error e.Audit.OnSessionStart(e.Context, e.sessionCtx, err) return trace.Wrap(err) } - signer, err := libaws.NewSigningService(libaws.SigningServiceConfig{ - Clock: e.Clock, - AWSConfigProvider: e.AWSConfigProvider, - }) - if err != nil { - return trace.Wrap(err) - } - // TODO(Tener): // Consider rewriting to support HTTP2 clients. // Ideally we should have shared middleware for DB clients using HTTP, instead of separate per-engine implementations. @@ -166,7 +159,7 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error return trace.Wrap(err) } - if err := e.process(ctx, tr, signer, req, msgFromClient, msgFromServer); err != nil { + if err := e.process(ctx, tr, req, msgFromClient, msgFromServer); err != nil { return trace.Wrap(err) } } @@ -174,7 +167,7 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error // process reads request from connected OpenSearch client, processes the requests/responses and send data back // to the client. -func (e *Engine) process(ctx context.Context, tr *http.Transport, signer *libaws.SigningService, req *http.Request, msgFromClient prometheus.Counter, msgFromServer prometheus.Counter) error { +func (e *Engine) process(ctx context.Context, tr *http.Transport, req *http.Request, msgFromClient prometheus.Counter, msgFromServer prometheus.Counter) error { msgFromClient.Inc() if req.Body != nil { @@ -193,7 +186,7 @@ func (e *Engine) process(ctx context.Context, tr *http.Transport, signer *libaws e.emitAuditEvent(reqCopy, payload, responseStatusCode, err == nil) }() - signedReq, err := e.getSignedRequest(signer, reqCopy) + signedReq, err := e.getSignedRequest(reqCopy) if err != nil { return trace.Wrap(err) } @@ -225,31 +218,33 @@ func (e *Engine) getTransport(ctx context.Context) (*http.Transport, error) { return tr, nil } -func (e *Engine) getSignedRequest(signer *libaws.SigningService, reqCopy *http.Request) (*http.Request, error) { +func (e *Engine) getSignedRequest(reqCopy *http.Request) (*http.Request, error) { roleArn, err := libaws.BuildRoleARN(e.sessionCtx.DatabaseUser, e.sessionCtx.Database.GetAWS().Region, e.sessionCtx.Database.GetAWS().AccountID) if err != nil { return nil, trace.Wrap(err) } meta := e.sessionCtx.Database.GetAWS() - signCtx := &libaws.SigningCtx{ - SigningName: "es", - SigningRegion: meta.Region, - Expiry: e.sessionCtx.Identity.Expires, - SessionName: e.sessionCtx.Identity.Username, - BaseAWSRoleARN: meta.AssumeRoleARN, - BaseAWSExternalID: meta.ExternalID, - AWSRoleArn: roleArn, - // OpenSearch uses meta.ExternalID for both the base role and the - // assumed role. - AWSExternalID: meta.ExternalID, + awsCfg, err := e.AWSConfigProvider.GetConfig(e.Context, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithDetailedAssumeRole(awsconfig.AssumeRole{ + RoleARN: roleArn, + ExternalID: meta.ExternalID, + SessionName: e.sessionCtx.Identity.Username, + }), + awsconfig.WithAmbientCredentials(), + ) + if err != nil { + return nil, trace.Wrap(err) } - - if meta.AssumeRoleARN == "" { - signCtx.AWSExternalID = meta.ExternalID + signCtx := &libaws.SigningCtx{ + Clock: e.Clock, + Credentials: awsCfg.Credentials, + SigningName: "es", + SigningRegion: e.sessionCtx.Database.GetAWS().Region, } - signedReq, err := signer.SignRequest(e.Context, reqCopy, signCtx) + signedReq, err := libaws.SignRequest(e.Context, reqCopy, signCtx) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/utils/aws/aws.go b/lib/utils/aws/aws.go index c06636c109150..084a6383e80f7 100644 --- a/lib/utils/aws/aws.go +++ b/lib/utils/aws/aws.go @@ -159,22 +159,17 @@ func IsSignedByAWSSigV4(r *http.Request) bool { return strings.HasPrefix(r.Header.Get(AuthorizationHeader), AmazonSigV4AuthorizationPrefix) } -// VerifyAWSSignatureV2 is a temporary AWS SDK migration helper. -func VerifyAWSSignatureV2(req *http.Request, provider aws.CredentialsProvider) error { - return VerifyAWSSignature(req, migration.NewCredentialsAdapter(provider)) -} - // VerifyAWSSignature verifies the request signature ensuring that the request originates from tsh aws command execution // AWS CLI signs the request with random generated credentials that are passed to LocalProxy by // the AWSCredentials LocalProxyConfig configuration. -func VerifyAWSSignature(req *http.Request, credentials *credentials.Credentials) error { +func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider) error { sigV4, err := ParseSigV4(req.Header.Get("Authorization")) if err != nil { return trace.BadParameter(err.Error()) } // Verifies the request is signed by the expected access key ID. - credValue, err := credentials.Get() + credValue, err := credProvider.Retrieve(req.Context()) if err != nil { return trace.Wrap(err) } @@ -212,7 +207,7 @@ func VerifyAWSSignature(req *http.Request, credentials *credentials.Credentials) return trace.BadParameter(err.Error()) } - signer := NewSigner(credentials, sigV4.Service) + signer := NewSignerV2(credProvider, sigV4.Service) _, err = signer.Sign(reqCopy, bytes.NewReader(payload), sigV4.Service, sigV4.Region, t) if err != nil { return trace.Wrap(err) diff --git a/lib/utils/aws/credentials.go b/lib/utils/aws/credentials.go index 47c3105174943..8be99d898e5f7 100644 --- a/lib/utils/aws/credentials.go +++ b/lib/utils/aws/credentials.go @@ -20,187 +20,15 @@ package aws import ( "context" - "log/slog" - "sort" - "strings" - "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" "github.com/gravitational/teleport/lib/modules" - "github.com/gravitational/teleport/lib/utils" ) -// GetCredentialsRequest is the request for obtaining STS credentials. -type GetCredentialsRequest struct { - // Provider is the user session used to create the STS client. - Provider client.ConfigProvider - // Expiry is session expiry to be requested. - Expiry time.Time - // SessionName is the session name to be requested. - SessionName string - // RoleARN is the role ARN to be requested. - RoleARN string - // ExternalID is the external ID to be requested, if not empty. - ExternalID string - // Tags is a list of AWS STS session tags. - Tags map[string]string -} - -// CredentialsGetter defines an interface for obtaining STS credentials. -type CredentialsGetter interface { - // Get obtains STS credentials. - Get(ctx context.Context, request GetCredentialsRequest) (*credentials.Credentials, error) -} - -type credentialsGetter struct { -} - -// NewCredentialsGetter returns a new CredentialsGetter. -func NewCredentialsGetter() CredentialsGetter { - return &credentialsGetter{} -} - -// Get obtains STS credentials. -func (g *credentialsGetter) Get(ctx context.Context, request GetCredentialsRequest) (*credentials.Credentials, error) { - slog.DebugContext(ctx, "Creating STS session", - "session_name", request.SessionName, - "role_arn", request.RoleARN, - ) - return stscreds.NewCredentials(request.Provider, request.RoleARN, - func(cred *stscreds.AssumeRoleProvider) { - cred.RoleSessionName = MaybeHashRoleSessionName(request.SessionName) - cred.Expiry.SetExpiration(request.Expiry, 0) - - if request.ExternalID != "" { - cred.ExternalID = aws.String(request.ExternalID) - } - - cred.Tags = make([]*sts.Tag, 0, len(request.Tags)) - for key, value := range request.Tags { - cred.Tags = append(cred.Tags, &sts.Tag{Key: aws.String(key), Value: aws.String(value)}) - } - }, - ), nil -} - -// CachedCredentialsGetterConfig is the config for creating a CredentialsGetter that caches credentials. -type CachedCredentialsGetterConfig struct { - // Getter is the CredentialsGetter for obtaining the STS credentials. - Getter CredentialsGetter - // CacheTTL is the cache TTL. - CacheTTL time.Duration - // Clock is used to control time. - Clock clockwork.Clock -} - -// SetDefaults sets default values for CachedCredentialsGetterConfig. -func (c *CachedCredentialsGetterConfig) SetDefaults() { - if c.Getter == nil { - c.Getter = NewCredentialsGetter() - } - if c.CacheTTL <= 0 { - c.CacheTTL = time.Minute - } - if c.Clock == nil { - c.Clock = clockwork.NewRealClock() - } -} - -// credentialRequestCacheKey credentials request cache key. -type credentialRequestCacheKey struct { - provider client.ConfigProvider - expiry time.Time - sessionName string - roleARN string - externalID string - tags string -} - -// newCredentialRequestCacheKey creates a new cache key for the credentials -// request. -func newCredentialRequestCacheKey(req GetCredentialsRequest) credentialRequestCacheKey { - k := credentialRequestCacheKey{ - provider: req.Provider, - expiry: req.Expiry, - sessionName: req.SessionName, - roleARN: req.RoleARN, - externalID: req.ExternalID, - } - - tags := make([]string, 0, len(req.Tags)) - for key, value := range req.Tags { - tags = append(tags, key+"="+value+",") - } - sort.Strings(tags) - k.tags = strings.Join(tags, ",") - - return k -} - -type cachedCredentialsGetter struct { - config CachedCredentialsGetterConfig - cache *utils.FnCache -} - -// NewCachedCredentialsGetter returns a CredentialsGetter that caches credentials. -func NewCachedCredentialsGetter(config CachedCredentialsGetterConfig) (CredentialsGetter, error) { - config.SetDefaults() - - cache, err := utils.NewFnCache(utils.FnCacheConfig{ - TTL: config.CacheTTL, - Clock: config.Clock, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return &cachedCredentialsGetter{ - config: config, - cache: cache, - }, nil -} - -// Get returns cached credentials if found, or fetch it from the configured -// getter. -func (g *cachedCredentialsGetter) Get(ctx context.Context, request GetCredentialsRequest) (*credentials.Credentials, error) { - credentials, err := utils.FnCacheGet(ctx, g.cache, newCredentialRequestCacheKey(request), func(ctx context.Context) (*credentials.Credentials, error) { - credentials, err := g.config.Getter.Get(ctx, request) - return credentials, trace.Wrap(err) - }) - return credentials, trace.Wrap(err) -} - -type staticCredentialsGetter struct { - credentials *credentials.Credentials -} - -// NewStaticCredentialsGetter returns a CredentialsGetter that always returns -// the same provided credentials. -// -// Used in testing to mock CredentialsGetter. -func NewStaticCredentialsGetter(credentials *credentials.Credentials) CredentialsGetter { - return &staticCredentialsGetter{ - credentials: credentials, - } -} - -// Get returns the credentials provided to NewStaticCredentialsGetter. -func (g *staticCredentialsGetter) Get(_ context.Context, _ GetCredentialsRequest) (*credentials.Credentials, error) { - if g.credentials == nil { - return nil, trace.NotFound("no credentials found") - } - return g.credentials, nil -} - // AWSSessionProvider defines a function that creates an AWS Session. // It must use ambient credentials if Integration is empty. // It must use Integration credentials otherwise. diff --git a/lib/utils/aws/credentials_test.go b/lib/utils/aws/credentials_test.go deleted file mode 100644 index 3f682f3462cc5..0000000000000 --- a/lib/utils/aws/credentials_test.go +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package aws - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/google/uuid" - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" -) - -type fakeCredentialsGetter struct { -} - -func (f *fakeCredentialsGetter) Get(ctx context.Context, request GetCredentialsRequest) (*credentials.Credentials, error) { - return credentials.NewStaticCredentials( - fmt.Sprintf("%s-%s-%s", request.SessionName, request.RoleARN, request.ExternalID), - uuid.NewString(), - uuid.NewString(), - ), nil -} - -func TestCachedCredentialsGetter(t *testing.T) { - t.Parallel() - - hostSession := session.Must(session.NewSession(&aws.Config{ - Credentials: credentials.AnonymousCredentials, - Region: aws.String("us-west-2"), - })) - fakeClock := clockwork.NewFakeClock() - - getter, err := NewCachedCredentialsGetter(CachedCredentialsGetterConfig{ - Getter: &fakeCredentialsGetter{}, - CacheTTL: time.Minute, - Clock: fakeClock, - }) - require.NoError(t, err) - - cred1, err := getter.Get(context.Background(), GetCredentialsRequest{ - Provider: hostSession, - Expiry: fakeClock.Now().Add(time.Hour), - SessionName: "test-session", - RoleARN: "test-role", - Tags: map[string]string{ - "one": "1", - "two": "2", - "three": "3", - }, - }) - require.NoError(t, err) - checkCredentialsAccessKeyID(t, cred1, "test-session-test-role-") - - tests := []struct { - name string - sessionName string - roleARN string - externalID string - tags map[string]string - advanceClock time.Duration - compareCred1 require.ComparisonAssertionFunc - }{ - { - name: "cached", - sessionName: "test-session", - roleARN: "test-role", - tags: map[string]string{ - "one": "1", - "two": "2", - "three": "3", - }, - compareCred1: require.Same, - }, - { - name: "cached different tags order", - sessionName: "test-session", - roleARN: "test-role", - tags: map[string]string{ - "three": "3", - "two": "2", - "one": "1", - }, - compareCred1: require.Same, - }, - { - name: "different session name", - sessionName: "test-session-2", - roleARN: "test-role", - tags: map[string]string{ - "one": "1", - "two": "2", - "three": "3", - }, - compareCred1: require.NotSame, - }, - { - name: "different role ARN", - sessionName: "test-session", - roleARN: "test-role-2", - tags: map[string]string{ - "one": "1", - "two": "2", - "three": "3", - }, - compareCred1: require.NotSame, - }, - { - name: "different external ID", - sessionName: "test-session", - roleARN: "test-role", - externalID: "test-id", - tags: map[string]string{ - "one": "1", - "two": "2", - "three": "3", - }, - compareCred1: require.NotSame, - }, - { - name: "different tags", - sessionName: "test-session", - roleARN: "test-role", - tags: map[string]string{ - "four": "4", - "five": "5", - }, - compareCred1: require.NotSame, - }, - { - name: "cache expired", - sessionName: "test-session", - roleARN: "test-role", - tags: map[string]string{ - "one": "1", - "two": "2", - "three": "3", - }, - advanceClock: time.Hour, - compareCred1: require.NotSame, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - fakeClock.Advance(test.advanceClock) - - cred, err := getter.Get(context.Background(), GetCredentialsRequest{ - Provider: hostSession, - Expiry: fakeClock.Now().Add(time.Hour), - SessionName: test.sessionName, - RoleARN: test.roleARN, - ExternalID: test.externalID, - Tags: test.tags, - }) - require.NoError(t, err) - checkCredentialsAccessKeyID(t, cred, fmt.Sprintf("%s-%s-%s", test.sessionName, test.roleARN, test.externalID)) - test.compareCred1(t, cred1, cred) - }) - } -} - -func checkCredentialsAccessKeyID(t *testing.T, cred *credentials.Credentials, wantAccessKeyID string) { - t.Helper() - value, err := cred.Get() - require.NoError(t, err) - require.Equal(t, wantAccessKeyID, value.AccessKeyID) -} diff --git a/lib/utils/aws/signing.go b/lib/utils/aws/signing.go index 6549265aed676..31d29532c20d8 100644 --- a/lib/utils/aws/signing.go +++ b/lib/utils/aws/signing.go @@ -23,114 +23,35 @@ import ( "context" "io" "net/http" - "time" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/utils" ) -// NewSigningService creates a new instance of SigningService. -func NewSigningService(config SigningServiceConfig) (*SigningService, error) { - if err := config.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - return &SigningService{ - SigningServiceConfig: config, - }, nil -} - -// SigningService is an AWS CLI proxy service that signs AWS requests -// based on user identity. -type SigningService struct { - // SigningServiceConfig is the SigningService configuration. - SigningServiceConfig -} - -// SigningServiceConfig is the SigningService configuration. -type SigningServiceConfig struct { - // SessionProvider is a provider for AWS Sessions. - SessionProvider AWSSessionProvider - // Clock is used to override time in tests. - Clock clockwork.Clock - // CredentialsGetter is used to obtain STS credentials. - CredentialsGetter CredentialsGetter - // AWSConfigProvider is a provider for AWS configs. - AWSConfigProvider awsconfig.Provider -} - -// CheckAndSetDefaults validates the SigningServiceConfig config. -func (s *SigningServiceConfig) CheckAndSetDefaults() error { - if s.Clock == nil { - s.Clock = clockwork.NewRealClock() - } - - if s.AWSConfigProvider == nil { - if s.SessionProvider == nil { - return trace.BadParameter("session provider or config provider is required") - } - if s.CredentialsGetter == nil { - // Use cachedCredentialsGetter by default. cachedCredentialsGetter - // caches the credentials for one minute. - cachedGetter, err := NewCachedCredentialsGetter(CachedCredentialsGetterConfig{ - Clock: s.Clock, - }) - if err != nil { - return trace.Wrap(err) - } - - s.CredentialsGetter = cachedGetter - } - } - return nil -} - // SigningCtx contains AWS SigV4 signing context parameters. type SigningCtx struct { + // Clock is used to override time in tests. + Clock clockwork.Clock + // Credentials provides AWS credentials. + Credentials aws.CredentialsProvider // SigningName is the AWS signing service name. SigningName string // SigningRegion is the AWS region to sign a request for. SigningRegion string - // Expiry is the expiration of the AWS credentials used to sign requests. - Expiry time.Time - // SessionName is role session name of AWS credentials used to sign requests. - SessionName string - // BaseAWSRoleARN is the AWS ARN of the role as a base to the assumed roles. - BaseAWSRoleARN string - // BaseAWSRoleARN is an optional external ID used on base assumed role. - BaseAWSExternalID string - // AWSRoleArn is the AWS ARN of the role to assume for signing requests, - // chained with BaseAWSRoleARN. - AWSRoleArn string - // AWSExternalID is an optional external ID used when getting sts credentials. - AWSExternalID string - // SessionTags is a list of AWS STS session tags. - SessionTags map[string]string - // Integration is the Integration name to use to generate credentials. - // If empty, it will use ambient credentials - Integration string } // Check checks signing context parameters. -func (sc *SigningCtx) Check(clock clockwork.Clock) error { +func (sc *SigningCtx) Check() error { switch { + case sc.Credentials == nil: + return trace.BadParameter("missing AWS credentials") case sc.SigningName == "": return trace.BadParameter("missing AWS signing name") case sc.SigningRegion == "": return trace.BadParameter("missing AWS signing region") - case sc.SessionName == "": - return trace.BadParameter("missing AWS session name") - case sc.AWSRoleArn == "": - return trace.BadParameter("missing AWS Role ARN") - case sc.Expiry.Before(clock.Now()): - return trace.BadParameter("AWS SigV4 expiry has already expired") - } - _, err := ParseRoleARN(sc.AWSRoleArn) - if err != nil { - return trace.Wrap(err) } return nil } @@ -151,11 +72,11 @@ func (sc *SigningCtx) Check(clock clockwork.Clock) error { // Not that for endpoint resolving the https://github.com/aws/aws-sdk-go/aws/endpoints/endpoints.go // package is used and when Amazon releases a new API the dependency update is needed. // 5. Sign HTTP request. -func (s *SigningService) SignRequest(ctx context.Context, req *http.Request, signCtx *SigningCtx) (*http.Request, error) { +func SignRequest(ctx context.Context, req *http.Request, signCtx *SigningCtx) (*http.Request, error) { if signCtx == nil { return nil, trace.BadParameter("missing signing context") } - if err := signCtx.Check(s.Clock); err != nil { + if err := signCtx.Check(); err != nil { return nil, trace.Wrap(err) } payload, err := utils.GetAndReplaceRequestBody(req) @@ -173,11 +94,8 @@ func (s *SigningService) SignRequest(ctx context.Context, req *http.Request, sig // 100-continue" headers without being signed, otherwise the Athena service // would reject the requests. unsignedHeaders := removeUnsignedHeaders(reqCopy) - signer, err := s.newSigner(ctx, signCtx) - if err != nil { - return nil, trace.Wrap(err) - } - _, err = signer.Sign(reqCopy, bytes.NewReader(payload), signCtx.SigningName, signCtx.SigningRegion, s.Clock.Now()) + signer := NewSignerV2(signCtx.Credentials, signCtx.SigningName) + _, err = signer.Sign(reqCopy, bytes.NewReader(payload), signCtx.SigningName, signCtx.SigningRegion, signCtx.Clock.Now()) if err != nil { return nil, trace.Wrap(err) } @@ -187,40 +105,6 @@ func (s *SigningService) SignRequest(ctx context.Context, req *http.Request, sig return reqCopy, nil } -// TODO(gabrielcorado): once all service callers are updated to use -// AWSConfigProvider, make it required and remove session provider and -// credentials getter fallback. -func (s *SigningService) newSigner(ctx context.Context, signCtx *SigningCtx) (*v4.Signer, error) { - if s.AWSConfigProvider != nil { - awsCfg, err := s.AWSConfigProvider.GetConfig(ctx, signCtx.SigningRegion, - awsconfig.WithAssumeRole(signCtx.BaseAWSRoleARN, signCtx.BaseAWSExternalID), - awsconfig.WithAssumeRole(signCtx.AWSRoleArn, signCtx.AWSExternalID), - awsconfig.WithCredentialsMaybeIntegration(signCtx.Integration), - ) - if err != nil { - return nil, trace.Wrap(err) - } - return NewSignerV2(awsCfg.Credentials, signCtx.SigningName), nil - } - - session, err := s.SessionProvider(ctx, signCtx.SigningRegion, signCtx.Integration) - if err != nil { - return nil, trace.Wrap(err) - } - credentials, err := s.CredentialsGetter.Get(ctx, GetCredentialsRequest{ - Provider: session, - Expiry: signCtx.Expiry, - SessionName: signCtx.SessionName, - RoleARN: signCtx.AWSRoleArn, - ExternalID: signCtx.AWSExternalID, - Tags: signCtx.SessionTags, - }) - if err != nil { - return nil, trace.Wrap(err) - } - return NewSigner(credentials, signCtx.SigningName), nil -} - // removeUnsignedHeaders removes and returns header keys that are not included in SigV4 SignedHeaders. // If the request is not already signed, then no headers are removed. func removeUnsignedHeaders(reqCopy *http.Request) []string {