Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify lib/utils/aws #51627

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions integration/proxy/proxy_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jackc/pgconn"
Expand Down Expand Up @@ -605,7 +605,7 @@ func mustParseURL(t *testing.T, rawURL string) *url.URL {
type fakeSTSClient struct {
accountID string
arn string
credentials *credentials.Credentials
credentials aws.CredentialsProvider
}

func (f fakeSTSClient) Do(req *http.Request) (*http.Response, error) {
Expand Down Expand Up @@ -640,10 +640,10 @@ func mustCreateIAMJoinProvisionToken(t *testing.T, name, awsAccountID, allowedAR
return provisionToken
}

func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token string, credentials *credentials.Credentials) {
func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token string, credentials aws.CredentialsProvider) {
t.Helper()

cred, err := credentials.Get()
cred, err := credentials.Retrieve(context.Background())
require.NoError(t, err)

t.Setenv("AWS_ACCESS_KEY_ID", cred.AccessKeyID)
Expand Down
4 changes: 2 additions & 2 deletions integration/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
Expand Down Expand Up @@ -1739,7 +1739,7 @@ func TestALPNSNIProxyGRPCInsecure(t *testing.T) {

nodeAccount := "123456789012"
nodeRoleARN := "arn:aws:iam::123456789012:role/test"
nodeCredentials := credentials.NewStaticCredentials("FAKE_ID", "FAKE_KEY", "FAKE_TOKEN")
nodeCredentials := credentials.NewStaticCredentialsProvider("FAKE_ID", "FAKE_KEY", "FAKE_TOKEN")
provisionToken := mustCreateIAMJoinProvisionToken(t, "iam-join-token", nodeAccount, nodeRoleARN)

suite := newSuite(t,
Expand Down
65 changes: 64 additions & 1 deletion lib/cloud/awsconfig/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@ package awsconfig

import (
"context"
"crypto/sha1"
"encoding/hex"
"fmt"
"log/slog"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/aws/smithy-go/tracing/smithyoteltracing"
"github.com/gravitational/trace"
"go.opentelemetry.io/otel"
Expand Down Expand Up @@ -69,7 +73,12 @@ type AssumeRole struct {
// RoleARN is the ARN of the role to assume.
RoleARN string `json:"role_arn"`
// ExternalID is an optional ID to include when assuming the role.
ExternalID string `json:"external_id"`
ExternalID string `json:"external_id,omitempty"`
// SessionName is an optional session name to use when assuming the role.
SessionName string `json:"session_name,omitempty"`
// Tags is a list of STS session tags to pass when assuming the role.
// https://docs.aws.amazon.com/IAM/latest/UserGuide/id_session-tags.html
Tags map[string]string `json:"tags,omitempty"`
}

// options is a struct of additional options for assuming an AWS role
Expand Down Expand Up @@ -153,6 +162,18 @@ func WithAssumeRole(roleARN, externalID string) OptionsFn {
}
}

// WithDetailedAssumeRole configures options needed for assuming an AWS role,
// including optional details like session name, duration, and tags.
func WithDetailedAssumeRole(ar AssumeRole) OptionsFn {
return func(options *options) {
if ar.RoleARN == "" {
// ignore empty role ARN for caller convenience.
return
}
options.assumeRoles = append(options.assumeRoles, ar)
}
}

// WithRetryer sets a custom retryer for the config.
func WithRetryer(retryer func() aws.Retryer) OptionsFn {
return func(options *options) {
Expand Down Expand Up @@ -302,6 +323,13 @@ func getAssumeRoleProvider(ctx context.Context, clt stscreds.AssumeRoleAPIClient
if role.ExternalID != "" {
aro.ExternalID = aws.String(role.ExternalID)
}
aro.RoleSessionName = maybeHashRoleSessionName(role.SessionName)
for k, v := range role.Tags {
aro.Tags = append(aro.Tags, ststypes.Tag{
Key: aws.String(k),
Value: aws.String(v),
})
}
})
}

Expand Down Expand Up @@ -342,3 +370,38 @@ func (p *integrationCredentialsProvider) Retrieve(ctx context.Context) (aws.Cred
).Retrieve(ctx)
return cred, trace.Wrap(err)
}

// maybeHashRoleSessionName truncates the role session name and adds a hash
// when the original role session name is greater than AWS character limit
// (64).
func maybeHashRoleSessionName(roleSessionName string) (ret string) {
// maxRoleSessionNameLength is the maximum length of the role session name
// used by the AssumeRole call.
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_iam-quotas.html
const maxRoleSessionNameLength = 64
if len(roleSessionName) <= maxRoleSessionNameLength {
return roleSessionName
}

const hashLen = 16
hash := sha1.New()
hash.Write([]byte(roleSessionName))
hex := hex.EncodeToString(hash.Sum(nil))[:hashLen]

// "1" for the delimiter.
keepPrefixIndex := maxRoleSessionNameLength - len(hex) - 1

// Sanity check. This should never happen since hash length and
// MaxRoleSessionNameLength are both constant.
if keepPrefixIndex < 0 {
keepPrefixIndex = 0
}

ret = fmt.Sprintf("%s-%s", roleSessionName[:keepPrefixIndex], hex)
slog.DebugContext(context.Background(),
"AWS role session name is too long. Using a hash instead.",
"hashed", ret,
"original", roleSessionName,
)
return ret
}
4 changes: 2 additions & 2 deletions lib/cloud/awsconfig/awsconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,12 @@ func testGetConfigIntegration(t *testing.T, provider Provider) {
func TestNewCacheKey(t *testing.T) {
roleChain := []AssumeRole{
{RoleARN: "roleA"},
{RoleARN: "roleB", ExternalID: "abc123"},
{RoleARN: "roleB", ExternalID: "abc123", SessionName: "alice", Tags: map[string]string{"AKey": "AValue"}},
}
got, err := newCacheKey("integration-name", roleChain...)
require.NoError(t, err)
want := strings.TrimSpace(`
{"integration":"integration-name","role_chain":[{"role_arn":"roleA","external_id":""},{"role_arn":"roleB","external_id":"abc123"}]}
{"integration":"integration-name","role_chain":[{"role_arn":"roleA"},{"role_arn":"roleB","external_id":"abc123","session_name":"alice","tags":{"AKey":"AValue"}}]}
`)
require.Equal(t, want, got)
}
Expand Down
13 changes: 13 additions & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ import (
pgrepl "github.com/gravitational/teleport/lib/client/db/postgres/repl"
dbrepl "github.com/gravitational/teleport/lib/client/db/repl"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/cloud/gcp"
"github.com/gravitational/teleport/lib/cloud/imds"
awsimds "github.com/gravitational/teleport/lib/cloud/imds/aws"
Expand Down Expand Up @@ -4889,6 +4890,12 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {

return awsoidc.NewSessionV1(ctx, conn.Client, region, integration)
}
awsConfigProvider, err := awsconfig.NewCache(awsconfig.WithDefaults(
awsconfig.WithOIDCIntegrationClient(conn.Client),
))
if err != nil {
return trace.Wrap(err, "unable to create AWS config provider cache")
}

connectionsHandler, err := app.NewConnectionsHandler(process.GracefulExitContext(), &app.ConnectionsHandlerConfig{
Clock: process.Clock,
Expand All @@ -4903,6 +4910,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
CipherSuites: cfg.CipherSuites,
ServiceComponent: teleport.ComponentWebProxy,
AWSSessionProvider: awsSessionGetter,
AWSConfigProvider: awsConfigProvider,
})
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -6218,6 +6226,10 @@ func (process *TeleportProcess) initApps() {
return trace.Wrap(err)
}

awsConfigProvider, err := awsconfig.NewCache()
if err != nil {
return trace.Wrap(err, "unable to create AWS config provider cache")
}
connectionsHandler, err := app.NewConnectionsHandler(process.ExitContext(), &app.ConnectionsHandlerConfig{
Clock: process.Config.Clock,
DataDir: process.Config.DataDir,
Expand All @@ -6232,6 +6244,7 @@ func (process *TeleportProcess) initApps() {
ServiceComponent: teleport.ComponentApp,
Logger: logger,
AWSSessionProvider: awsutils.SessionProviderUsingAmbientCredentials(),
AWSConfigProvider: awsConfigProvider,
})
if err != nil {
return trace.Wrap(err)
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/alpnproxy/aws_local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Re
}

func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *http.Request) bool {
if err := awsutils.VerifyAWSSignatureV2(req, m.AWSCredentialsProvider); err != nil {
if err := awsutils.VerifyAWSSignature(req, m.AWSCredentialsProvider); err != nil {
m.Log.ErrorContext(req.Context(), "AWS signature verification failed", "error", err)
rw.WriteHeader(http.StatusForbidden)
return true
Expand All @@ -149,7 +149,7 @@ func (m *AWSAccessMiddleware) handleRequestByAssumedRole(rw http.ResponseWriter,
aws.ToString(assumedRole.Credentials.SessionToken),
)

if err := awsutils.VerifyAWSSignatureV2(req, credentials); err != nil {
if err := awsutils.VerifyAWSSignature(req, credentials); err != nil {
m.Log.ErrorContext(req.Context(), "AWS signature verification failed", "error", err)
rw.WriteHeader(http.StatusForbidden)
return true
Expand Down
30 changes: 20 additions & 10 deletions lib/srv/app/aws/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/httplib/reverseproxy"
Expand All @@ -53,12 +54,12 @@ type signerHandler struct {

// SignerHandlerConfig is the awsSignerHandler configuration.
type SignerHandlerConfig struct {
// AWSConfigProvider provides [aws.Config] for AWS SDK service clients.
AWSConfigProvider awsconfig.Provider
// Log is a logger for the handler.
Log *slog.Logger
// RoundTripper is an http.RoundTripper instance used for requests.
RoundTripper http.RoundTripper
// SigningService is used to sign requests before forwarding them.
*awsutils.SigningService
// Clock is used to override time in tests.
Clock clockwork.Clock
// MaxHTTPRequestBodySize is the limit on how big a request body can be.
Expand All @@ -67,8 +68,8 @@ type SignerHandlerConfig struct {

// CheckAndSetDefaults validates the AwsSignerHandlerConfig.
func (cfg *SignerHandlerConfig) CheckAndSetDefaults() error {
if cfg.SigningService == nil {
return trace.BadParameter("missing SigningService")
if cfg.AWSConfigProvider == nil {
return trace.BadParameter("aws config provider missing")
}
if cfg.RoundTripper == nil {
tr, err := defaults.Transport()
Expand Down Expand Up @@ -165,15 +166,24 @@ func (s *signerHandler) serveCommonRequest(sessCtx *common.SessionContext, w htt
return trace.Wrap(err)
}

signedReq, err := s.SignRequest(s.closeContext, unsignedReq,
awsCfg, err := s.AWSConfigProvider.GetConfig(s.closeContext, re.SigningRegion,
awsconfig.WithDetailedAssumeRole(awsconfig.AssumeRole{
RoleARN: sessCtx.Identity.RouteToApp.AWSRoleARN,
ExternalID: sessCtx.App.GetAWSExternalID(),
SessionName: sessCtx.Identity.Username,
}),
awsconfig.WithCredentialsMaybeIntegration(sessCtx.App.GetIntegration()),
)
if err != nil {
return trace.Wrap(err)
}

signedReq, err := awsutils.SignRequest(s.closeContext, unsignedReq,
&awsutils.SigningCtx{
Clock: s.Clock,
Credentials: awsCfg.Credentials,
SigningName: re.SigningName,
SigningRegion: re.SigningRegion,
Expiry: sessCtx.Identity.Expires,
SessionName: sessCtx.Identity.Username,
AWSRoleArn: sessCtx.Identity.RouteToApp.AWSRoleARN,
AWSExternalID: sessCtx.App.GetAWSExternalID(),
Integration: sessCtx.App.GetIntegration(),
})
if err != nil {
return trace.Wrap(err)
Expand Down
Loading
Loading