Skip to content

Commit

Permalink
feat: AWS Temporary credential support for SQS eventsource (#2092)
Browse files Browse the repository at this point in the history
* support aws sts credentials for sqs eventsource

Signed-off-by: Harshdeep Singh <[email protected]>
  • Loading branch information
harshdeep-23 authored and whynowy committed Sep 10, 2022
1 parent a1df80c commit 78044d6
Show file tree
Hide file tree
Showing 13 changed files with 576 additions and 420 deletions.
14 changes: 14 additions & 0 deletions api/event-source.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions api/event-source.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions api/jsonschema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2464,6 +2464,10 @@
"$ref": "#/definitions/io.k8s.api.core.v1.SecretKeySelector",
"description": "SecretKey refers K8s secret containing aws secret key"
},
"sessionToken": {
"$ref": "#/definitions/io.k8s.api.core.v1.SecretKeySelector",
"description": "SessionToken refers to K8s secret containing AWS temporary credentials(STS) session token"
},
"waitTimeSeconds": {
"description": "WaitTimeSeconds is The duration (in seconds) for which the call waits for a message to arrive in the queue before returning.",
"format": "int64",
Expand Down
4 changes: 4 additions & 0 deletions api/openapi-spec/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 13 additions & 3 deletions eventsources/common/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func GetAWSCredFromEnvironment(access *corev1.SecretKeySelector, secret *corev1.
}

// GetAWSCredFromVolume reads credential stored in mounted secret volume.
func GetAWSCredFromVolume(access *corev1.SecretKeySelector, secret *corev1.SecretKeySelector) (*credentials.Credentials, error) {
func GetAWSCredFromVolume(access *corev1.SecretKeySelector, secret *corev1.SecretKeySelector, sessionToken *corev1.SecretKeySelector) (*credentials.Credentials, error) {
accessKey, err := common.GetSecretFromVolume(access)
if err != nil {
return nil, errors.Wrap(err, "can not find access key")
Expand All @@ -53,9 +53,19 @@ func GetAWSCredFromVolume(access *corev1.SecretKeySelector, secret *corev1.Secre
if err != nil {
return nil, errors.Wrap(err, "can not find secret key")
}

var token string
if sessionToken != nil {
token, err = common.GetSecretFromVolume(sessionToken)
if err != nil {
return nil, errors.Wrap(err, "can not find session token")
}
}

return credentials.NewStaticCredentialsFromCreds(credentials.Value{
AccessKeyID: accessKey,
SecretAccessKey: secretKey,
SessionToken: token,
}), nil
}

Expand Down Expand Up @@ -97,7 +107,7 @@ func CreateAWSSessionWithCredsInEnv(region string, roleARN string, accessKey *co
}

// CreateAWSSessionWithCredsInVolume based on credentials in mounted volumes, return a aws session
func CreateAWSSessionWithCredsInVolume(region string, roleARN string, accessKey *corev1.SecretKeySelector, secretKey *corev1.SecretKeySelector) (*session.Session, error) {
func CreateAWSSessionWithCredsInVolume(region string, roleARN string, accessKey *corev1.SecretKeySelector, secretKey *corev1.SecretKeySelector, sessionToken *corev1.SecretKeySelector) (*session.Session, error) {
if roleARN != "" {
return GetAWSAssumeRoleCreds(roleARN, region)
}
Expand All @@ -106,7 +116,7 @@ func CreateAWSSessionWithCredsInVolume(region string, roleARN string, accessKey
return GetAWSSessionWithoutCreds(region)
}

creds, err := GetAWSCredFromVolume(accessKey, secretKey)
creds, err := GetAWSCredFromVolume(accessKey, secretKey, sessionToken)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion eventsources/sources/awssns/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func (router *Router) PostActivate() error {

snsEventSource := router.eventSource

awsSession, err := commonaws.CreateAWSSessionWithCredsInVolume(snsEventSource.Region, snsEventSource.RoleARN, snsEventSource.AccessKey, snsEventSource.SecretKey)
awsSession, err := commonaws.CreateAWSSessionWithCredsInVolume(snsEventSource.Region, snsEventSource.RoleARN, snsEventSource.AccessKey, snsEventSource.SecretKey, nil)
if err != nil {
return err
}
Expand Down
61 changes: 49 additions & 12 deletions eventsources/sources/awssqs/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
sqslib "github.com/aws/aws-sdk-go/service/sqs"
"github.com/pkg/errors"
Expand Down Expand Up @@ -68,19 +69,9 @@ func (el *EventListener) StartListening(ctx context.Context, dispatch func([]byt
defer sources.Recover(el.GetEventName())

sqsEventSource := &el.SQSEventSource
var awsSession *session.Session
awsSession, err := awscommon.CreateAWSSessionWithCredsInVolume(sqsEventSource.Region, sqsEventSource.RoleARN, sqsEventSource.AccessKey, sqsEventSource.SecretKey)
sqsClient, err := el.createSqsClient()
if err != nil {
log.Errorw("Error creating AWS credentials", zap.Error(err))
return errors.Wrapf(err, "failed to create aws session for %s", el.GetEventName())
}

var sqsClient *sqslib.SQS

if sqsEventSource.Endpoint == "" {
sqsClient = sqslib.New(awsSession)
} else {
sqsClient = sqslib.New(awsSession, &aws.Config{Endpoint: &sqsEventSource.Endpoint, Region: &sqsEventSource.Region})
return err
}

log.Info("fetching queue url...")
Expand Down Expand Up @@ -112,6 +103,17 @@ func (el *EventListener) StartListening(ctx context.Context, dispatch func([]byt
messages, err := fetchMessages(ctx, sqsClient, *queueURL.QueueUrl, 10, sqsEventSource.WaitTimeSeconds)
if err != nil {
log.Errorw("failed to get messages from SQS", zap.Error(err))
awsError, ok := err.(awserr.Error)
if ok && awsError.Code() == "ExpiredToken" && el.SQSEventSource.SessionToken != nil {
log.Info("credentials expired, reading credentials again")
newSqsClient, err := el.createSqsClient()
if err != nil {
log.Errorw("Error creating SQS client", zap.Error(err))
} else if newSqsClient != nil {
sqsClient = newSqsClient
}
}

time.Sleep(2 * time.Second)
continue
}
Expand All @@ -123,6 +125,16 @@ func (el *EventListener) StartListening(ctx context.Context, dispatch func([]byt
})
if err != nil {
log.Errorw("Failed to delete message", zap.Error(err))
awsError, ok := err.(awserr.Error)
if ok && awsError.Code() == "ExpiredToken" && el.SQSEventSource.SessionToken != nil {
log.Info("credentials expired, reading credentials again")
newSqsClient, err := el.createSqsClient()
if err != nil {
log.Errorw("Error creating SQS client", zap.Error(err))
} else if newSqsClient != nil {
sqsClient = newSqsClient
}
}
}
}, log)
}
Expand Down Expand Up @@ -185,3 +197,28 @@ func fetchMessages(ctx context.Context, q *sqslib.SQS, url string, maxSize, wait
}
return result.Messages, nil
}

func (el *EventListener) createAWSSession() (*session.Session, error) {
sqsEventSource := &el.SQSEventSource
awsSession, err := awscommon.CreateAWSSessionWithCredsInVolume(sqsEventSource.Region, sqsEventSource.RoleARN, sqsEventSource.AccessKey, sqsEventSource.SecretKey, sqsEventSource.SessionToken)
if err != nil {
return nil, errors.Wrapf(err, "failed to create aws session for %s", el.GetEventName())
}
return awsSession, nil
}

func (el *EventListener) createSqsClient() (*sqslib.SQS, error) {
awsSession, err := el.createAWSSession()
if err != nil {
return nil, err
}

var sqsClient *sqslib.SQS
if el.SQSEventSource.Endpoint == "" {
sqsClient = sqslib.New(awsSession)
} else {
sqsClient = sqslib.New(awsSession, &aws.Config{Endpoint: &el.SQSEventSource.Endpoint, Region: &el.SQSEventSource.Region})
}

return sqsClient, nil
}
Loading

0 comments on commit 78044d6

Please sign in to comment.