Skip to content

Commit

Permalink
Merge pull request #6 from twharmon/services
Browse files Browse the repository at this point in the history
Add STS and SSM services
  • Loading branch information
twharmon authored Apr 8, 2022
2 parents e44583c + 4b7a1d4 commit 39f4cda
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 0 deletions.
52 changes: 52 additions & 0 deletions serviceprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import (
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/ses"
"github.com/aws/aws-sdk-go/service/ses/sesiface"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

type awsServiceProvider struct {
Expand All @@ -17,19 +21,25 @@ type awsServiceProvider struct {
dynamodb *dynamodb.DynamoDB
ses *ses.SES
s3 *s3.S3
sts *sts.STS
ssm *ssm.SSM
}

type AWSServiceProviderConfig struct {
Default *aws.Config
DynamoDB *aws.Config
SES *aws.Config
S3 *aws.Config
STS *aws.Config
SSM *aws.Config
}

type AWSServiceProvider interface {
DynamoDB() dynamodbiface.DynamoDBAPI
SES() sesiface.SESAPI
S3() s3iface.S3API
STS() stsiface.STSAPI
SSM() ssmiface.SSMAPI
}

func (sp *awsServiceProvider) loadSession() {
Expand Down Expand Up @@ -91,6 +101,38 @@ func (sp *awsServiceProvider) loadS3() {
sp.s3 = s3.New(sp.session)
}

func (sp *awsServiceProvider) loadSSM() {
if sp.ssm != nil {
return
}
sp.loadSession()
if sp.config.SSM != nil {
sp.ssm = ssm.New(sp.session, sp.config.SSM)
return
}
if sp.config.Default != nil {
sp.ssm = ssm.New(sp.session, sp.config.Default)
return
}
sp.ssm = ssm.New(sp.session)
}

func (sp *awsServiceProvider) loadSTS() {
if sp.sts != nil {
return
}
sp.loadSession()
if sp.config.STS != nil {
sp.sts = sts.New(sp.session, sp.config.STS)
return
}
if sp.config.Default != nil {
sp.sts = sts.New(sp.session, sp.config.Default)
return
}
sp.sts = sts.New(sp.session)
}

func (sp *awsServiceProvider) DynamoDB() dynamodbiface.DynamoDBAPI {
sp.loadDynamoDB()
return sp.dynamodb
Expand All @@ -105,3 +147,13 @@ func (sp *awsServiceProvider) S3() s3iface.S3API {
sp.loadS3()
return sp.s3
}

func (sp *awsServiceProvider) STS() stsiface.STSAPI {
sp.loadSTS()
return sp.sts
}

func (sp *awsServiceProvider) SSM() ssmiface.SSMAPI {
sp.loadSSM()
return sp.ssm
}
128 changes: 128 additions & 0 deletions serviceprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/ses"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/sts"
)

func TestServiceProviderDBWithConfig(t *testing.T) {
Expand Down Expand Up @@ -208,3 +210,129 @@ func TestServiceProviderSessionCached(t *testing.T) {
t.Fatalf("expected not nil")
}
}

func TestServiceProviderSSMWithConfig(t *testing.T) {
sp := &awsServiceProvider{
config: &AWSServiceProviderConfig{
SSM: &aws.Config{
Region: aws.String("bar"),
},
},
}
svc := sp.SSM()
ssm, ok := svc.(*ssm.SSM)
if !ok {
t.Fatalf("expected ok")
}
if ssm == nil {
t.Fatalf("expected not nil")
}
}

func TestServiceProviderSSMDefaultSPConfig(t *testing.T) {
sp := &awsServiceProvider{config: &AWSServiceProviderConfig{
Default: &aws.Config{
Region: aws.String("bar"),
},
}}
svc := sp.SSM()
ssm, ok := svc.(*ssm.SSM)
if !ok {
t.Fatalf("expected ok")
}
if ssm == nil {
t.Fatalf("expected not nil")
}
}

func TestServiceProviderSSMNoSPConfig(t *testing.T) {
c := &handlerContext{
sp: &awsServiceProvider{config: &AWSServiceProviderConfig{}},
}
svc := c.AWS().SSM()
ssm, ok := svc.(*ssm.SSM)
if !ok {
t.Fatalf("expected ok")
}
if ssm == nil {
t.Fatalf("expected not nil")
}
}

func TestServiceProviderSSMCached(t *testing.T) {
c := &handlerContext{
sp: &awsServiceProvider{config: &AWSServiceProviderConfig{}},
}
svc := c.AWS().SSM()
svc = c.AWS().SSM()
ssm, ok := svc.(*ssm.SSM)
if !ok {
t.Fatalf("expected ok")
}
if ssm == nil {
t.Fatalf("expected not nil")
}
}

func TestServiceProviderSTSWithConfig(t *testing.T) {
sp := &awsServiceProvider{
config: &AWSServiceProviderConfig{
STS: &aws.Config{
Region: aws.String("bar"),
},
},
}
svc := sp.STS()
sts, ok := svc.(*sts.STS)
if !ok {
t.Fatalf("expected ok")
}
if sts == nil {
t.Fatalf("expected not nil")
}
}

func TestServiceProviderSTSDefaultSPConfig(t *testing.T) {
sp := &awsServiceProvider{config: &AWSServiceProviderConfig{
Default: &aws.Config{
Region: aws.String("bar"),
},
}}
svc := sp.STS()
sts, ok := svc.(*sts.STS)
if !ok {
t.Fatalf("expected ok")
}
if sts == nil {
t.Fatalf("expected not nil")
}
}

func TestServiceProviderSTSNoSPConfig(t *testing.T) {
c := &handlerContext{
sp: &awsServiceProvider{config: &AWSServiceProviderConfig{}},
}
svc := c.AWS().STS()
sts, ok := svc.(*sts.STS)
if !ok {
t.Fatalf("expected ok")
}
if sts == nil {
t.Fatalf("expected not nil")
}
}

func TestServiceProviderSTSCached(t *testing.T) {
c := &handlerContext{
sp: &awsServiceProvider{config: &AWSServiceProviderConfig{}},
}
svc := c.AWS().STS()
svc = c.AWS().STS()
sts, ok := svc.(*sts.STS)
if !ok {
t.Fatalf("expected ok")
}
if sts == nil {
t.Fatalf("expected not nil")
}
}

0 comments on commit 39f4cda

Please sign in to comment.