Skip to content
This repository has been archived by the owner on May 18, 2021. It is now read-only.

Commit

Permalink
fix: make GetRoleARN support GovCloud (#220)
Browse files Browse the repository at this point in the history
GovCloud support was added in #197 / b4e7839 and #204 / 9787f11.
Taking the role as an env var was added in #208 / 396d453, which added
GetRoleARN which ignored region and, thus, broke GovCloud.
Both of these were released as v0.23.0, so no version of aws-okta
supported GovCloud.

Separately, GetRoleARN was split into a lib function and a Provider
function in #218 / e13ae0f, which left behind a code duplication TODO
and an unused Provider function.

This adds Provider.GetRoleARNWithRegion.
  • Loading branch information
raylu authored and nickatsegment committed Oct 1, 2019
1 parent 00cceb3 commit 02acf14
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 26 deletions.
8 changes: 1 addition & 7 deletions cmd/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/99designs/keyring"
"github.com/alessio/shellescape"
"github.com/aws/aws-sdk-go/aws/credentials"
analytics "github.com/segmentio/analytics-go"
"github.com/segmentio/aws-okta/lib"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -99,12 +98,7 @@ func envRun(cmd *cobra.Command, args []string) error {
return err
}

// TODO: deduplicate this code from exec.go
roleARN, err := lib.GetRoleARN(credentials.Value{
AccessKeyID: creds.AccessKeyID,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: creds.SessionToken,
})
roleARN, err := p.GetRoleARNWithRegion(creds)
if err != nil {
return err
}
Expand Down
7 changes: 1 addition & 6 deletions cmd/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cmd
import (
"fmt"

"github.com/aws/aws-sdk-go/aws/credentials"
log "github.com/sirupsen/logrus"

"os"
Expand Down Expand Up @@ -200,11 +199,7 @@ func execRun(cmd *cobra.Command, args []string) error {
return err
}

roleARN, err := lib.GetRoleARN(credentials.Value{
AccessKeyID: creds.AccessKeyID,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: creds.SessionToken,
})
roleARN, err := p.GetRoleARNWithRegion(creds)
if err != nil {
return err
}
Expand Down
42 changes: 29 additions & 13 deletions lib/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,25 +325,41 @@ func (p *Provider) roleSessionName() string {
return fmt.Sprintf("%d", time.Now().UTC().UnixNano())
}

// GetRoleARN uses temporary credentials to call AWS's get-caller-identity and
// returns the assumed role's ARN
func (p *Provider) GetRoleARNWithRegion(creds credentials.Value) (string, error) {
config := aws.Config{Credentials: credentials.NewStaticCredentials(
creds.AccessKeyID,
creds.SecretAccessKey,
creds.SessionToken,
)}
if region := p.profiles[sourceProfile(p.profile, p.profiles)]["region"]; region != "" {
config.WithRegion(region)
}
client := sts.New(aws_session.New(&config))

indentity, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
log.Errorf("Error getting caller identity: %s", err.Error())
return "", err
}
arn := *indentity.Arn
return arn, nil
}

// GetRoleARN uses p to establish temporary credentials then calls
// lib.GetRoleARN with them to get the role's ARN
// lib.GetRoleARN with them to get the role's ARN. It is unused internally and
// is kept for backwards compatability.
func (p *Provider) GetRoleARN() (string, error) {
creds, err := p.getSamlSessionCreds()
if err != nil {
return "", err
}
return GetRoleARN(credentials.Value{
AccessKeyID: *creds.AccessKeyId,
SecretAccessKey: *creds.SecretAccessKey,
SessionToken: *creds.SessionToken,
})
}

// GetRoleARN makes a call to AWS to get-caller-identity and returns the
// assumed role's name and ARN.
func GetRoleARN(c credentials.Value) (string, error) {
client := sts.New(aws_session.New(&aws.Config{Credentials: credentials.NewStaticCredentialsFromCreds(c)}))

client := sts.New(aws_session.New(&aws.Config{Credentials: credentials.NewStaticCredentials(
*creds.AccessKeyId,
*creds.SecretAccessKey,
*creds.SessionToken,
)}))
indentity, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
log.Errorf("Error getting caller identity: %s", err.Error())
Expand Down

0 comments on commit 02acf14

Please sign in to comment.