Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate AWS OpenSearch requests signing credential generation to AWS SDK V2 #51149

Merged
merged 8 commits into from
Jan 28, 2025
1 change: 0 additions & 1 deletion lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ func TestMain(m *testing.M) {
modules.SetInsecureTestMode(true)
registerTestSnowflakeEngine()
registerTestElasticsearchEngine()
registerTestOpenSearchEngine()
registerTestSQLServerEngine()
registerTestDynamoDBEngine()
os.Exit(m.Run())
Expand Down
28 changes: 8 additions & 20 deletions lib/srv/db/opensearch/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import (
"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/srv/db/common"
Expand All @@ -61,8 +60,6 @@ type Engine struct {
clientConn net.Conn
// sessionCtx is current session context.
sessionCtx *common.Session
// CredentialsGetter is used to obtain STS credentials.
CredentialsGetter libaws.CredentialsGetter
}

// InitializeConnection initializes the engine with the client connection.
Expand Down Expand Up @@ -137,19 +134,9 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error
e.Audit.OnSessionStart(e.Context, e.sessionCtx, err)
return trace.Wrap(err)
}

meta := e.sessionCtx.Database.GetAWS()
awsSession, err := e.CloudClients.GetAWSSession(ctx, meta.Region,
cloud.WithAssumeRoleFromAWSMeta(meta),
cloud.WithAmbientCredentials(),
)
if err != nil {
return trace.Wrap(err)
}
signer, err := libaws.NewSigningService(libaws.SigningServiceConfig{
Clock: e.Clock,
SessionProvider: libaws.StaticAWSSessionProvider(awsSession),
CredentialsGetter: e.CredentialsGetter,
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
AWSConfigProvider: e.AWSConfigProvider,
})
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -246,12 +233,13 @@ func (e *Engine) getSignedRequest(signer *libaws.SigningService, reqCopy *http.R
}

signCtx := &libaws.SigningCtx{
SigningName: opensearchservice.EndpointsID,
SigningRegion: e.sessionCtx.Database.GetAWS().Region,
Expiry: e.sessionCtx.Identity.Expires,
SessionName: e.sessionCtx.Identity.Username,
AWSRoleArn: roleArn,
AWSExternalID: e.sessionCtx.Database.GetAWS().ExternalID,
SigningName: opensearchservice.EndpointsID,
SigningRegion: e.sessionCtx.Database.GetAWS().Region,
Expiry: e.sessionCtx.Identity.Expires,
SessionName: e.sessionCtx.Identity.Username,
BaseAWSRoleARN: e.sessionCtx.Database.GetAWS().AssumeRoleARN,
AWSRoleArn: roleArn,
AWSExternalID: e.sessionCtx.Database.GetAWS().ExternalID,
}

signedReq, err := signer.SignRequest(e.Context, reqCopy, signCtx)
Expand Down
29 changes: 11 additions & 18 deletions lib/srv/db/opensearch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,24 @@ import (
"net"
"testing"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/cloud/mocks"
"github.com/gravitational/teleport/lib/defaults"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/opensearch"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
)

func registerTestOpenSearchEngine() {
common.RegisterEngine(newTestOpenSearchEngine, defaults.ProtocolOpenSearch)
}

func newTestOpenSearchEngine(ec common.EngineConfig) common.Engine {
return &opensearch.Engine{
EngineConfig: ec,
// inject mock AWS credentials.
CredentialsGetter: awsutils.NewStaticCredentialsGetter(
credentials.NewStaticCredentials("AKIDl", "SECRET", "SESSION"),
),
}
}

func TestAccessOpenSearch(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t, withOpenSearch("OpenSearch"))
testCtx := setupTestContext(ctx, t)
testCtx.server = testCtx.setupDatabaseServer(ctx, t, agentParams{
Databases: []types.Database{withOpenSearch("OpenSearch")(t, ctx, testCtx)},
AWSConfigProvider: &mocks.AWSConfigProvider{},
})
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
go testCtx.startHandlingConnections()

tests := []struct {
Expand Down Expand Up @@ -151,7 +140,11 @@ func TestAccessOpenSearch(t *testing.T) {

func TestAuditOpenSearch(t *testing.T) {
ctx := context.Background()
testCtx := setupTestContext(ctx, t, withOpenSearch("OpenSearch"))
testCtx := setupTestContext(ctx, t)
testCtx.server = testCtx.setupDatabaseServer(ctx, t, agentParams{
Databases: []types.Database{withOpenSearch("OpenSearch")(t, ctx, testCtx)},
AWSConfigProvider: &mocks.AWSConfigProvider{},
})
go testCtx.startHandlingConnections()

testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{"admin"}, []string{types.Wildcard})
Expand Down
77 changes: 54 additions & 23 deletions lib/utils/aws/signing.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ import (
"net/http"
"time"

v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -56,27 +58,32 @@ type SigningServiceConfig struct {
Clock clockwork.Clock
// CredentialsGetter is used to obtain STS credentials.
CredentialsGetter CredentialsGetter
// AWSConfigProvider is a provider for AWS configs.
AWSConfigProvider awsconfig.Provider
}

// CheckAndSetDefaults validates the SigningServiceConfig config.
func (s *SigningServiceConfig) CheckAndSetDefaults() error {
if s.Clock == nil {
s.Clock = clockwork.NewRealClock()
}
if s.SessionProvider == nil {
return trace.BadParameter("session provider is required")
}
if s.CredentialsGetter == nil {
// Use cachedCredentialsGetter by default. cachedCredentialsGetter
// caches the credentials for one minute.
cachedGetter, err := NewCachedCredentialsGetter(CachedCredentialsGetterConfig{
Clock: s.Clock,
})
if err != nil {
return trace.Wrap(err)
}

s.CredentialsGetter = cachedGetter
if s.AWSConfigProvider == nil {
if s.SessionProvider == nil {
return trace.BadParameter("session provider or config provider is required")
}
if s.CredentialsGetter == nil {
// Use cachedCredentialsGetter by default. cachedCredentialsGetter
// caches the credentials for one minute.
cachedGetter, err := NewCachedCredentialsGetter(CachedCredentialsGetterConfig{
Clock: s.Clock,
})
if err != nil {
return trace.Wrap(err)
}

s.CredentialsGetter = cachedGetter
}
}
return nil
}
Expand All @@ -91,7 +98,10 @@ type SigningCtx struct {
Expiry time.Time
// SessionName is role session name of AWS credentials used to sign requests.
SessionName string
// AWSRoleArn is the AWS ARN of the role to assume for signing requests.
// BaseAWSRoleARN is the AWS ARN of the role as a base to the assumed roles.
BaseAWSRoleARN string
// AWSRoleArn is the AWS ARN of the role to assume for signing requests,
// chained with BaseAWSRoleARN.
AWSRoleArn string
// AWSExternalID is an optional external ID used when getting sts credentials.
AWSExternalID string
Expand Down Expand Up @@ -161,6 +171,35 @@ func (s *SigningService) SignRequest(ctx context.Context, req *http.Request, sig
// 100-continue" headers without being signed, otherwise the Athena service
// would reject the requests.
unsignedHeaders := removeUnsignedHeaders(reqCopy)
signer, err := s.newSigner(ctx, signCtx)
if err != nil {
return nil, trace.Wrap(err)
}
_, err = signer.Sign(reqCopy, bytes.NewReader(payload), signCtx.SigningName, signCtx.SigningRegion, s.Clock.Now())
if err != nil {
return nil, trace.Wrap(err)
}

// copy removed headers back to the request after signing it, but don't copy the old Authorization header.
copyHeaders(reqCopy, req, utils.RemoveFromSlice(unsignedHeaders, "Authorization"))
return reqCopy, nil
}

// TODO(gabrielcorado): once all service callers are updated to use
// AWSConfigProvider, make it required and remove session provider and
// credentials getter fallback.
func (s *SigningService) newSigner(ctx context.Context, signCtx *SigningCtx) (*v4.Signer, error) {
Comment on lines +190 to +193
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I like how this makes it easy to migrate each dependent service in small PRs

if s.AWSConfigProvider != nil {
awsCfg, err := s.AWSConfigProvider.GetConfig(ctx, signCtx.SigningRegion,
awsconfig.WithAssumeRole(signCtx.BaseAWSRoleARN, signCtx.AWSExternalID),
awsconfig.WithAssumeRole(signCtx.AWSRoleArn, signCtx.AWSExternalID),
awsconfig.WithCredentialsMaybeIntegration(signCtx.Integration),
)
if err != nil {
return nil, trace.Wrap(err)
}
return NewSignerV2(awsCfg.Credentials, signCtx.SigningName), nil
}

session, err := s.SessionProvider(ctx, signCtx.SigningRegion, signCtx.Integration)
if err != nil {
Expand All @@ -177,15 +216,7 @@ func (s *SigningService) SignRequest(ctx context.Context, req *http.Request, sig
if err != nil {
return nil, trace.Wrap(err)
}
signer := NewSigner(credentials, signCtx.SigningName)
_, err = signer.Sign(reqCopy, bytes.NewReader(payload), signCtx.SigningName, signCtx.SigningRegion, s.Clock.Now())
if err != nil {
return nil, trace.Wrap(err)
}

// copy removed headers back to the request after signing it, but don't copy the old Authorization header.
copyHeaders(reqCopy, req, utils.RemoveFromSlice(unsignedHeaders, "Authorization"))
return reqCopy, nil
return NewSigner(credentials, signCtx.SigningName), nil
}

// removeUnsignedHeaders removes and returns header keys that are not included in SigV4 SignedHeaders.
Expand Down
Loading