diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index cf85bd913be74..73a9aa2405467 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -107,7 +107,6 @@ func TestMain(m *testing.M) { modules.SetInsecureTestMode(true) registerTestSnowflakeEngine() registerTestElasticsearchEngine() - registerTestOpenSearchEngine() registerTestSQLServerEngine() registerTestDynamoDBEngine() os.Exit(m.Run()) diff --git a/lib/srv/db/opensearch/engine.go b/lib/srv/db/opensearch/engine.go index cfb016230fc31..354610bc6da36 100644 --- a/lib/srv/db/opensearch/engine.go +++ b/lib/srv/db/opensearch/engine.go @@ -35,7 +35,6 @@ import ( "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/srv/db/common" @@ -61,8 +60,6 @@ type Engine struct { clientConn net.Conn // sessionCtx is current session context. sessionCtx *common.Session - // CredentialsGetter is used to obtain STS credentials. - CredentialsGetter libaws.CredentialsGetter } // InitializeConnection initializes the engine with the client connection. @@ -137,19 +134,9 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error e.Audit.OnSessionStart(e.Context, e.sessionCtx, err) return trace.Wrap(err) } - - meta := e.sessionCtx.Database.GetAWS() - awsSession, err := e.CloudClients.GetAWSSession(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), - ) - if err != nil { - return trace.Wrap(err) - } signer, err := libaws.NewSigningService(libaws.SigningServiceConfig{ Clock: e.Clock, - SessionProvider: libaws.StaticAWSSessionProvider(awsSession), - CredentialsGetter: e.CredentialsGetter, + AWSConfigProvider: e.AWSConfigProvider, }) if err != nil { return trace.Wrap(err) @@ -245,13 +232,22 @@ func (e *Engine) getSignedRequest(signer *libaws.SigningService, reqCopy *http.R return nil, trace.Wrap(err) } + meta := e.sessionCtx.Database.GetAWS() signCtx := &libaws.SigningCtx{ - SigningName: opensearchservice.EndpointsID, - SigningRegion: e.sessionCtx.Database.GetAWS().Region, - Expiry: e.sessionCtx.Identity.Expires, - SessionName: e.sessionCtx.Identity.Username, - AWSRoleArn: roleArn, - AWSExternalID: e.sessionCtx.Database.GetAWS().ExternalID, + SigningName: opensearchservice.EndpointsID, + SigningRegion: meta.Region, + Expiry: e.sessionCtx.Identity.Expires, + SessionName: e.sessionCtx.Identity.Username, + BaseAWSRoleARN: meta.AssumeRoleARN, + BaseAWSExternalID: meta.ExternalID, + AWSRoleArn: roleArn, + // OpenSearch uses meta.ExternalID for both the base role and the + // assumed role. + AWSExternalID: meta.ExternalID, + } + + if meta.AssumeRoleARN == "" { + signCtx.AWSExternalID = meta.ExternalID } signedReq, err := signer.SignRequest(e.Context, reqCopy, signCtx) diff --git a/lib/srv/db/opensearch_test.go b/lib/srv/db/opensearch_test.go index 96d2c3237ef66..190256dc827da 100644 --- a/lib/srv/db/opensearch_test.go +++ b/lib/srv/db/opensearch_test.go @@ -25,35 +25,24 @@ import ( "net" "testing" - "github.com/aws/aws-sdk-go/aws/credentials" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/defaults" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/opensearch" - awsutils "github.com/gravitational/teleport/lib/utils/aws" ) -func registerTestOpenSearchEngine() { - common.RegisterEngine(newTestOpenSearchEngine, defaults.ProtocolOpenSearch) -} - -func newTestOpenSearchEngine(ec common.EngineConfig) common.Engine { - return &opensearch.Engine{ - EngineConfig: ec, - // inject mock AWS credentials. - CredentialsGetter: awsutils.NewStaticCredentialsGetter( - credentials.NewStaticCredentials("AKIDl", "SECRET", "SESSION"), - ), - } -} - func TestAccessOpenSearch(t *testing.T) { ctx := context.Background() - testCtx := setupTestContext(ctx, t, withOpenSearch("OpenSearch")) + testCtx := setupTestContext(ctx, t) + testCtx.server = testCtx.setupDatabaseServer(ctx, t, agentParams{ + Databases: []types.Database{withOpenSearch("OpenSearch")(t, ctx, testCtx)}, + AWSConfigProvider: &mocks.AWSConfigProvider{}, + }) go testCtx.startHandlingConnections() tests := []struct { @@ -151,7 +140,11 @@ func TestAccessOpenSearch(t *testing.T) { func TestAuditOpenSearch(t *testing.T) { ctx := context.Background() - testCtx := setupTestContext(ctx, t, withOpenSearch("OpenSearch")) + testCtx := setupTestContext(ctx, t) + testCtx.server = testCtx.setupDatabaseServer(ctx, t, agentParams{ + Databases: []types.Database{withOpenSearch("OpenSearch")(t, ctx, testCtx)}, + AWSConfigProvider: &mocks.AWSConfigProvider{}, + }) go testCtx.startHandlingConnections() testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{"admin"}, []string{types.Wildcard}) diff --git a/lib/utils/aws/signing.go b/lib/utils/aws/signing.go index 32a1d55f046f1..6549265aed676 100644 --- a/lib/utils/aws/signing.go +++ b/lib/utils/aws/signing.go @@ -25,9 +25,11 @@ import ( "net/http" "time" + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/utils" ) @@ -56,6 +58,8 @@ type SigningServiceConfig struct { Clock clockwork.Clock // CredentialsGetter is used to obtain STS credentials. CredentialsGetter CredentialsGetter + // AWSConfigProvider is a provider for AWS configs. + AWSConfigProvider awsconfig.Provider } // CheckAndSetDefaults validates the SigningServiceConfig config. @@ -63,20 +67,23 @@ func (s *SigningServiceConfig) CheckAndSetDefaults() error { if s.Clock == nil { s.Clock = clockwork.NewRealClock() } - if s.SessionProvider == nil { - return trace.BadParameter("session provider is required") - } - if s.CredentialsGetter == nil { - // Use cachedCredentialsGetter by default. cachedCredentialsGetter - // caches the credentials for one minute. - cachedGetter, err := NewCachedCredentialsGetter(CachedCredentialsGetterConfig{ - Clock: s.Clock, - }) - if err != nil { - return trace.Wrap(err) - } - s.CredentialsGetter = cachedGetter + if s.AWSConfigProvider == nil { + if s.SessionProvider == nil { + return trace.BadParameter("session provider or config provider is required") + } + if s.CredentialsGetter == nil { + // Use cachedCredentialsGetter by default. cachedCredentialsGetter + // caches the credentials for one minute. + cachedGetter, err := NewCachedCredentialsGetter(CachedCredentialsGetterConfig{ + Clock: s.Clock, + }) + if err != nil { + return trace.Wrap(err) + } + + s.CredentialsGetter = cachedGetter + } } return nil } @@ -91,7 +98,12 @@ type SigningCtx struct { Expiry time.Time // SessionName is role session name of AWS credentials used to sign requests. SessionName string - // AWSRoleArn is the AWS ARN of the role to assume for signing requests. + // BaseAWSRoleARN is the AWS ARN of the role as a base to the assumed roles. + BaseAWSRoleARN string + // BaseAWSRoleARN is an optional external ID used on base assumed role. + BaseAWSExternalID string + // AWSRoleArn is the AWS ARN of the role to assume for signing requests, + // chained with BaseAWSRoleARN. AWSRoleArn string // AWSExternalID is an optional external ID used when getting sts credentials. AWSExternalID string @@ -161,6 +173,35 @@ func (s *SigningService) SignRequest(ctx context.Context, req *http.Request, sig // 100-continue" headers without being signed, otherwise the Athena service // would reject the requests. unsignedHeaders := removeUnsignedHeaders(reqCopy) + signer, err := s.newSigner(ctx, signCtx) + if err != nil { + return nil, trace.Wrap(err) + } + _, err = signer.Sign(reqCopy, bytes.NewReader(payload), signCtx.SigningName, signCtx.SigningRegion, s.Clock.Now()) + if err != nil { + return nil, trace.Wrap(err) + } + + // copy removed headers back to the request after signing it, but don't copy the old Authorization header. + copyHeaders(reqCopy, req, utils.RemoveFromSlice(unsignedHeaders, "Authorization")) + return reqCopy, nil +} + +// TODO(gabrielcorado): once all service callers are updated to use +// AWSConfigProvider, make it required and remove session provider and +// credentials getter fallback. +func (s *SigningService) newSigner(ctx context.Context, signCtx *SigningCtx) (*v4.Signer, error) { + if s.AWSConfigProvider != nil { + awsCfg, err := s.AWSConfigProvider.GetConfig(ctx, signCtx.SigningRegion, + awsconfig.WithAssumeRole(signCtx.BaseAWSRoleARN, signCtx.BaseAWSExternalID), + awsconfig.WithAssumeRole(signCtx.AWSRoleArn, signCtx.AWSExternalID), + awsconfig.WithCredentialsMaybeIntegration(signCtx.Integration), + ) + if err != nil { + return nil, trace.Wrap(err) + } + return NewSignerV2(awsCfg.Credentials, signCtx.SigningName), nil + } session, err := s.SessionProvider(ctx, signCtx.SigningRegion, signCtx.Integration) if err != nil { @@ -177,15 +218,7 @@ func (s *SigningService) SignRequest(ctx context.Context, req *http.Request, sig if err != nil { return nil, trace.Wrap(err) } - signer := NewSigner(credentials, signCtx.SigningName) - _, err = signer.Sign(reqCopy, bytes.NewReader(payload), signCtx.SigningName, signCtx.SigningRegion, s.Clock.Now()) - if err != nil { - return nil, trace.Wrap(err) - } - - // copy removed headers back to the request after signing it, but don't copy the old Authorization header. - copyHeaders(reqCopy, req, utils.RemoveFromSlice(unsignedHeaders, "Authorization")) - return reqCopy, nil + return NewSigner(credentials, signCtx.SigningName), nil } // removeUnsignedHeaders removes and returns header keys that are not included in SigV4 SignedHeaders.