Skip to content

Commit

Permalink
credential process support (#384)
Browse files Browse the repository at this point in the history
* refactor logging to support file based logging to ease debugging on Windows
* add remaining credential lifetime to log
* refactor parameters and function names
* fix spoiled stdout
* unify help message, check return codes
  • Loading branch information
bitte-ein-bit authored Apr 30, 2024
1 parent 90e1f1f commit cfa6f2b
Show file tree
Hide file tree
Showing 30 changed files with 585 additions and 234 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ steps_output.txt
bottle_output.txt

dist/
clisso.yaml
35 changes: 33 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,42 @@ with the app's name as the [profile name][10]. You can use the temporary credent
the profile name as an argument to the AWS CLI (`--profile my-profile`), by setting the
`AWS_PROFILE` environment variable or by configuring any AWS SDK to use the profile.

To save the credentials to a custom file, use the `-w` flag.
To save the credentials to a custom file, use the `--output` flag with a custom path. For example:

To print the credentials to the shell instead of storing them in a file, use the `-s` flag. This
clisso get my-app --output /path/to/credentials

To print the credentials to the shell instead of storing them in a file, use the `--output environment` flag. This
will output shell commands which can be pasted in any shell to use the credentials.

### Running as `credential_process`

AWS CLI v2 introduced the `credential_process` feature which allows you to use an external command to obtain temporal credentials.
Clisso can be used as a `credential_process` command by setting the `--output credential_process` flag. For example:

clisso get my-app --output credential_process

You can use this by adding the following to your `~/.aws/credentials` file:

```ini
[my-app]
credential_process = clisso get my-app --output credential_process
```

The AWS SDK does not cache any credentials obtained using `credential_process`. This means that every time you use the profile, Clisso will be called to obtain new credentials. If you want to cache the credentials, you can use the `--cache` flag. For example:

```ini
[my-app]
credential_process = clisso get my-app --output credential_process --cache
```

Alternatively you can set it in the `~/.clisso.yaml` file:

```yaml
global:
cache:
enable: true
```
### Storing the password in the key chain
> WARNING: Storing the password without having MFA enabled is a security risk. It allows anyone
Expand Down
84 changes: 68 additions & 16 deletions aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import (
"io"
"time"

"github.com/allcloud-io/clisso/log"
"github.com/go-ini/ini"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
)

// Credentials represents a set of temporary credentials received from AWS STS
Expand All @@ -32,11 +33,11 @@ type Profile struct {

const expireKey = "aws_expiration"

// WriteToFile writes credentials to an AWS CLI credentials file
// OutputFile writes credentials to an AWS CLI credentials file
// (https://docs.aws.amazon.com/cli/latest/userguide/cli-config-files.html). In addition, this
// function removes expired temporary credentials from the credentials file.
func WriteToFile(c *Credentials, filename string, section string) error {
log.WithFields(log.Fields{
func OutputFile(c *Credentials, filename string, section string) error {
log.Log.WithFields(logrus.Fields{
"filename": filename,
"section": section,
}).Debug("Writing credentials to file")
Expand Down Expand Up @@ -65,29 +66,29 @@ func WriteToFile(c *Credentials, filename string, section string) error {
// Remove expired credentials.
for _, s := range cfg.Sections() {
if !s.HasKey(expireKey) {
log.Tracef("Skipping profile %s because it does not have an %s key", s.Name(), expireKey)
log.Log.Tracef("Skipping profile %s because it does not have an %s key", s.Name(), expireKey)
continue
}
v, err := s.Key(expireKey).TimeFormat(time.RFC3339)
if err != nil {
log.Warnf("Cannot parse date (%v) in profile %s: %s",
log.Log.Warnf("Cannot parse date (%v) in profile %s: %s",
s.Key(expireKey), s.Name(), err)
continue
}
if time.Now().UTC().Unix() > v.Unix() {
log.Tracef("Removing expired credentials for profile %s", s.Name())
log.Log.Tracef("Removing expired credentials for profile %s", s.Name())
cfg.DeleteSection(s.Name())
continue
}
log.Tracef("Profile %s expires at %s", s.Name(), v.Format(time.RFC3339))
log.Log.Tracef("Profile %s expires at %s", s.Name(), v.Format(time.RFC3339))
}

return cfg.SaveTo(filename)
}

// WriteToShell writes (prints) credentials to stdout. If windows is true, Windows syntax will be
// used.
func WriteToShell(c *Credentials, windows bool, w io.Writer) {
// OutputEnvironment writes (prints) credentials to stdout. If windows is true, Windows syntax will be
// used. The output can be used to set environment variables.
func OutputEnvironment(c *Credentials, windows bool, w io.Writer) {
fmt.Print("Please paste the following in your shell:")
if windows {
fmt.Fprintf(
Expand All @@ -108,21 +109,37 @@ func WriteToShell(c *Credentials, windows bool, w io.Writer) {
}
}

// GetValidCredentials returns profiles which have a aws_expiration key but are not yet expired.
func GetValidCredentials(filename string) ([]Profile, error) {
// OutputCredentialProcess writes (prints) credentials to stdout in the format required by the AWS CLI.
// The output can be used to set the credential_process option in the AWS CLI configuration file.
func OutputCredentialProcess(c *Credentials, w io.Writer) {
log.Log.Trace("Writing credentials to stdout in credential_process format")
log.Log.Infof("Credentials expire at %s, in %d Minutes", c.Expiration.Format(time.RFC3339), int(c.Expiration.Sub(time.Now().UTC()).Minutes()))
fmt.Fprintf(
w,
`{ "Version": 1, "AccessKeyId": %q, "SecretAccessKey": %q, "SessionToken": %q, "Expiration": %q }`,
c.AccessKeyID,
c.SecretAccessKey,
c.SessionToken,
// Time must be in ISO8601 format
c.Expiration.Format(time.RFC3339),
)
}

// GetValidProfiles returns profiles which have a aws_expiration key but are not yet expired.
func GetValidProfiles(filename string) ([]Profile, error) {
var profiles []Profile
log.WithField("filename", filename).Trace("Loading AWS credentials file")
log.Log.WithField("filename", filename).Trace("Loading AWS credentials file")
cfg, err := ini.LooseLoad(filename)
if err != nil {
err = fmt.Errorf("%s contains errors: %w", filename, err)
log.WithError(err).Trace("Failed to load AWS credentials file")
log.Log.WithError(err).Trace("Failed to load AWS credentials file")
return nil, err
}
for _, s := range cfg.Sections() {
if s.HasKey(expireKey) {
v, err := s.Key(expireKey).TimeFormat(time.RFC3339)
if err != nil {
log.Warnf("Cannot parse date (%v) in section %s: %s",
log.Log.Warnf("Cannot parse date (%v) in section %s: %s",
s.Key(expireKey), s.Name(), err)
continue
}
Expand All @@ -136,3 +153,38 @@ func GetValidCredentials(filename string) ([]Profile, error) {
}
return profiles, nil
}

// GetValidCredentials returns credentials which have a aws_expiration key but are not yet expired.
// returns a map of profile name to credentials
func GetValidCredentials(filename string) (map[string]Credentials, error) {
credentials := make(map[string]Credentials)
log.Log.WithField("filename", filename).Trace("Loading credentials file")
cfg, err := ini.LooseLoad(filename)
if err != nil {
err = fmt.Errorf("%s contains errors: %w", filename, err)
log.Log.WithError(err).Trace("Failed to load credentials file")
return nil, err
}
for _, s := range cfg.Sections() {
if s.HasKey(expireKey) {
v, err := s.Key(expireKey).TimeFormat(time.RFC3339)
if err != nil {
log.Log.Warnf("Cannot parse date (%v) in section %s: %s",
s.Key(expireKey), s.Name(), err)
continue
}

if time.Now().UTC().Unix() < v.Unix() {
credential := Credentials{
AccessKeyID: s.Key("aws_access_key_id").String(),
SecretAccessKey: s.Key("aws_secret_access_key").String(),
SessionToken: s.Key("aws_session_token").String(),
Expiration: v,
}
credentials[s.Name()] = credential
}

}
}
return credentials, nil
}
25 changes: 14 additions & 11 deletions aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ import (
"testing"
"time"

"github.com/allcloud-io/clisso/log"
"github.com/go-ini/ini"
)

var _ = log.NewLogger("panic", "", false)

func TestWriteToFile(t *testing.T) {
id := "expiredkey"
sec := "expiredsecret"
Expand All @@ -32,7 +35,7 @@ func TestWriteToFile(t *testing.T) {
p := "expiredprofile"

// Write credentials
err := WriteToFile(&c, fn, p)
err := OutputFile(&c, fn, p)
if err != nil {
t.Fatal("Could not write credentials to file: ", err)
}
Expand All @@ -55,7 +58,7 @@ func TestWriteToFile(t *testing.T) {
p = "testprofile"

// Write credentials
err = WriteToFile(&c, fn, p)
err = OutputFile(&c, fn, p)
if err != nil {
t.Fatal("Could not write credentials to file: ", err)
}
Expand Down Expand Up @@ -120,7 +123,7 @@ func TestWriteToFile(t *testing.T) {
}
}

func TestGetValidCredentials(t *testing.T) {
func TestGetValidProfiles(t *testing.T) {
fn := "test_creds.txt"

id := "testkey"
Expand All @@ -139,7 +142,7 @@ func TestGetValidCredentials(t *testing.T) {
p := "expired"

// Write credentials
err := WriteToFile(&c, fn, p)
err := OutputFile(&c, fn, p)
if err != nil {
t.Fatal("Could not write credentials to file: ", err)
}
Expand All @@ -149,7 +152,7 @@ func TestGetValidCredentials(t *testing.T) {
p = "valid"

// Write credentials
err = WriteToFile(&c, fn, p)
err = OutputFile(&c, fn, p)
if err != nil {
t.Fatal("Could not write credentials to file: ", err)
}
Expand Down Expand Up @@ -179,7 +182,7 @@ func TestGetValidCredentials(t *testing.T) {

time.Sleep(time.Duration(1) * time.Second)

profiles, err := GetValidCredentials(fn)
profiles, err := GetValidProfiles(fn)
if err != nil {
t.Fatal("Failed to get NonExpiredCredentials")
}
Expand All @@ -202,13 +205,13 @@ func TestGetValidCredentials(t *testing.T) {
t.Fatalf("Could not remove file %v during cleanup", fn)
}

_, err = GetValidCredentials(fn)
_, err = GetValidProfiles(fn)
if err != nil {
t.Fatal("Function did crash on missing file")
}
}

func TestWriteToShellUnix(t *testing.T) {
func TestOutputUnixEnvironment(t *testing.T) {
id := "testkey"
sec := "testsecret"
tok := "testtoken"
Expand All @@ -222,7 +225,7 @@ func TestWriteToShellUnix(t *testing.T) {
}
var b bytes.Buffer

WriteToShell(&c, false, &b)
OutputEnvironment(&c, false, &b)

got := b.String()
want := fmt.Sprintf(
Expand All @@ -237,7 +240,7 @@ func TestWriteToShellUnix(t *testing.T) {
}
}

func TestWriteToShellWindows(t *testing.T) {
func TestOutputWindowsEnvironment(t *testing.T) {
id := "testkey"
sec := "testsecret"
tok := "testtoken"
Expand All @@ -251,7 +254,7 @@ func TestWriteToShellWindows(t *testing.T) {
}
var b bytes.Buffer

WriteToShell(&c, true, &b)
OutputEnvironment(&c, true, &b)

got := b.String()
want := fmt.Sprintf(
Expand Down
21 changes: 11 additions & 10 deletions aws/sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ import (
"regexp"
"strings"

"github.com/allcloud-io/clisso/log"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/smithy-go"
"github.com/icza/gog"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
)

const (
Expand All @@ -41,13 +42,13 @@ const (
// returns a specific error message to indicate that. In this case we return a custom error to the
// caller to allow special handling such as retrying with a lower duration.
func AssumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, duration int32) (*Credentials, error) {
log.WithFields(log.Fields{
log.Log.WithFields(logrus.Fields{
"PrincipalArn": PrincipalArn,
"RoleArn": RoleArn,
"awsRegion": awsRegion,
"duration": duration,
}).Debug("Assuming role with SAML assertion")
log.WithField("SAMLAssertion", SAMLAssertion).Trace("SAML assertion")
log.Log.WithField("SAMLAssertion", SAMLAssertion).Trace("SAML assertion")
creds, err := assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion, duration)
if err != nil {
// Check if API error returned by AWS
Expand Down Expand Up @@ -80,15 +81,15 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura

config, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion), config.WithSharedConfigProfile("default"))
if err != nil {
log.WithError(err).Debug("Error loading default configuration")
log.Log.WithError(err).Debug("Error loading default configuration")
return nil, err
}
log.WithField("awsRegion", config.Region).Trace("Loaded default config")
log.Log.WithField("awsRegion", config.Region).Trace("Loaded default config")

// If we request credentials for China we need to provide a Chinese region
idp := regexp.MustCompile(`^arn:aws-cn:iam::\d+:saml-provider\/\S+$`)
if idp.MatchString(PrincipalArn) && !strings.HasPrefix(awsRegion, "cn-") {
log.Trace("Setting region to cn-north-1")
log.Log.Trace("Setting region to cn-north-1")
config.Region = "cn-north-1"
}
svc := sts.NewFromConfig(config, func(o *sts.Options) {
Expand All @@ -98,7 +99,7 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura

aResp, err := svc.AssumeRoleWithSAML(ctx, &input)
if err != nil {
log.WithError(err).Debug("Error assuming role with SAML assertion")
log.Log.WithError(err).Debug("Error assuming role with SAML assertion")
return nil, err
}

Expand All @@ -107,10 +108,10 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura
sessionToken := *aResp.Credentials.SessionToken
expiration := *aResp.Credentials.Expiration

log.WithFields(log.Fields{
log.Log.WithFields(logrus.Fields{
"AccessKeyID": keyID,
"SecretAccessKey": gog.If(log.GetLevel() == log.TraceLevel, secretKey, "<redacted>"),
"SessionToken": gog.If(log.GetLevel() == log.TraceLevel, sessionToken, "<redacted>"),
"SecretAccessKey": gog.If(log.Log.GetLevel() == logrus.TraceLevel, secretKey, "<redacted>"),
"SessionToken": gog.If(log.Log.GetLevel() == logrus.TraceLevel, sessionToken, "<redacted>"),
"Expiration": expiration,
}).Debug("Got temporary credentials")

Expand Down
Loading

0 comments on commit cfa6f2b

Please sign in to comment.