From 8f96c1ce6f402d9c4187889619f4f41895073851 Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Fri, 1 Dec 2023 17:29:42 +0100 Subject: [PATCH 01/12] draft version with file cache --- aws/aws.go | 59 +++++++++++++++++++-- aws/aws_test.go | 10 ++-- cmd/get.go | 102 ++++++++++++++++++++++++++++++++----- cmd/status.go | 4 +- okta/get.go | 9 ++-- onelogin/get.go | 14 +++-- spinner/spinner.go | 4 +- spinner/spinner_unix.go | 4 +- spinner/spinner_windows.go | 2 +- 9 files changed, 171 insertions(+), 37 deletions(-) diff --git a/aws/aws.go b/aws/aws.go index 67e84dd..b5a6d14 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -85,9 +85,9 @@ func WriteToFile(c *Credentials, filename string, section string) error { return cfg.SaveTo(filename) } -// WriteToShell writes (prints) credentials to stdout. If windows is true, Windows syntax will be -// used. -func WriteToShell(c *Credentials, windows bool, w io.Writer) { +// WriteToStdOutAsEnvironment writes (prints) credentials to stdout. If windows is true, Windows syntax will be +// used. The output can be used to set environment variables. +func WriteToStdOutAsEnvironment(c *Credentials, windows bool, w io.Writer) { fmt.Print("Please paste the following in your shell:") if windows { fmt.Fprintf( @@ -108,8 +108,22 @@ func WriteToShell(c *Credentials, windows bool, w io.Writer) { } } -// GetValidCredentials returns profiles which have a aws_expiration key but are not yet expired. -func GetValidCredentials(filename string) ([]Profile, error) { +// WriteCredentialsToStdOutAsCredentialProcess writes (prints) credentials to stdout in the format required by the AWS CLI. +// The output can be used to set the credential_process option in the AWS CLI configuration file. +func WriteCredentialsToStdOutAsCredentialProcess(c *Credentials, w io.Writer) { + fmt.Fprintf( + w, + `{ "Version": 1, "AccessKeyId": %q, "SecretAccessKey": %q, "SessionToken": %q, "Expiration": %q }`, + c.AccessKeyID, + c.SecretAccessKey, + c.SessionToken, + // Time must be in ISO8601 format + c.Expiration.Format(time.RFC3339), + ) +} + +// GetValidProfiles returns profiles which have a aws_expiration key but are not yet expired. +func GetValidProfiles(filename string) ([]Profile, error) { var profiles []Profile log.WithField("filename", filename).Trace("Loading AWS credentials file") cfg, err := ini.LooseLoad(filename) @@ -136,3 +150,38 @@ func GetValidCredentials(filename string) ([]Profile, error) { } return profiles, nil } + +// GetValidCredentials returns credentials which have a aws_expiration key but are not yet expired. +// returns a map of profile name to credentials +func GetValidCredentials(filename string) (map[string]Credentials, error) { + credentials := make(map[string]Credentials) + log.WithField("filename", filename).Trace("Loading credentials file") + cfg, err := ini.LooseLoad(filename) + if err != nil { + err = fmt.Errorf("%s contains errors: %w", filename, err) + log.WithError(err).Trace("Failed to load credentials file") + return nil, err + } + for _, s := range cfg.Sections() { + if s.HasKey(expireKey) { + v, err := s.Key(expireKey).TimeFormat(time.RFC3339) + if err != nil { + log.Warnf("Cannot parse date (%v) in section %s: %s", + s.Key(expireKey), s.Name(), err) + continue + } + + if time.Now().UTC().Unix() < v.Unix() { + credential := Credentials{ + AccessKeyID: s.Key("aws_access_key_id").String(), + SecretAccessKey: s.Key("aws_secret_access_key").String(), + SessionToken: s.Key("aws_session_token").String(), + Expiration: v, + } + credentials[s.Name()] = credential + } + + } + } + return credentials, nil +} \ No newline at end of file diff --git a/aws/aws_test.go b/aws/aws_test.go index 8a9b1af..64a03f2 100644 --- a/aws/aws_test.go +++ b/aws/aws_test.go @@ -120,7 +120,7 @@ func TestWriteToFile(t *testing.T) { } } -func TestGetValidCredentials(t *testing.T) { +func TestGetValidProfiles(t *testing.T) { fn := "test_creds.txt" id := "testkey" @@ -179,7 +179,7 @@ func TestGetValidCredentials(t *testing.T) { time.Sleep(time.Duration(1) * time.Second) - profiles, err := GetValidCredentials(fn) + profiles, err := GetValidProfiles(fn) if err != nil { t.Fatal("Failed to get NonExpiredCredentials") } @@ -202,7 +202,7 @@ func TestGetValidCredentials(t *testing.T) { t.Fatalf("Could not remove file %v during cleanup", fn) } - _, err = GetValidCredentials(fn) + _, err = GetValidProfiles(fn) if err != nil { t.Fatal("Function did crash on missing file") } @@ -222,7 +222,7 @@ func TestWriteToShellUnix(t *testing.T) { } var b bytes.Buffer - WriteToShell(&c, false, &b) + WriteToStdOutAsEnvironment(&c, false, &b) got := b.String() want := fmt.Sprintf( @@ -251,7 +251,7 @@ func TestWriteToShellWindows(t *testing.T) { } var b bytes.Buffer - WriteToShell(&c, true, &b) + WriteToStdOutAsEnvironment(&c, true, &b) got := b.String() want := fmt.Sprintf( diff --git a/cmd/get.go b/cmd/get.go index 2bda408..d8d098b 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -22,18 +22,40 @@ import ( ) var printToShell bool +var printToCredentialProcess bool +var cacheCredentials bool var writeToFile string +var cacheToFile string func init() { RootCmd.AddCommand(cmdGet) cmdGet.Flags().BoolVarP( - &printToShell, "shell", "s", false, "Print credentials to shell", + &printToShell, "shell", "s", false, "Print credentials to shell to be sourced as environment variables", ) + cmdGet.Flags().BoolVarP( + &printToCredentialProcess, "credential_process", "p", false, "Print credentials in the format used by the AWS CLI credential_process", + ) + cmdGet.Flags().BoolVarP( + &cacheCredentials, "cache-credentials", "", false, + "Should credentials be cached to a file if run as a credential_process (default: false)", + ) + err := viper.BindPFlag("global.cache-credentials", cmdGet.Flags().Lookup("cache-credentials")) + if err != nil { + log.Fatalf("Error binding flag global.credentials-cache-path: %v", err) + } + cmdGet.Flags().StringVarP( + &cacheToFile, "cache-file", "", "~/.aws/credentials-cache", + "Write credentials to this file instead of the default (~/.aws/credentials-cache)", + ) + err = viper.BindPFlag("global.credentials-cache-path", cmdGet.Flags().Lookup("cache-file")) + if err != nil { + log.Fatalf("Error binding flag global.credentials-cache-path: %v", err) + } cmdGet.Flags().StringVarP( - &writeToFile, "write-to-file", "w", "", + &writeToFile, "write-to-file", "w", "~/.aws/credentials", "Write credentials to this file instead of the default ($HOME/.aws/credentials)", ) - err := viper.BindPFlag("global.credentials-path", cmdGet.Flags().Lookup("write-to-file")) + err = viper.BindPFlag("global.credentials-path", cmdGet.Flags().Lookup("write-to-file")) if err != nil { log.Fatalf("Error binding flag global.credentials-path: %v", err) } @@ -41,32 +63,45 @@ func init() { // processCredentials prints the given Credentials to a file and/or to the shell. func processCredentials(creds *aws.Credentials, app string) error { + if printToCredentialProcess && printToShell { + return fmt.Errorf("cannot use both --shell and --credential-process") + } if printToShell { // Print credentials to shell using the correct syntax for the OS. - aws.WriteToShell(creds, runtime.GOOS == "windows", os.Stdout) + aws.WriteToStdOutAsEnvironment(creds, runtime.GOOS == "windows", os.Stdout) + return nil + } + + var viperPathString string + if printToCredentialProcess { + aws.WriteCredentialsToStdOutAsCredentialProcess(creds, os.Stdout) + if cacheCredentials { + viperPathString = "global.credentials-cache-path" + } } else { - path, err := homedir.Expand(viper.GetString("global.credentials-path")) + viperPathString = "global.credentials-path" + } + if viperPathString != "" { + path, err := homedir.Expand(viper.GetString(viperPathString)) if err != nil { return fmt.Errorf("expanding config file path: %v", err) } - // Create the `global.credentials-path` directory if it doesn't exist. credsFileParentDir := filepath.Dir(path) if _, err := os.Stat(credsFileParentDir); os.IsNotExist(err) { log.Warnf("Credentials directory '%s' does not exist - creating it", credsFileParentDir) - - err = os.MkdirAll(credsFileParentDir, 0755) + // Lets default to strict permissions on the folders we create + err = os.MkdirAll(credsFileParentDir, 0700) if err != nil { return fmt.Errorf("creating credentials directory: %v", err) } } - if err = aws.WriteToFile(creds, path, app); err != nil { + if err := aws.WriteToFile(creds, path, app); err != nil { return fmt.Errorf("writing credentials to file: %v", err) } log.Printf("Credentials written successfully to '%s'", path) } - return nil } @@ -100,6 +135,31 @@ func awsRegion(app string) string { return "aws-global" } +func getCachedCredential(app string) (*aws.Credentials, error) { + // get the credentials from the cache file + credentialFile, err := homedir.Expand(viper.GetString("global.credentials-cache-path")) + if err != nil { + log.Fatalf("Failed to expand home: %s", err) + } + + profiles, err := aws.GetValidCredentials(credentialFile) + if err != nil { + log.Fatalf("Failed to retrieve non-expired credentials: %s", err) + } + + if len(profiles) == 0 { + return nil, nil + } + + // find the app we are looking for + for k, p := range profiles { + if k == app { + return &p, nil + } + } + return nil, fmt.Errorf("no valid credentials found for app '%s'", app) +} + var cmdGet = &cobra.Command{ Use: "get", Short: "Get temporary credentials for an app", @@ -141,8 +201,22 @@ If no app is specified, the selected app (if configured) will be assumed.`, awsRegion := awsRegion(app) + if printToCredentialProcess && cacheCredentials { + log.Trace("Using --cache-credentials and --credential-process") + // we need to cache the credentials to a file and return valid credentials instead of constantly hitting the IdPs + credential, err := getCachedCredential(app) + if err != nil { + log.WithError(err).Debugf("Failed to find cached credentials for app '%s'", app) + } + if credential != nil { + aws.WriteCredentialsToStdOutAsCredentialProcess(credential, os.Stdout) + return + } + } + + interactive := !printToShell && !printToCredentialProcess if pType == "onelogin" { - creds, err := onelogin.Get(app, provider, pArn, awsRegion, duration) + creds, err := onelogin.Get(app, provider, pArn, awsRegion, duration, interactive) if err != nil { log.Fatal("Could not get temporary credentials: ", err) } @@ -152,7 +226,7 @@ If no app is specified, the selected app (if configured) will be assumed.`, log.Fatalf("Error processing credentials: %v", err) } } else if pType == "okta" { - creds, err := okta.Get(app, provider, pArn, awsRegion, duration) + creds, err := okta.Get(app, provider, pArn, awsRegion, duration, interactive) if err != nil { log.Fatal("Could not get temporary credentials: ", err) } @@ -164,6 +238,8 @@ If no app is specified, the selected app (if configured) will be assumed.`, } else { log.Fatalf("Unsupported identity provider type '%s' for app '%s'", pType, app) } - printStatus() + if interactive { + printStatus() + } }, } diff --git a/cmd/status.go b/cmd/status.go index 75fa790..d6c07dd 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -42,12 +42,12 @@ var cmdStatus = &cobra.Command{ } func printStatus() { - configfile, err := homedir.Expand(viper.GetString("global.credentials-path")) + credentialFile, err := homedir.Expand(viper.GetString("global.credentials-path")) if err != nil { log.Fatalf("Failed to expand home: %s", err) } - profiles, err := aws.GetValidCredentials(configfile) + profiles, err := aws.GetValidProfiles(credentialFile) if err != nil { log.Fatalf("Failed to retrieve non-expired credentials: %s", err) } diff --git a/okta/get.go b/okta/get.go index aa8835d..234d6c4 100644 --- a/okta/get.go +++ b/okta/get.go @@ -31,13 +31,14 @@ var ( ) // Get gets temporary credentials for the given app. -func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credentials, error) { +func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool) (*aws.Credentials, error) { log.WithFields(log.Fields{ "app": app, "provider": provider, "pArn": pArn, "awsRegion": awsRegion, "duration": duration, + "interactive": interactive, }).Trace("Getting credentials from Okta") // Get provider config p, err := config.GetOktaProvider(provider) @@ -71,7 +72,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential } // Initialize spinner - var s = spinner.New() + var s = spinner.New(interactive) // Get session token s.Start() @@ -114,7 +115,9 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential // https://developer.okta.com/docs/api/resources/authn/#verify-push-factor // Keep polling authentication transactions with WAITING result until the challenge // completes or expires. - fmt.Println("Please approve request on Okta Verify app") + if interactive { + fmt.Println("Please approve request on Okta Verify app") + } s.Start() vfResp, err = c.VerifyFactor(&VerifyFactorParams{ FactorID: factor.ID, diff --git a/onelogin/get.go b/onelogin/get.go index f5d4439..34245e7 100644 --- a/onelogin/get.go +++ b/onelogin/get.go @@ -8,6 +8,7 @@ package onelogin import ( "errors" "fmt" + "os" "strconv" "strings" "time" @@ -40,13 +41,14 @@ var ( // Get gets temporary credentials for the given app. // TODO Move AWS logic outside this function. -func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credentials, error) { +func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool) (*aws.Credentials, error) { log.WithFields(log.Fields{ "app": app, "provider": provider, "pArn": pArn, "awsRegion": awsRegion, "duration": duration, + "interactive": interactive, }).Trace("Getting credentials from OneLogin") // Read config p, err := config.GetOneLoginProvider(provider) @@ -65,7 +67,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential } // Initialize spinner - var s = spinner.New() + var s = spinner.New(interactive) // Get OneLogin access token s.Start() @@ -155,8 +157,12 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential } pMfa.DoNotNotify = true - - fmt.Println(rMfa.Message) + if interactive { + fmt.Println(rMfa.Message) + } else { + // print to StdErr if we're not interactive + fmt.Fprintln(os.Stderr, rMfa.Message) + } timeout := MFAPushTimeout s.Start() diff --git a/spinner/spinner.go b/spinner/spinner.go index af98f8b..eb16708 100644 --- a/spinner/spinner.go +++ b/spinner/spinner.go @@ -8,8 +8,8 @@ package spinner // This is a wrapper around spinner to disable unsupported operation systems transparently until upstream is fixed. // See https://github.com/briandowns/spinner/issues/52 -func New() SpinnerWrapper { - return new() +func New(interactive bool) SpinnerWrapper { + return new(interactive) } // SpinnerWrapper is used to abstract a spinner so that it can be conveniently disabled on terminals which don't support it. diff --git a/spinner/spinner_unix.go b/spinner/spinner_unix.go index cd07068..ea9e39e 100644 --- a/spinner/spinner_unix.go +++ b/spinner/spinner_unix.go @@ -16,8 +16,8 @@ import ( log "github.com/sirupsen/logrus" ) -func new() SpinnerWrapper { - if log.GetLevel() >= log.DebugLevel { +func new(interactive bool) SpinnerWrapper { + if log.GetLevel() >= log.DebugLevel || !interactive { return &noopSpinner{} } return spinner.New(spinner.CharSets[14], 50*time.Millisecond) diff --git a/spinner/spinner_windows.go b/spinner/spinner_windows.go index 8ecb0f9..fb2e23c 100644 --- a/spinner/spinner_windows.go +++ b/spinner/spinner_windows.go @@ -8,6 +8,6 @@ */ package spinner -func new() SpinnerWrapper { +func new(interactive bool) SpinnerWrapper { return &noopSpinner{} } From 9b0d553363dba0210a5ecb5d50c0c381c2b95958 Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Thu, 7 Dec 2023 19:59:31 +0100 Subject: [PATCH 02/12] refactor logging to support file based logging to ease debugging on Windows --- aws/aws.go | 27 +++++++++++---------- aws/sts.go | 21 ++++++++-------- cmd/apps.go | 40 +++++++++++++++---------------- cmd/get.go | 36 ++++++++++++++-------------- cmd/helpers.go | 4 ++-- cmd/providers.go | 30 +++++++++++------------ cmd/root.go | 32 +++++++++++++++---------- cmd/status.go | 10 ++++---- config/config.go | 9 +++---- go.mod | 1 + go.sum | 2 ++ keychain/keychain.go | 8 +++---- log/log.go | 53 +++++++++++++++++++++++++++++++++++++++++ main.go | 15 ------------ okta/client.go | 5 ++-- okta/get.go | 29 +++++++++++----------- onelogin/client.go | 18 +++++++------- onelogin/get.go | 37 ++++++++++++++-------------- saml/saml.go | 33 ++++++++++++------------- spinner/spinner_unix.go | 5 ++-- 20 files changed, 236 insertions(+), 179 deletions(-) create mode 100644 log/log.go diff --git a/aws/aws.go b/aws/aws.go index b5a6d14..ae83944 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -10,8 +10,9 @@ import ( "io" "time" + "github.com/allcloud-io/clisso/log" "github.com/go-ini/ini" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) // Credentials represents a set of temporary credentials received from AWS STS @@ -36,7 +37,7 @@ const expireKey = "aws_expiration" // (https://docs.aws.amazon.com/cli/latest/userguide/cli-config-files.html). In addition, this // function removes expired temporary credentials from the credentials file. func WriteToFile(c *Credentials, filename string, section string) error { - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "filename": filename, "section": section, }).Debug("Writing credentials to file") @@ -65,21 +66,21 @@ func WriteToFile(c *Credentials, filename string, section string) error { // Remove expired credentials. for _, s := range cfg.Sections() { if !s.HasKey(expireKey) { - log.Tracef("Skipping profile %s because it does not have an %s key", s.Name(), expireKey) + log.Log.Tracef("Skipping profile %s because it does not have an %s key", s.Name(), expireKey) continue } v, err := s.Key(expireKey).TimeFormat(time.RFC3339) if err != nil { - log.Warnf("Cannot parse date (%v) in profile %s: %s", + log.Log.Warnf("Cannot parse date (%v) in profile %s: %s", s.Key(expireKey), s.Name(), err) continue } if time.Now().UTC().Unix() > v.Unix() { - log.Tracef("Removing expired credentials for profile %s", s.Name()) + log.Log.Tracef("Removing expired credentials for profile %s", s.Name()) cfg.DeleteSection(s.Name()) continue } - log.Tracef("Profile %s expires at %s", s.Name(), v.Format(time.RFC3339)) + log.Log.Tracef("Profile %s expires at %s", s.Name(), v.Format(time.RFC3339)) } return cfg.SaveTo(filename) @@ -125,18 +126,18 @@ func WriteCredentialsToStdOutAsCredentialProcess(c *Credentials, w io.Writer) { // GetValidProfiles returns profiles which have a aws_expiration key but are not yet expired. func GetValidProfiles(filename string) ([]Profile, error) { var profiles []Profile - log.WithField("filename", filename).Trace("Loading AWS credentials file") + log.Log.WithField("filename", filename).Trace("Loading AWS credentials file") cfg, err := ini.LooseLoad(filename) if err != nil { err = fmt.Errorf("%s contains errors: %w", filename, err) - log.WithError(err).Trace("Failed to load AWS credentials file") + log.Log.WithError(err).Trace("Failed to load AWS credentials file") return nil, err } for _, s := range cfg.Sections() { if s.HasKey(expireKey) { v, err := s.Key(expireKey).TimeFormat(time.RFC3339) if err != nil { - log.Warnf("Cannot parse date (%v) in section %s: %s", + log.Log.Warnf("Cannot parse date (%v) in section %s: %s", s.Key(expireKey), s.Name(), err) continue } @@ -155,18 +156,18 @@ func GetValidProfiles(filename string) ([]Profile, error) { // returns a map of profile name to credentials func GetValidCredentials(filename string) (map[string]Credentials, error) { credentials := make(map[string]Credentials) - log.WithField("filename", filename).Trace("Loading credentials file") + log.Log.WithField("filename", filename).Trace("Loading credentials file") cfg, err := ini.LooseLoad(filename) if err != nil { err = fmt.Errorf("%s contains errors: %w", filename, err) - log.WithError(err).Trace("Failed to load credentials file") + log.Log.WithError(err).Trace("Failed to load credentials file") return nil, err } for _, s := range cfg.Sections() { if s.HasKey(expireKey) { v, err := s.Key(expireKey).TimeFormat(time.RFC3339) if err != nil { - log.Warnf("Cannot parse date (%v) in section %s: %s", + log.Log.Warnf("Cannot parse date (%v) in section %s: %s", s.Key(expireKey), s.Name(), err) continue } @@ -184,4 +185,4 @@ func GetValidCredentials(filename string) (map[string]Credentials, error) { } } return credentials, nil -} \ No newline at end of file +} diff --git a/aws/sts.go b/aws/sts.go index 524967e..748fedf 100644 --- a/aws/sts.go +++ b/aws/sts.go @@ -11,12 +11,13 @@ import ( "regexp" "strings" + "github.com/allcloud-io/clisso/log" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go" "github.com/icza/gog" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) const ( @@ -41,13 +42,13 @@ const ( // returns a specific error message to indicate that. In this case we return a custom error to the // caller to allow special handling such as retrying with a lower duration. func AssumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, duration int32) (*Credentials, error) { - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "PrincipalArn": PrincipalArn, "RoleArn": RoleArn, "awsRegion": awsRegion, "duration": duration, }).Debug("Assuming role with SAML assertion") - log.WithField("SAMLAssertion", SAMLAssertion).Trace("SAML assertion") + log.Log.WithField("SAMLAssertion", SAMLAssertion).Trace("SAML assertion") creds, err := assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion, duration) if err != nil { // Check if API error returned by AWS @@ -80,15 +81,15 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura config, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion)) if err != nil { - log.WithError(err).Debug("Error loading default configuration") + log.Log.WithError(err).Debug("Error loading default configuration") return nil, err } - log.WithField("awsRegion", config.Region).Trace("Loaded default config") + log.Log.WithField("awsRegion", config.Region).Trace("Loaded default config") // If we request credentials for China we need to provide a Chinese region idp := regexp.MustCompile(`^arn:aws-cn:iam::\d+:saml-provider\/\S+$`) if idp.MatchString(PrincipalArn) && !strings.HasPrefix(awsRegion, "cn-") { - log.Trace("Setting region to cn-north-1") + log.Log.Trace("Setting region to cn-north-1") config.Region = "cn-north-1" } svc := sts.NewFromConfig(config, func(o *sts.Options) { @@ -98,7 +99,7 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura aResp, err := svc.AssumeRoleWithSAML(ctx, &input) if err != nil { - log.WithError(err).Debug("Error assuming role with SAML assertion") + log.Log.WithError(err).Debug("Error assuming role with SAML assertion") return nil, err } @@ -107,10 +108,10 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura sessionToken := *aResp.Credentials.SessionToken expiration := *aResp.Credentials.Expiration - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "AccessKeyID": keyID, - "SecretAccessKey": gog.If(log.GetLevel() == log.TraceLevel, secretKey, ""), - "SessionToken": gog.If(log.GetLevel() == log.TraceLevel, sessionToken, ""), + "SecretAccessKey": gog.If(log.Log.GetLevel() == logrus.TraceLevel, secretKey, ""), + "SessionToken": gog.If(log.Log.GetLevel() == logrus.TraceLevel, sessionToken, ""), "Expiration": expiration, }).Debug("Got temporary credentials") diff --git a/cmd/apps.go b/cmd/apps.go index 5a56bbf..aa70dff 100644 --- a/cmd/apps.go +++ b/cmd/apps.go @@ -10,7 +10,7 @@ import ( "sort" "strconv" - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -63,7 +63,7 @@ var cmdAppsList = &cobra.Command{ Long: "List all configured apps.", Run: func(cmd *cobra.Command, args []string) { apps := viper.GetStringMap("apps") - log.Trace("Listing apps") + log.Log.Trace("Listing apps") if len(apps) == 0 { fmt.Println("No apps configured") @@ -105,18 +105,18 @@ var cmdAppsCreateOneLogin = &cobra.Command{ // Verify app doesn't exist if exists := viper.Get("apps." + name); exists != nil { - log.Fatalf("App '%s' already exists", name) + log.Log.Fatalf("App '%s' already exists", name) } // Verify provider exists if exists := viper.Get("providers." + provider); exists == nil { - log.Fatalf("Provider '%s' doesn't exist", provider) + log.Log.Fatalf("Provider '%s' doesn't exist", provider) } // Verify provider type pType := viper.GetString(fmt.Sprintf("providers.%s.type", provider)) if pType != "onelogin" { - log.Fatalf( + log.Log.Fatalf( "Invalid provider type '%s' for a OneLogin app. Type must be 'onelogin'.", pType, ) @@ -134,9 +134,9 @@ var cmdAppsCreateOneLogin = &cobra.Command{ if duration != 0 { // Duration specified - validate value if duration < 3600 || duration > 43200 { - log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } - log.Tracef("Setting duration to %d", duration) + log.Log.Tracef("Setting duration to %d", duration) conf["duration"] = strconv.Itoa(duration) } @@ -145,9 +145,9 @@ var cmdAppsCreateOneLogin = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Fatalf("Error writing config: %v", err) + log.Log.Fatalf("Error writing config: %v", err) } - log.Printf("App '%s' saved to config file", name) + log.Log.Printf("App '%s' saved to config file", name) }, } @@ -161,18 +161,18 @@ var cmdAppsCreateOkta = &cobra.Command{ // Verify app doesn't exist if exists := viper.Get("apps." + name); exists != nil { - log.Fatalf("App '%s' already exists", name) + log.Log.Fatalf("App '%s' already exists", name) } // Verify provider exists if exists := viper.Get("providers." + provider); exists == nil { - log.Fatalf("Provider '%s' doesn't exist", provider) + log.Log.Fatalf("Provider '%s' doesn't exist", provider) } // Verify provider type pType := viper.GetString(fmt.Sprintf("providers.%s.type", provider)) if pType != "okta" { - log.Fatalf( + log.Log.Fatalf( "Invalid provider type '%s' for an Okta app. Type must be 'okta'.", pType, ) @@ -186,9 +186,9 @@ var cmdAppsCreateOkta = &cobra.Command{ if duration != 0 { // Duration specified - validate value if duration < 3600 || duration > 43200 { - log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } - log.Tracef("Setting duration to %d", duration) + log.Log.Tracef("Setting duration to %d", duration) conf["duration"] = strconv.Itoa(duration) } @@ -197,9 +197,9 @@ var cmdAppsCreateOkta = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Fatalf("Error writing config: %v", err) + log.Log.Fatalf("Error writing config: %v", err) } - log.Printf("App '%s' saved to config file", name) + log.Log.Printf("App '%s' saved to config file", name) }, } @@ -213,19 +213,19 @@ var cmdAppsSelect = &cobra.Command{ if app == "" { viper.Set("global.selected-app", "") - log.Println("Unsetting selected app") + log.Log.Println("Unsetting selected app") } else { if exists := viper.Get("apps." + app); exists == nil { - log.Fatalf("App '%s' doesn't exist", app) + log.Log.Fatalf("App '%s' doesn't exist", app) } - log.Printf("Setting selected app to '%s'", app) + log.Log.Printf("Setting selected app to '%s'", app) viper.Set("global.selected-app", app) } // Write config to file err := viper.WriteConfig() if err != nil { - log.Fatalf("Error writing config: %v", err) + log.Log.Fatalf("Error writing config: %v", err) } }, } diff --git a/cmd/get.go b/cmd/get.go index d8d098b..23a55fa 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -11,8 +11,8 @@ import ( "path/filepath" "runtime" + "github.com/allcloud-io/clisso/log" "github.com/mitchellh/go-homedir" - log "github.com/sirupsen/logrus" "github.com/allcloud-io/clisso/aws" "github.com/allcloud-io/clisso/okta" @@ -41,7 +41,7 @@ func init() { ) err := viper.BindPFlag("global.cache-credentials", cmdGet.Flags().Lookup("cache-credentials")) if err != nil { - log.Fatalf("Error binding flag global.credentials-cache-path: %v", err) + log.Log.Fatalf("Error binding flag global.credentials-cache-path: %v", err) } cmdGet.Flags().StringVarP( &cacheToFile, "cache-file", "", "~/.aws/credentials-cache", @@ -49,7 +49,7 @@ func init() { ) err = viper.BindPFlag("global.credentials-cache-path", cmdGet.Flags().Lookup("cache-file")) if err != nil { - log.Fatalf("Error binding flag global.credentials-cache-path: %v", err) + log.Log.Fatalf("Error binding flag global.credentials-cache-path: %v", err) } cmdGet.Flags().StringVarP( &writeToFile, "write-to-file", "w", "~/.aws/credentials", @@ -57,7 +57,7 @@ func init() { ) err = viper.BindPFlag("global.credentials-path", cmdGet.Flags().Lookup("write-to-file")) if err != nil { - log.Fatalf("Error binding flag global.credentials-path: %v", err) + log.Log.Fatalf("Error binding flag global.credentials-path: %v", err) } } @@ -89,7 +89,7 @@ func processCredentials(creds *aws.Credentials, app string) error { // Create the `global.credentials-path` directory if it doesn't exist. credsFileParentDir := filepath.Dir(path) if _, err := os.Stat(credsFileParentDir); os.IsNotExist(err) { - log.Warnf("Credentials directory '%s' does not exist - creating it", credsFileParentDir) + log.Log.Warnf("Credentials directory '%s' does not exist - creating it", credsFileParentDir) // Lets default to strict permissions on the folders we create err = os.MkdirAll(credsFileParentDir, 0700) if err != nil { @@ -100,7 +100,7 @@ func processCredentials(creds *aws.Credentials, app string) error { if err := aws.WriteToFile(creds, path, app); err != nil { return fmt.Errorf("writing credentials to file: %v", err) } - log.Printf("Credentials written successfully to '%s'", path) + log.Log.Printf("Credentials written successfully to '%s'", path) } return nil } @@ -139,12 +139,12 @@ func getCachedCredential(app string) (*aws.Credentials, error) { // get the credentials from the cache file credentialFile, err := homedir.Expand(viper.GetString("global.credentials-cache-path")) if err != nil { - log.Fatalf("Failed to expand home: %s", err) + log.Log.Fatalf("Failed to expand home: %s", err) } profiles, err := aws.GetValidCredentials(credentialFile) if err != nil { - log.Fatalf("Failed to retrieve non-expired credentials: %s", err) + log.Log.Fatalf("Failed to retrieve non-expired credentials: %s", err) } if len(profiles) == 0 { @@ -175,7 +175,7 @@ If no app is specified, the selected app (if configured) will be assumed.`, selected := viper.GetString("global.selected-app") if selected == "" { // No default app configured. - log.Fatal("No app specified and no default app configured") + log.Log.Fatal("No app specified and no default app configured") } app = selected } else { @@ -185,12 +185,12 @@ If no app is specified, the selected app (if configured) will be assumed.`, provider := viper.GetString(fmt.Sprintf("apps.%s.provider", app)) if provider == "" { - log.Fatalf("Could not get provider for app '%s'", app) + log.Log.Fatalf("Could not get provider for app '%s'", app) } pType := viper.GetString(fmt.Sprintf("providers.%s.type", provider)) if pType == "" { - log.Fatalf("Could not get provider type for provider '%s'", provider) + log.Log.Fatalf("Could not get provider type for provider '%s'", provider) } // allow preferred "arn" to be specified in the config file for each app @@ -202,11 +202,11 @@ If no app is specified, the selected app (if configured) will be assumed.`, awsRegion := awsRegion(app) if printToCredentialProcess && cacheCredentials { - log.Trace("Using --cache-credentials and --credential-process") + log.Log.Trace("Using --cache-credentials and --credential-process") // we need to cache the credentials to a file and return valid credentials instead of constantly hitting the IdPs credential, err := getCachedCredential(app) if err != nil { - log.WithError(err).Debugf("Failed to find cached credentials for app '%s'", app) + log.Log.WithError(err).Debugf("Failed to find cached credentials for app '%s'", app) } if credential != nil { aws.WriteCredentialsToStdOutAsCredentialProcess(credential, os.Stdout) @@ -218,25 +218,25 @@ If no app is specified, the selected app (if configured) will be assumed.`, if pType == "onelogin" { creds, err := onelogin.Get(app, provider, pArn, awsRegion, duration, interactive) if err != nil { - log.Fatal("Could not get temporary credentials: ", err) + log.Log.Fatal("Could not get temporary credentials: ", err) } // Process credentials err = processCredentials(creds, app) if err != nil { - log.Fatalf("Error processing credentials: %v", err) + log.Log.Fatalf("Error processing credentials: %v", err) } } else if pType == "okta" { creds, err := okta.Get(app, provider, pArn, awsRegion, duration, interactive) if err != nil { - log.Fatal("Could not get temporary credentials: ", err) + log.Log.Fatal("Could not get temporary credentials: ", err) } // Process credentials err = processCredentials(creds, app) if err != nil { - log.Fatalf("Error processing credentials: %v", err) + log.Log.Fatalf("Error processing credentials: %v", err) } } else { - log.Fatalf("Unsupported identity provider type '%s' for app '%s'", pType, app) + log.Log.Fatalf("Unsupported identity provider type '%s' for app '%s'", pType, app) } if interactive { printStatus() diff --git a/cmd/helpers.go b/cmd/helpers.go index 3da06e9..59698dc 100644 --- a/cmd/helpers.go +++ b/cmd/helpers.go @@ -6,13 +6,13 @@ package cmd import ( - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" "github.com/spf13/cobra" ) func mandatoryFlag(cmd *cobra.Command, name string) { err := cmd.MarkFlagRequired(name) if err != nil { - log.Fatalf("Error marking flag %s as required: %v", name, err) + log.Log.Fatalf("Error marking flag %s as required: %v", name, err) } } diff --git a/cmd/providers.go b/cmd/providers.go index 6d84333..a9f0c3a 100644 --- a/cmd/providers.go +++ b/cmd/providers.go @@ -12,7 +12,7 @@ import ( "syscall" "github.com/allcloud-io/clisso/keychain" - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" "github.com/spf13/cobra" "github.com/spf13/viper" "golang.org/x/term" @@ -77,7 +77,7 @@ var cmdProvidersList = &cobra.Command{ providers := viper.GetStringMap("providers") if len(providers) == 0 { - log.Println("No providers configured") + log.Log.Println("No providers configured") return } @@ -88,7 +88,7 @@ var cmdProvidersList = &cobra.Command{ } sort.Strings(keys) for _, k := range keys { - log.Println(k) + log.Log.Println(k) } }, } @@ -104,16 +104,16 @@ var cmdProvidersPassword = &cobra.Command{ fmt.Printf("Please enter the password for the '%s' provider: ", provider) pass, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { - log.Fatalf("Could not read password") + log.Log.Fatalf("Could not read password") } keyChain := keychain.DefaultKeychain{} err = keyChain.Set(provider, pass) if err != nil { - log.Fatalf("Could not save to keychain: %+v", err) + log.Log.Fatalf("Could not save to keychain: %+v", err) } - log.Printf("Saved password for Provider '%s'", provider) + log.Log.Printf("Saved password for Provider '%s'", provider) }, } @@ -133,13 +133,13 @@ var cmdProvidersCreateOneLogin = &cobra.Command{ // Verify provider doesn't exist if exists := viper.Get("providers." + name); exists != nil { - log.Fatalf("Provider '%s' already exists", name) + log.Log.Fatalf("Provider '%s' already exists", name) } switch region { case "US", "EU": default: - log.Fatal("Region must be either US or EU") + log.Log.Fatal("Region must be either US or EU") } conf := map[string]string{ @@ -153,7 +153,7 @@ var cmdProvidersCreateOneLogin = &cobra.Command{ if providerDuration != 0 { // Duration specified - validate value if providerDuration < 3600 || providerDuration > 43200 { - log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } conf["duration"] = strconv.Itoa(providerDuration) } @@ -162,9 +162,9 @@ var cmdProvidersCreateOneLogin = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Fatalf("Error writing config: %v", err) + log.Log.Fatalf("Error writing config: %v", err) } - log.Printf("Provider '%s' saved to config file", name) + log.Log.Printf("Provider '%s' saved to config file", name) }, } @@ -178,7 +178,7 @@ var cmdProvidersCreateOkta = &cobra.Command{ // Verify provider doesn't exist if exists := viper.Get("providers." + name); exists != nil { - log.Fatalf("Provider '%s' already exists", name) + log.Log.Fatalf("Provider '%s' already exists", name) } conf := map[string]string{ @@ -189,7 +189,7 @@ var cmdProvidersCreateOkta = &cobra.Command{ if providerDuration != 0 { // Duration specified - validate value if providerDuration < 3600 || providerDuration > 43200 { - log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } conf["duration"] = strconv.Itoa(providerDuration) } @@ -198,8 +198,8 @@ var cmdProvidersCreateOkta = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Fatalf("Error writing config: %v", err) + log.Log.Fatalf("Error writing config: %v", err) } - log.Printf("Provider '%s' saved to config file", name) + log.Log.Printf("Provider '%s' saved to config file", name) }, } diff --git a/cmd/root.go b/cmd/root.go index 953469d..971acfa 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -9,26 +9,22 @@ import ( "os" "path/filepath" + "github.com/allcloud-io/clisso/log" homedir "github.com/mitchellh/go-homedir" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" ) var cfgFile string +var logFile string var RootCmd = &cobra.Command{ Use: "clisso", Version: "0.0.0", PersistentPreRun: func(cmd *cobra.Command, args []string) { - // get log level flag value logLevelFlag := cmd.Flag("log-level").Value.String() - // parse log level flag and set log level - logLevel, err := log.ParseLevel(logLevelFlag) - if err != nil { - log.Fatalf("Error parsing log level: %v", err) - } - log.SetLevel(logLevel) + log.Log = log.NewLogger(logLevelFlag, logFile, true) }, } @@ -80,7 +76,19 @@ func init() { ) // Add a global log level flag RootCmd.PersistentFlags().String("log-level", "info", "set log level to trace, debug, info, warn, error, fatal or panic") + err := viper.BindPFlag("global.logs.level", RootCmd.PersistentFlags().Lookup("log-level")) + if err != nil { + // log isn't available yet, so we can't use it + logrus.Fatalf("Error binding flag global.logs.level: %v", err) + } + RootCmd.PersistentFlags().StringVarP( + &logFile, "log-file", "", "~/.clisso.log", "log file location (~/.clisso.log)", + ) + err = viper.BindPFlag("global.logs.path", RootCmd.PersistentFlags().Lookup("log-file")) + if err != nil { + logrus.Fatalf("Error binding flag global.logs.path: %v", err) + } RootCmd.SetUsageTemplate(usageTemplate) RootCmd.SetVersionTemplate(versionTemplate) } @@ -91,7 +99,7 @@ func Execute(version, commit, date string) { RootCmd.Version = version + " (" + commit + " " + date + ")" err := RootCmd.Execute() if err != nil { - log.Fatalf("Failed to execute: %v", err) + log.Log.Fatalf("Failed to execute: %v", err) } } @@ -101,7 +109,7 @@ func initConfig() { } else { home, err := homedir.Dir() if err != nil { - log.Fatalf("Error getting home directory: %v", err) + log.Log.Fatalf("Error getting home directory: %v", err) } viper.SetConfigType("yaml") @@ -113,7 +121,7 @@ func initConfig() { if _, err := os.Stat(file); os.IsNotExist(err) { _, err := os.Create(file) if err != nil { - log.Fatalf("Error creating config file: %v", err) + log.Log.Fatalf("Error creating config file: %v", err) } } @@ -122,6 +130,6 @@ func initConfig() { } if err := viper.ReadInConfig(); err != nil { - log.Fatalf("Can't read config: %v", err) + log.Log.Fatalf("Can't read config: %v", err) } } diff --git a/cmd/status.go b/cmd/status.go index d6c07dd..df5fb51 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -11,9 +11,9 @@ import ( "time" "github.com/allcloud-io/clisso/aws" + "github.com/allcloud-io/clisso/log" homedir "github.com/mitchellh/go-homedir" "github.com/olekukonko/tablewriter" - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -28,7 +28,7 @@ func init() { ) err := viper.BindPFlag("global.credentials-path", cmdStatus.Flags().Lookup("read-from-file")) if err != nil { - log.Fatalf("Error binding flag global.credentials-path: %v", err) + log.Log.Fatalf("Error binding flag global.credentials-path: %v", err) } } @@ -44,12 +44,12 @@ var cmdStatus = &cobra.Command{ func printStatus() { credentialFile, err := homedir.Expand(viper.GetString("global.credentials-path")) if err != nil { - log.Fatalf("Failed to expand home: %s", err) + log.Log.Fatalf("Failed to expand home: %s", err) } profiles, err := aws.GetValidProfiles(credentialFile) if err != nil { - log.Fatalf("Failed to retrieve non-expired credentials: %s", err) + log.Log.Fatalf("Failed to retrieve non-expired credentials: %s", err) } if len(profiles) == 0 { @@ -60,7 +60,7 @@ func printStatus() { table := tablewriter.NewWriter(os.Stdout) table.SetHeader([]string{"App", "Expire At", "Remaining"}) - log.Print("The following apps currently have valid credentials:") + log.Log.Print("The following apps currently have valid credentials:") for _, p := range profiles { table.Append([]string{p.Name, fmt.Sprintf("%d", p.ExpireAtUnix), p.LifetimeLeft.Round(time.Second).String()}) } diff --git a/config/config.go b/config/config.go index 69b16bc..fc65960 100644 --- a/config/config.go +++ b/config/config.go @@ -9,8 +9,9 @@ import ( "errors" "fmt" + "github.com/allcloud-io/clisso/log" "github.com/icza/gog" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/spf13/viper" ) @@ -27,14 +28,14 @@ type OneLoginProviderConfig struct { // GetOneLoginProvider returns a OneLoginProviderConfig struct containing the configuration for // provider p. func GetOneLoginProvider(p string) (*OneLoginProviderConfig, error) { - log.WithField("provider", p).Trace("Reading OneLogin provider config") + log.Log.WithField("provider", p).Trace("Reading OneLogin provider config") clientSecret := viper.GetString(fmt.Sprintf("providers.%s.client-secret", p)) clientID := viper.GetString(fmt.Sprintf("providers.%s.client-id", p)) subdomain := viper.GetString(fmt.Sprintf("providers.%s.subdomain", p)) username := viper.GetString(fmt.Sprintf("providers.%s.username", p)) region := viper.GetString(fmt.Sprintf("providers.%s.region", p)) - log.WithFields(log.Fields{ - "clientSecret": gog.If(log.GetLevel() == log.TraceLevel, clientSecret, ""), + log.Log.WithFields(logrus.Fields{ + "clientSecret": gog.If(log.Log.GetLevel() == logrus.TraceLevel, clientSecret, ""), "clientID": clientID, "subdomain": subdomain, "username": username, diff --git a/go.mod b/go.mod index 4deea4b..60e1554 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/mattn/go-colorable v0.1.13 github.com/mitchellh/go-homedir v1.1.0 github.com/olekukonko/tablewriter v0.0.5 + github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 github.com/spf13/viper v1.18.2 diff --git a/go.sum b/go.sum index 3fe86ee..ec7e594 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 h1:mZHayPoR0lNmnHyvtYjDeq0zlVHn9K/ZXoy17ylucdo= +github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5/go.mod h1:GEXHk5HgEKCvEIIrSpFI3ozzG5xOKA2DVlEX/gGnewM= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= diff --git a/keychain/keychain.go b/keychain/keychain.go index 13be514..3bf823d 100644 --- a/keychain/keychain.go +++ b/keychain/keychain.go @@ -9,7 +9,7 @@ import ( "fmt" "syscall" - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" keyring "github.com/zalando/go-keyring" "golang.org/x/term" ) @@ -43,15 +43,15 @@ func (DefaultKeychain) Set(provider string, password []byte) (err error) { // and just ask the user for the password instead. Error could be anything from access denied to // password not found. func (DefaultKeychain) Get(provider string) (pw []byte, err error) { - log.WithField("provider", provider).Trace("Reading password from keychain") + log.Log.WithField("provider", provider).Trace("Reading password from keychain") pass, err := get(provider) if err != nil { - log.WithError(err).Trace("Couldn't read password from keychain") + log.Log.WithError(err).Trace("Couldn't read password from keychain") fmt.Printf("Please enter %s password: ", provider) pass, err = term.ReadPassword(int(syscall.Stdin)) if err != nil { err = fmt.Errorf("couldn't read password from terminal: %w", err) - log.WithError(err).Trace("Couldn't read password from terminal") + log.Log.WithError(err).Trace("Couldn't read password from terminal") return nil, err } } diff --git a/log/log.go b/log/log.go new file mode 100644 index 0000000..780ee1c --- /dev/null +++ b/log/log.go @@ -0,0 +1,53 @@ +package log + +import ( + "io" + "runtime" + + "github.com/mattn/go-colorable" + "github.com/mitchellh/go-homedir" + "github.com/rifflock/lfshook" + "github.com/sirupsen/logrus" +) + +var Log *logrus.Logger + +func NewLogger(logLevelFlag, logFilePath string, enableLogFile bool) *logrus.Logger { + if Log != nil { + return Log + } + + // parse log level flag and set log level + logLevel, err := logrus.ParseLevel(logLevelFlag) + if err != nil { + Log.Fatalf("Error parsing log level: %v", err) + } + Log = logrus.New() + Log.SetLevel(logLevel) + + if enableLogFile { + logFile, err := homedir.Expand(logFilePath) + if err != nil { + Log.Fatalf("Error expanding homedir: %v", err) + } + + pathMap := lfshook.PathMap{ + logLevel: logFile, + } + Log.Hooks.Add(lfshook.NewHook( + pathMap, + &logrus.JSONFormatter{}, + )) + Log.Out = io.Discard + } else { + if runtime.GOOS == "windows" { + // Handle terminal colors on Windows machines. + // TODO, check if still required with the switch to logrus + Log.SetOutput(colorable.NewColorableStdout()) + } + Log.SetFormatter(&logrus.TextFormatter{PadLevelText: true}) + } + Log.Infof("Log level set to %s", logLevelFlag) + return Log +} + diff --git a/main.go b/main.go index 115f1c6..2e1bd60 100644 --- a/main.go +++ b/main.go @@ -6,11 +6,6 @@ package main import ( - "runtime" - - "github.com/mattn/go-colorable" - log "github.com/sirupsen/logrus" - "github.com/allcloud-io/clisso/cmd" ) @@ -21,16 +16,6 @@ var ( date = "unknown" ) -func init() { - if runtime.GOOS == "windows" { - // Handle terminal colors on Windows machines. - // TODO, check if still required with the switch to logrus - log.SetOutput(colorable.NewColorableStdout()) - } - - log.SetFormatter(&log.TextFormatter{PadLevelText: true}) -} - func main() { cmd.Execute(version, commit, date) } diff --git a/okta/client.go b/okta/client.go index ff5dc3d..1be57fa 100644 --- a/okta/client.go +++ b/okta/client.go @@ -16,7 +16,8 @@ import ( "time" "github.com/PuerkitoBio/goquery" - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" + "github.com/sirupsen/logrus" "golang.org/x/net/publicsuffix" ) @@ -186,7 +187,7 @@ func (c *Client) LaunchApp(p *LaunchAppParams) (*string, error) { // using the client, handles any HTTP-related errors and returns any data as a string. func (c *Client) doRequest(r *http.Request) (string, error) { resp, err := c.Do(r) - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "status": resp.Status, "url": resp.Request.URL, "host": resp.Request.Host, diff --git a/okta/get.go b/okta/get.go index 234d6c4..2cdc527 100644 --- a/okta/get.go +++ b/okta/get.go @@ -12,10 +12,11 @@ import ( "github.com/allcloud-io/clisso/aws" "github.com/allcloud-io/clisso/config" "github.com/allcloud-io/clisso/keychain" + "github.com/allcloud-io/clisso/log" "github.com/allcloud-io/clisso/saml" "github.com/allcloud-io/clisso/spinner" "github.com/icza/gog" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) const ( @@ -32,12 +33,12 @@ var ( // Get gets temporary credentials for the given app. func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool) (*aws.Credentials, error) { - log.WithFields(log.Fields{ - "app": app, - "provider": provider, - "pArn": pArn, - "awsRegion": awsRegion, - "duration": duration, + log.Log.WithFields(logrus.Fields{ + "app": app, + "provider": provider, + "pArn": pArn, + "awsRegion": awsRegion, + "duration": duration, "interactive": interactive, }).Trace("Getting credentials from Okta") // Get provider config @@ -76,17 +77,17 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool // Get session token s.Start() - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "Username": user, // print password only in Trace Log Level - "Password": gog.If(log.GetLevel() == log.TraceLevel, string(pass), ""), + "Password": gog.If(log.Log.GetLevel() == logrus.TraceLevel, string(pass), ""), }).Debug("Calling GetSessionToken") resp, err := c.GetSessionToken(&GetSessionTokenParams{ Username: user, Password: string(pass), }) s.Stop() - log.WithField("Status", resp.Status).WithError(err).Trace("GetSessionToken done") + log.Log.WithField("Status", resp.Status).WithError(err).Trace("GetSessionToken done") if err != nil { return nil, fmt.Errorf("getting session token: %v", err) } @@ -100,7 +101,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool case StatusMFARequired: factor := resp.Embedded.Factors[0] stateToken := resp.StateToken - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "factorID": factor.ID, "factorLink": factor.Links.Verify.Href, "stateToken": stateToken, @@ -167,7 +168,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool // Handle failed MFA verification (verification rejected or timed out) if vfResp.Status != VerifyFactorStatusSuccess { err = fmt.Errorf("MFA verification failed") - log.WithField("status", vfResp.Status).WithError(err).Warn("MFA verification failed") + log.Log.WithField("status", vfResp.Status).WithError(err).Warn("MFA verification failed") return nil, fmt.Errorf("MFA verification failed") } @@ -178,7 +179,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool // Launch Okta app with session token s.Start() - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "SessionToken": st, "URL": a.URL, }).Trace("Calling LaunchApp") @@ -199,7 +200,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool if err != nil { if err.Error() == aws.ErrDurationExceeded { - log.Warn(aws.DurationExceededMessage) + log.Log.Warn(aws.DurationExceededMessage) s.Start() creds, err = aws.AssumeSAMLRole(arn.Provider, arn.Role, *samlAssertion, awsRegion, 3600) s.Stop() diff --git a/onelogin/client.go b/onelogin/client.go index a113316..4bb7fb4 100644 --- a/onelogin/client.go +++ b/onelogin/client.go @@ -14,7 +14,8 @@ import ( "net/http" "time" - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" + "github.com/sirupsen/logrus" ) // Client represents a OneLogin API client. @@ -115,14 +116,13 @@ func (c *Client) doRequest(r *http.Request) (string, error) { resp, err := c.Do(r) if resp != nil { - log.WithFields( - log.Fields{ - "status": resp.Status, - "url": resp.Request.URL, - "host": resp.Request.Host, - "code": resp.StatusCode, - "method": resp.Request.Method, - }).WithError(err).Trace("HTTP request sent") + log.Log.WithFields(logrus.Fields{ + "status": resp.Status, + "url": resp.Request.URL, + "host": resp.Request.Host, + "code": resp.StatusCode, + "method": resp.Request.Method, + }).WithError(err).Trace("HTTP request sent") } if err != nil { diff --git a/onelogin/get.go b/onelogin/get.go index 34245e7..be44960 100644 --- a/onelogin/get.go +++ b/onelogin/get.go @@ -16,10 +16,11 @@ import ( "github.com/allcloud-io/clisso/aws" "github.com/allcloud-io/clisso/config" "github.com/allcloud-io/clisso/keychain" + "github.com/allcloud-io/clisso/log" "github.com/allcloud-io/clisso/saml" "github.com/allcloud-io/clisso/spinner" "github.com/icza/gog" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) const ( @@ -42,12 +43,12 @@ var ( // Get gets temporary credentials for the given app. // TODO Move AWS logic outside this function. func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool) (*aws.Credentials, error) { - log.WithFields(log.Fields{ - "app": app, - "provider": provider, - "pArn": pArn, - "awsRegion": awsRegion, - "duration": duration, + log.Log.WithFields(logrus.Fields{ + "app": app, + "provider": provider, + "pArn": pArn, + "awsRegion": awsRegion, + "duration": duration, "interactive": interactive, }).Trace("Getting credentials from OneLogin") // Read config @@ -71,7 +72,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool // Get OneLogin access token s.Start() - log.Trace("Generating access token") + log.Log.Trace("Generating access token") token, err := c.GenerateTokens(p.ClientID, p.ClientSecret) s.Stop() if err != nil { @@ -80,7 +81,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool user := p.Username if user == "" { - log.Trace("No username provided") + log.Log.Trace("No username provided") // Get credentials from the user fmt.Print("OneLogin username: ") fmt.Scanln(&user) @@ -101,10 +102,10 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool Subdomain: p.Subdomain, } - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "UsernameOrEmail": user, // print password only in Trace Log Level - "Password": gog.If(log.GetLevel() == log.TraceLevel, string(pass), ""), + "Password": gog.If(log.Log.GetLevel() == logrus.TraceLevel, string(pass), ""), "AppId": a.ID, "Subdomain": p.Subdomain, }).Debug("Calling GenerateSamlAssertion") @@ -116,14 +117,14 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool return nil, fmt.Errorf("generating SAML assertion: %v", err) } - log.WithField("Message", rSaml.Message).Debug("GenerateSamlAssertion is done") + log.Log.WithField("Message", rSaml.Message).Debug("GenerateSamlAssertion is done") var rData string if rSaml.Message != "Success" { st := rSaml.StateToken devices := rSaml.Devices - log.WithField("Devices", devices).Trace("Devices returned by GenerateSamlAssertion") + log.Log.WithField("Devices", devices).Trace("Devices returned by GenerateSamlAssertion") device, err := getDevice(devices) if err != nil { return nil, fmt.Errorf("error getting devices: %s", err) @@ -143,7 +144,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool OtpToken: "", DoNotNotify: false, } - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "AppId": a.ID, "DeviceId": device.DeviceID, "StateToken": st, @@ -168,7 +169,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool s.Start() for strings.Contains(rMfa.Message, "pending") && timeout > 0 { time.Sleep(time.Duration(MFAInterval) * time.Second) - log.Trace("MFAInterval completed, calling VerifyFactor again") + log.Log.Trace("MFAInterval completed, calling VerifyFactor again") rMfa, err = c.VerifyFactor(token, &pMfa) if err != nil { s.Stop() @@ -208,7 +209,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool } } rData = rMfa.Data - log.Trace("Factor is verified") + log.Log.Trace("Factor is verified") } else { rData = rSaml.Data } @@ -224,7 +225,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool if err != nil { if err.Error() == aws.ErrDurationExceeded { - log.Warn(aws.DurationExceededMessage) + log.Log.Warn(aws.DurationExceededMessage) s.Start() creds, err = aws.AssumeSAMLRole(arn.Provider, arn.Role, rData, awsRegion, 3600) s.Stop() @@ -247,7 +248,7 @@ func getDevice(devices []Device) (device *Device, err error) { } if len(devices) == 1 { - log.Trace("Only one MFA device returned by Onelogin, automatically selecting it.") + log.Log.Trace("Only one MFA device returned by Onelogin, automatically selecting it.") device = &Device{DeviceID: devices[0].DeviceID, DeviceType: devices[0].DeviceType} return } diff --git a/saml/saml.go b/saml/saml.go index 5595ace..f66b606 100644 --- a/saml/saml.go +++ b/saml/saml.go @@ -14,8 +14,9 @@ import ( "strconv" "strings" + "github.com/allcloud-io/clisso/log" "github.com/crewjam/saml" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/spf13/viper" ) @@ -32,14 +33,14 @@ const idpRegex = `^arn:(?:aws|aws-cn):iam::\d+:saml-provider\/\S+$` func Get(data, pArn string) (a ARN, err error) { samlBody, err := decode(data) if err != nil { - log.WithError(err).Error("Error decoding SAML assertion") + log.Log.WithError(err).Error("Error decoding SAML assertion") return } x := new(saml.Response) err = xml.Unmarshal(samlBody, x) if err != nil { - log.WithError(err).Error("Error parsing SAML assertion") + log.Log.WithError(err).Error("Error parsing SAML assertion") return } @@ -67,10 +68,10 @@ func decode(in string) (b []byte, err error) { } func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { - log.WithField("preferredARN", pArn).Trace("Extracting ARNs from SAML AttributeStatements") + log.Log.WithField("preferredARN", pArn).Trace("Extracting ARNs from SAML AttributeStatements") // check for human readable ARN strings in config accounts := viper.GetStringMap("global.accounts") - log.WithFields(accounts).Trace("Accounts loaded from config") + log.Log.WithFields(accounts).Trace("Accounts loaded from config") arns := make([]ARN, 0) for _, stmt := range stmts { @@ -90,7 +91,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { components := strings.Split(strings.TrimSpace(av.Value), ",") if len(components) != 2 { // Wrong number of components - move on - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "components": components, "length": len(components), "value": av.Value, @@ -101,7 +102,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { // people like to put spaces in there, AWS accepts them, let's remove them on our end too. components[0] = strings.TrimSpace(components[0]) components[1] = strings.TrimSpace(components[1]) - log.WithField("components", components).Trace("ARN components extracted from SAML assertion") + log.Log.WithField("components", components).Trace("ARN components extracted from SAML assertion") arn := ARN{} @@ -110,13 +111,13 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { // Otherwise it matches it with what is in the .clisso.yaml file if pArn != "" { if components[0] == pArn { - log.Trace("Preferred ARN matches first component") + log.Log.Trace("Preferred ARN matches first component") arn = ARN{components[0], components[1], ""} } else if components[1] == pArn { - log.Trace("Preferred ARN matches second component") + log.Log.Trace("Preferred ARN matches second component") arn = ARN{components[1], components[0], ""} } else { - log.Trace("Preferred ARN does not match either component") + log.Log.Trace("Preferred ARN does not match either component") continue } } else { @@ -125,20 +126,20 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { idp := regexp.MustCompile(idpRegex) if role.MatchString(components[0]) && idp.MatchString(components[1]) { - log.Trace("First component is role, second component is IdP") + log.Log.Trace("First component is role, second component is IdP") arn = ARN{components[0], components[1], ""} } else if role.MatchString(components[1]) && idp.MatchString(components[0]) { - log.Trace("First component is IdP, second component is role") + log.Log.Trace("First component is IdP, second component is role") arn = ARN{components[1], components[0], ""} } else { - log.Trace("Neither component matches expected pattern") + log.Log.Trace("Neither component matches expected pattern") continue } // Look up the human friendly name, if available if len(accounts) > 0 { ids := role.FindStringSubmatch(arn.Role) - log.WithField("matches", ids).Trace("Role regex matches") + log.Log.WithField("matches", ids).Trace("Role regex matches") // if the regex matches we should have 3 entries from the regex match // 1) the matching string @@ -146,7 +147,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { // 3) the match for Name // we want to match the Id to any accounts/roles in our config if len(ids) == 3 && accounts[ids[1]] != "" && accounts[ids[1]] != nil { - log.Trace("Found human friendly name for account") + log.Log.Trace("Found human friendly name for account") arn.Name = fmt.Sprintf("%s - %s", accounts[ids[1]].(string), ids[2]) } } @@ -159,7 +160,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { } } } - log.Trace("No statements in SAML assertion or no ARNs found.") + log.Log.Trace("No statements in SAML assertion or no ARNs found.") // Empty :( return arns } diff --git a/spinner/spinner_unix.go b/spinner/spinner_unix.go index ea9e39e..3a2e9e9 100644 --- a/spinner/spinner_unix.go +++ b/spinner/spinner_unix.go @@ -12,12 +12,13 @@ package spinner import ( "time" + "github.com/allcloud-io/clisso/log" "github.com/briandowns/spinner" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) func new(interactive bool) SpinnerWrapper { - if log.GetLevel() >= log.DebugLevel || !interactive { + if log.Log.GetLevel() >= logrus.DebugLevel || !interactive { return &noopSpinner{} } return spinner.New(spinner.CharSets[14], 50*time.Millisecond) From 30862162891401abf3498d5a979da768996c0ce0 Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Thu, 7 Dec 2023 20:01:39 +0100 Subject: [PATCH 03/12] add remaining credential lifetime to log --- aws/aws.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aws/aws.go b/aws/aws.go index ae83944..4311c68 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -112,6 +112,8 @@ func WriteToStdOutAsEnvironment(c *Credentials, windows bool, w io.Writer) { // WriteCredentialsToStdOutAsCredentialProcess writes (prints) credentials to stdout in the format required by the AWS CLI. // The output can be used to set the credential_process option in the AWS CLI configuration file. func WriteCredentialsToStdOutAsCredentialProcess(c *Credentials, w io.Writer) { + log.Log.Trace("Writing credentials to stdout in credential_process format") + log.Log.Infof("Credentials expire at %s, in %d Minutes", c.Expiration.Format(time.RFC3339), int(c.Expiration.Sub(time.Now().UTC()).Minutes())) fmt.Fprintf( w, `{ "Version": 1, "AccessKeyId": %q, "SecretAccessKey": %q, "SessionToken": %q, "Expiration": %q }`, From 564d81c2b8693980957c6c4e624a82ee4beaa467 Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Thu, 22 Feb 2024 16:40:52 +0100 Subject: [PATCH 04/12] refactor parameters and function names --- aws/aws.go | 12 ++-- aws/aws_test.go | 12 ++-- cmd/get.go | 171 ++++++++++++++++++++++++++++++++++-------------- cmd/root.go | 62 +++++++++++++----- log/log.go | 16 +++-- 5 files changed, 189 insertions(+), 84 deletions(-) diff --git a/aws/aws.go b/aws/aws.go index 4311c68..ba66d6f 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -33,10 +33,10 @@ type Profile struct { const expireKey = "aws_expiration" -// WriteToFile writes credentials to an AWS CLI credentials file +// OutputFile writes credentials to an AWS CLI credentials file // (https://docs.aws.amazon.com/cli/latest/userguide/cli-config-files.html). In addition, this // function removes expired temporary credentials from the credentials file. -func WriteToFile(c *Credentials, filename string, section string) error { +func OutputFile(c *Credentials, filename string, section string) error { log.Log.WithFields(logrus.Fields{ "filename": filename, "section": section, @@ -86,9 +86,9 @@ func WriteToFile(c *Credentials, filename string, section string) error { return cfg.SaveTo(filename) } -// WriteToStdOutAsEnvironment writes (prints) credentials to stdout. If windows is true, Windows syntax will be +// OutputEnvironment writes (prints) credentials to stdout. If windows is true, Windows syntax will be // used. The output can be used to set environment variables. -func WriteToStdOutAsEnvironment(c *Credentials, windows bool, w io.Writer) { +func OutputEnvironment(c *Credentials, windows bool, w io.Writer) { fmt.Print("Please paste the following in your shell:") if windows { fmt.Fprintf( @@ -109,9 +109,9 @@ func WriteToStdOutAsEnvironment(c *Credentials, windows bool, w io.Writer) { } } -// WriteCredentialsToStdOutAsCredentialProcess writes (prints) credentials to stdout in the format required by the AWS CLI. +// OutputCredentialProcess writes (prints) credentials to stdout in the format required by the AWS CLI. // The output can be used to set the credential_process option in the AWS CLI configuration file. -func WriteCredentialsToStdOutAsCredentialProcess(c *Credentials, w io.Writer) { +func OutputCredentialProcess(c *Credentials, w io.Writer) { log.Log.Trace("Writing credentials to stdout in credential_process format") log.Log.Infof("Credentials expire at %s, in %d Minutes", c.Expiration.Format(time.RFC3339), int(c.Expiration.Sub(time.Now().UTC()).Minutes())) fmt.Fprintf( diff --git a/aws/aws_test.go b/aws/aws_test.go index 64a03f2..35a984e 100644 --- a/aws/aws_test.go +++ b/aws/aws_test.go @@ -32,7 +32,7 @@ func TestWriteToFile(t *testing.T) { p := "expiredprofile" // Write credentials - err := WriteToFile(&c, fn, p) + err := OutputFile(&c, fn, p) if err != nil { t.Fatal("Could not write credentials to file: ", err) } @@ -55,7 +55,7 @@ func TestWriteToFile(t *testing.T) { p = "testprofile" // Write credentials - err = WriteToFile(&c, fn, p) + err = OutputFile(&c, fn, p) if err != nil { t.Fatal("Could not write credentials to file: ", err) } @@ -139,7 +139,7 @@ func TestGetValidProfiles(t *testing.T) { p := "expired" // Write credentials - err := WriteToFile(&c, fn, p) + err := OutputFile(&c, fn, p) if err != nil { t.Fatal("Could not write credentials to file: ", err) } @@ -149,7 +149,7 @@ func TestGetValidProfiles(t *testing.T) { p = "valid" // Write credentials - err = WriteToFile(&c, fn, p) + err = OutputFile(&c, fn, p) if err != nil { t.Fatal("Could not write credentials to file: ", err) } @@ -222,7 +222,7 @@ func TestWriteToShellUnix(t *testing.T) { } var b bytes.Buffer - WriteToStdOutAsEnvironment(&c, false, &b) + OutputEnvironment(&c, false, &b) got := b.String() want := fmt.Sprintf( @@ -251,7 +251,7 @@ func TestWriteToShellWindows(t *testing.T) { } var b bytes.Buffer - WriteToStdOutAsEnvironment(&c, true, &b) + OutputEnvironment(&c, true, &b) got := b.String() want := fmt.Sprintf( diff --git a/cmd/get.go b/cmd/get.go index 23a55fa..ed71f96 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -21,6 +21,7 @@ import ( "github.com/spf13/viper" ) +var output string var printToShell bool var printToCredentialProcess bool var cacheCredentials bool @@ -29,82 +30,149 @@ var cacheToFile string func init() { RootCmd.AddCommand(cmdGet) - cmdGet.Flags().BoolVarP( - &printToShell, "shell", "s", false, "Print credentials to shell to be sourced as environment variables", - ) - cmdGet.Flags().BoolVarP( - &printToCredentialProcess, "credential_process", "p", false, "Print credentials in the format used by the AWS CLI credential_process", + cmdGet.Flags().StringVarP( + &output, "output", "o", "~/.aws/credentials", "How or where to output credentials. Two special values are supported `environment` and `credential_process`. All other values are interpreted as file paths (default: $HOME/.aws/credentials)", ) + cmdGet.Flags().BoolVarP( - &cacheCredentials, "cache-credentials", "", false, + &cacheCredentials, "cache-enable", "", false, "Should credentials be cached to a file if run as a credential_process (default: false)", ) - err := viper.BindPFlag("global.cache-credentials", cmdGet.Flags().Lookup("cache-credentials")) - if err != nil { - log.Log.Fatalf("Error binding flag global.credentials-cache-path: %v", err) - } cmdGet.Flags().StringVarP( - &cacheToFile, "cache-file", "", "~/.aws/credentials-cache", - "Write credentials to this file instead of the default (~/.aws/credentials-cache)", + &cacheToFile, "cache-path", "", "~/.aws/credentials-cache", + "Write credentials to this file instead of the default ($HOME/.aws/credentials-cache)", ) - err = viper.BindPFlag("global.credentials-cache-path", cmdGet.Flags().Lookup("cache-file")) - if err != nil { - log.Log.Fatalf("Error binding flag global.credentials-cache-path: %v", err) - } + + // Keep the old flags as is. cmdGet.Flags().StringVarP( - &writeToFile, "write-to-file", "w", "~/.aws/credentials", + &writeToFile, "write-to-file", "w", "", "Write credentials to this file instead of the default ($HOME/.aws/credentials)", ) - err = viper.BindPFlag("global.credentials-path", cmdGet.Flags().Lookup("write-to-file")) + cmdGet.Flags().BoolVarP( + &printToShell, "shell", "s", false, "Print credentials to shell to be sourced as environment variables", + ) + + // Mark the old flag as deprecated. + cmdGet.Flags().MarkDeprecated("write-to-file", "please use output-file instead.") + cmdGet.Flags().MarkDeprecated("shell", "please use output-environment instead.") + + // SetNormalize function to translate the use of `old-flag` to `new-flag` + // cmdGet.Flags().SetNormalizeFunc(normalizeFlagName) + + // cmdGet.MarkFlagsMutuallyExclusive("output-environment", "output-process") + +} + +// func normalizeFlagName(f *pflag.FlagSet, name string) pflag.NormalizedName { +// switch name { +// case "write-to-file": +// name = "output" +// case "shell": +// name = "output" +// writeToFile = "environment" +// } +// return pflag.NormalizedName(name) +// } + +func preferredOutput(cmd *cobra.Command, app string) string { + // Order of preference: + // * output flag + // * write-to-file flag (deprecated) + // * app specific config file + // * global config file + // * default to ~/.aws/credentials + out, err := cmd.Flags().GetString("output") if err != nil { - log.Log.Fatalf("Error binding flag global.credentials-path: %v", err) + log.Log.Warnf("Error getting output flag: %v", err) + } + if out != "" { + return out + } + + out, err = cmd.Flags().GetString("write-to-file") + if err != nil { + log.Log.Warnf("Error getting write-to-file flag: %v", err) + } + if out != "" { + return out + } + + out = viper.GetString(fmt.Sprintf("apps.%s.output", app)) + if out != "" { + return out + } + + out = viper.GetString("global.output") + if out != "" { + return out + } + + return "~/.aws/credentials" +} + +func setOutput(cmd *cobra.Command, app string) { + o := preferredOutput(cmd, app) + writeToFile = "" + switch o { + case "environment": + printToShell = true + case "credential_process": + printToCredentialProcess = true + default: + writeToFile = o } } // processCredentials prints the given Credentials to a file and/or to the shell. func processCredentials(creds *aws.Credentials, app string) error { - if printToCredentialProcess && printToShell { - return fmt.Errorf("cannot use both --shell and --credential-process") - } if printToShell { // Print credentials to shell using the correct syntax for the OS. - aws.WriteToStdOutAsEnvironment(creds, runtime.GOOS == "windows", os.Stdout) - return nil + aws.OutputEnvironment(creds, runtime.GOOS == "windows", os.Stdout) } - var viperPathString string if printToCredentialProcess { - aws.WriteCredentialsToStdOutAsCredentialProcess(creds, os.Stdout) - if cacheCredentials { - viperPathString = "global.credentials-cache-path" - } - } else { - viperPathString = "global.credentials-path" + aws.OutputCredentialProcess(creds, os.Stdout) } - if viperPathString != "" { - path, err := homedir.Expand(viper.GetString(viperPathString)) - if err != nil { - return fmt.Errorf("expanding config file path: %v", err) - } - // Create the `global.credentials-path` directory if it doesn't exist. - credsFileParentDir := filepath.Dir(path) - if _, err := os.Stat(credsFileParentDir); os.IsNotExist(err) { - log.Log.Warnf("Credentials directory '%s' does not exist - creating it", credsFileParentDir) - // Lets default to strict permissions on the folders we create - err = os.MkdirAll(credsFileParentDir, 0700) - if err != nil { - return fmt.Errorf("creating credentials directory: %v", err) - } + + if cacheCredentials { + if err := writeCredentialsToFile(creds, app, cacheToFile); err != nil { + log.Log.Errorf("writing credentials to file: %v", err) } + } - if err := aws.WriteToFile(creds, path, app); err != nil { + // if writeToFile is set, write the credentials to the file, might be the cache file or the credentials file + if writeToFile != "" { + if err := writeCredentialsToFile(creds, app, writeToFile); err != nil { return fmt.Errorf("writing credentials to file: %v", err) } - log.Log.Printf("Credentials written successfully to '%s'", path) } return nil } +func writeCredentialsToFile(creds *aws.Credentials, app, file string) error { + log.Log.Tracef("Writing credentials to '%s'", file) + path, err := homedir.Expand(file) + if err != nil { + return fmt.Errorf("expanding config file path: %v", err) + } + // Create the `global.credentials-path` directory if it doesn't exist. + credsFileParentDir := filepath.Dir(path) + if _, err := os.Stat(credsFileParentDir); os.IsNotExist(err) { + log.Log.Warnf("Credentials directory '%s' does not exist - creating it", credsFileParentDir) + // Lets default to strict permissions on the folders we create + err = os.MkdirAll(credsFileParentDir, 0700) + if err != nil { + return fmt.Errorf("creating credentials directory: %v", err) + } + } + + if err := aws.OutputFile(creds, path, app); err != nil { + return fmt.Errorf("writing credentials to file: %v", err) + } + log.Log.Printf("Credentials written successfully to '%s'", path) + return nil +} + // sessionDuration returns a session duration using the following order of preference: // app.duration -> provider.duration -> hardcoded default of 3600 func sessionDuration(app, provider string) int32 { @@ -137,7 +205,8 @@ func awsRegion(app string) string { func getCachedCredential(app string) (*aws.Credentials, error) { // get the credentials from the cache file - credentialFile, err := homedir.Expand(viper.GetString("global.credentials-cache-path")) + log.Log.Tracef("Looking for cached credentials in '%s'", cacheToFile) + credentialFile, err := homedir.Expand(cacheToFile) if err != nil { log.Log.Fatalf("Failed to expand home: %s", err) } @@ -201,15 +270,17 @@ If no app is specified, the selected app (if configured) will be assumed.`, awsRegion := awsRegion(app) + setOutput(cmd, app) + if printToCredentialProcess && cacheCredentials { - log.Log.Trace("Using --cache-credentials and --credential-process") + log.Log.Trace("Using --cache-credentials and --output-process") // we need to cache the credentials to a file and return valid credentials instead of constantly hitting the IdPs credential, err := getCachedCredential(app) if err != nil { log.Log.WithError(err).Debugf("Failed to find cached credentials for app '%s'", app) } if credential != nil { - aws.WriteCredentialsToStdOutAsCredentialProcess(credential, os.Stdout) + aws.OutputCredentialProcess(credential, os.Stdout) return } } diff --git a/cmd/root.go b/cmd/root.go index 971acfa..284d7f7 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -6,25 +6,28 @@ package cmd import ( + "fmt" "os" "path/filepath" + "strings" "github.com/allcloud-io/clisso/log" homedir "github.com/mitchellh/go-homedir" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/spf13/pflag" "github.com/spf13/viper" ) var cfgFile string var logFile string +var logLevel string var RootCmd = &cobra.Command{ Use: "clisso", Version: "0.0.0", - PersistentPreRun: func(cmd *cobra.Command, args []string) { - logLevelFlag := cmd.Flag("log-level").Value.String() - log.Log = log.NewLogger(logLevelFlag, logFile, true) + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + return initConfig(cmd) }, } @@ -70,25 +73,24 @@ one at https://mozilla.org/MPL/2.0/. ` func init() { - cobra.OnInitialize(initConfig) RootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file (default is $HOME/.clisso.yaml)", ) // Add a global log level flag - RootCmd.PersistentFlags().String("log-level", "info", "set log level to trace, debug, info, warn, error, fatal or panic") - err := viper.BindPFlag("global.logs.level", RootCmd.PersistentFlags().Lookup("log-level")) - if err != nil { - // log isn't available yet, so we can't use it - logrus.Fatalf("Error binding flag global.logs.level: %v", err) - } + RootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "", "info", "set log level to trace, debug, info, warn, error, fatal or panic") + // err := viper.BindPFlag("global.log.level", RootCmd.PersistentFlags().Lookup("log-level")) + // if err != nil { + // // log isn't available yet, so we can't use it + // logrus.Fatalf("Error binding flag global.log.level: %v", err) + // } RootCmd.PersistentFlags().StringVarP( &logFile, "log-file", "", "~/.clisso.log", "log file location (~/.clisso.log)", ) - err = viper.BindPFlag("global.logs.path", RootCmd.PersistentFlags().Lookup("log-file")) - if err != nil { - logrus.Fatalf("Error binding flag global.logs.path: %v", err) - } + // err = viper.BindPFlag("global.log.file", RootCmd.PersistentFlags().Lookup("log-file")) + // if err != nil { + // logrus.Fatalf("Error binding flag global.log.file: %v", err) + // } RootCmd.SetUsageTemplate(usageTemplate) RootCmd.SetVersionTemplate(versionTemplate) } @@ -99,11 +101,11 @@ func Execute(version, commit, date string) { RootCmd.Version = version + " (" + commit + " " + date + ")" err := RootCmd.Execute() if err != nil { - log.Log.Fatalf("Failed to execute: %v", err) + logrus.Fatalf("Failed to execute: %v", err) } } -func initConfig() { +func initConfig(cmd *cobra.Command) error { if cfgFile != "" { viper.SetConfigFile(cfgFile) } else { @@ -125,11 +127,35 @@ func initConfig() { } } - // Set default config values - viper.SetDefault("global.credentials-path", filepath.Join(home, ".aws", "credentials")) + // // Set default config values + // viper.SetDefault("global.credentials-path", filepath.Join(home, ".aws", "credentials")) + // viper.SetDefault("global.cache.path", filepath.Join(home, ".aws", "credentials-cache")) } if err := viper.ReadInConfig(); err != nil { log.Log.Fatalf("Can't read config: %v", err) } + bindFlags(cmd, viper.GetViper()) + log.Log = log.NewLogger(logLevel, logFile, logFile != "") + return nil +} + +// Bind each cobra flag to its associated viper configuration (config file and environment variable) +func bindFlags(cmd *cobra.Command, v *viper.Viper) { + cmd.Flags().VisitAll(func(f *pflag.Flag) { + + // Determine the naming convention of the flags when represented in the config file + configName := fmt.Sprintf("global.%s", f.Name) + configName = strings.ReplaceAll(configName, "-", ".") + fmt.Printf("Checking Flag: %s\n", configName) + + // Apply the viper config value to the flag when the flag is not set and viper has a value + if !f.Changed && v.IsSet(configName) { + fmt.Printf("Setting Flag %s by config: %s\n", f.Name, configName) + val := v.Get(configName) + cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)) + } else { + fmt.Printf("Using Flag %s default: %v\n", f.Name, f.DefValue) + } + }) } diff --git a/log/log.go b/log/log.go index 780ee1c..daa7e6e 100644 --- a/log/log.go +++ b/log/log.go @@ -14,13 +14,14 @@ var Log *logrus.Logger func NewLogger(logLevelFlag, logFilePath string, enableLogFile bool) *logrus.Logger { if Log != nil { + Log.Tracef("Logger already initialized") return Log } // parse log level flag and set log level logLevel, err := logrus.ParseLevel(logLevelFlag) if err != nil { - Log.Fatalf("Error parsing log level: %v", err) + logrus.Fatalf("Error parsing log level: %v", err) } Log = logrus.New() Log.SetLevel(logLevel) @@ -31,8 +32,15 @@ func NewLogger(logLevelFlag, logFilePath string, enableLogFile bool) *logrus.Log Log.Fatalf("Error expanding homedir: %v", err) } + // set all log levels to write to the log file pathMap := lfshook.PathMap{ - logLevel: logFile, + logrus.TraceLevel: logFile, + logrus.DebugLevel: logFile, + logrus.InfoLevel: logFile, + logrus.WarnLevel: logFile, + logrus.ErrorLevel: logFile, + logrus.FatalLevel: logFile, + logrus.PanicLevel: logFile, } Log.Hooks.Add(lfshook.NewHook( pathMap, @@ -47,7 +55,7 @@ func NewLogger(logLevelFlag, logFilePath string, enableLogFile bool) *logrus.Log } Log.SetFormatter(&logrus.TextFormatter{PadLevelText: true}) } - Log.Infof("Log level set to %s", logLevelFlag) + Log.Warning("This is a warning") + Log.Warnf("Log level set to %s", logLevelFlag) return Log } - From 63005773eb15a7d8a607dd638104a5eefbde5151 Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Wed, 28 Feb 2024 14:18:07 +0100 Subject: [PATCH 05/12] fix spoiled stdout --- cmd/root.go | 6 +++--- go.mod | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 284d7f7..a2ee374 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -147,15 +147,15 @@ func bindFlags(cmd *cobra.Command, v *viper.Viper) { // Determine the naming convention of the flags when represented in the config file configName := fmt.Sprintf("global.%s", f.Name) configName = strings.ReplaceAll(configName, "-", ".") - fmt.Printf("Checking Flag: %s\n", configName) + fmt.Fprintf(os.Stderr, "Checking Flag: %s\n", configName) // Apply the viper config value to the flag when the flag is not set and viper has a value if !f.Changed && v.IsSet(configName) { - fmt.Printf("Setting Flag %s by config: %s\n", f.Name, configName) + fmt.Fprintf(os.Stderr, "Setting Flag %s by config: %s\n", f.Name, configName) val := v.Get(configName) cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)) } else { - fmt.Printf("Using Flag %s default: %v\n", f.Name, f.DefValue) + fmt.Fprintf(os.Stderr, "Using Flag %s default: %v\n", f.Name, f.DefValue) } }) } diff --git a/go.mod b/go.mod index 60e1554..76bc2d7 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 + github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.18.2 github.com/zalando/go-keyring v0.2.3 golang.org/x/net v0.21.0 @@ -57,7 +58,6 @@ require ( github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.19.0 // indirect From 038bde882735caae330f65d6fb2d094cacc0675f Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Fri, 8 Mar 2024 16:39:39 +0100 Subject: [PATCH 06/12] cleanup code, comment debug statements, remove unused code --- cmd/get.go | 16 +--------------- cmd/root.go | 8 ++++---- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/cmd/get.go b/cmd/get.go index ed71f96..7a2745c 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -56,24 +56,10 @@ func init() { cmdGet.Flags().MarkDeprecated("write-to-file", "please use output-file instead.") cmdGet.Flags().MarkDeprecated("shell", "please use output-environment instead.") - // SetNormalize function to translate the use of `old-flag` to `new-flag` - // cmdGet.Flags().SetNormalizeFunc(normalizeFlagName) - - // cmdGet.MarkFlagsMutuallyExclusive("output-environment", "output-process") + cmdGet.MarkFlagsMutuallyExclusive("output", "shell", "write-to-file") } -// func normalizeFlagName(f *pflag.FlagSet, name string) pflag.NormalizedName { -// switch name { -// case "write-to-file": -// name = "output" -// case "shell": -// name = "output" -// writeToFile = "environment" -// } -// return pflag.NormalizedName(name) -// } - func preferredOutput(cmd *cobra.Command, app string) string { // Order of preference: // * output flag diff --git a/cmd/root.go b/cmd/root.go index a2ee374..05e3534 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -147,15 +147,15 @@ func bindFlags(cmd *cobra.Command, v *viper.Viper) { // Determine the naming convention of the flags when represented in the config file configName := fmt.Sprintf("global.%s", f.Name) configName = strings.ReplaceAll(configName, "-", ".") - fmt.Fprintf(os.Stderr, "Checking Flag: %s\n", configName) + //fmt.Fprintf(os.Stderr, "Checking Flag: %s\n", configName) // Apply the viper config value to the flag when the flag is not set and viper has a value if !f.Changed && v.IsSet(configName) { - fmt.Fprintf(os.Stderr, "Setting Flag %s by config: %s\n", f.Name, configName) + //fmt.Fprintf(os.Stderr, "Setting Flag %s by config: %s\n", f.Name, configName) val := v.Get(configName) cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)) - } else { - fmt.Fprintf(os.Stderr, "Using Flag %s default: %v\n", f.Name, f.DefValue) + /*} else { + fmt.Fprintf(os.Stderr, "Using Flag %s default: %v\n", f.Name, f.DefValue)*/ } }) } From 73827e95e505b476cfef074b0982d80e77c1b346 Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Fri, 8 Mar 2024 16:57:27 +0100 Subject: [PATCH 07/12] unify help message, check return codes --- cmd/get.go | 24 ++++++++++++++++-------- cmd/root.go | 2 +- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/cmd/get.go b/cmd/get.go index 7a2745c..e5c35b1 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -31,30 +31,38 @@ var cacheToFile string func init() { RootCmd.AddCommand(cmdGet) cmdGet.Flags().StringVarP( - &output, "output", "o", "~/.aws/credentials", "How or where to output credentials. Two special values are supported `environment` and `credential_process`. All other values are interpreted as file paths (default: $HOME/.aws/credentials)", + &output, "output", "o", "$HOME/.aws/credentials", "How or where to output credentials. Two special values are supported 'environment' and 'credential_process'. All other values are interpreted as file paths", ) cmdGet.Flags().BoolVarP( &cacheCredentials, "cache-enable", "", false, - "Should credentials be cached to a file if run as a credential_process (default: false)", + "Should credentials be cached to a file, important when run as a credential_process (default: false)", ) cmdGet.Flags().StringVarP( - &cacheToFile, "cache-path", "", "~/.aws/credentials-cache", - "Write credentials to this file instead of the default ($HOME/.aws/credentials-cache)", + &cacheToFile, "cache-path", "", "$HOME/.aws/credentials-cache", + "Write credentials to this file instead of the default", ) // Keep the old flags as is. cmdGet.Flags().StringVarP( - &writeToFile, "write-to-file", "w", "", - "Write credentials to this file instead of the default ($HOME/.aws/credentials)", + &writeToFile, "write-to-file", "w", "$HOME/.aws/credentials", + "Write credentials to this file instead of the default", ) cmdGet.Flags().BoolVarP( &printToShell, "shell", "s", false, "Print credentials to shell to be sourced as environment variables", ) // Mark the old flag as deprecated. - cmdGet.Flags().MarkDeprecated("write-to-file", "please use output-file instead.") - cmdGet.Flags().MarkDeprecated("shell", "please use output-environment instead.") + err := cmdGet.Flags().MarkDeprecated("write-to-file", "please use output-file instead.") + if err != nil { + // we don't have a logger yet, so we can't use it but need to print the error to the console + fmt.Printf("Error marking flag as deprecated: %v", err) + } + err = cmdGet.Flags().MarkDeprecated("shell", "please use output-environment instead.") + if err != nil { + // we don't have a logger yet, so we can't use it but need to print the error to the console + fmt.Printf("Error marking flag as deprecated: %v", err) + } cmdGet.MarkFlagsMutuallyExclusive("output", "shell", "write-to-file") diff --git a/cmd/root.go b/cmd/root.go index 05e3534..27bc738 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -85,7 +85,7 @@ func init() { // } RootCmd.PersistentFlags().StringVarP( - &logFile, "log-file", "", "~/.clisso.log", "log file location (~/.clisso.log)", + &logFile, "log-file", "", "$HOME/.clisso.log", "log file location", ) // err = viper.BindPFlag("global.log.file", RootCmd.PersistentFlags().Lookup("log-file")) // if err != nil { From edca60c64c3fc0371f337c71d30148b0bee104ba Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Fri, 8 Mar 2024 17:07:12 +0100 Subject: [PATCH 08/12] fix tests --- aws/aws_test.go | 7 +++++-- keychain/keychain_test.go | 4 ++++ okta/client_test.go | 4 ++++ onelogin/client_test.go | 4 ++++ onelogin/endpoints_test.go | 3 +++ saml/saml_test.go | 3 +++ 6 files changed, 23 insertions(+), 2 deletions(-) diff --git a/aws/aws_test.go b/aws/aws_test.go index 35a984e..6f89b61 100644 --- a/aws/aws_test.go +++ b/aws/aws_test.go @@ -12,9 +12,12 @@ import ( "testing" "time" + "github.com/allcloud-io/clisso/log" "github.com/go-ini/ini" ) +var _ = log.NewLogger("panic","", false) + func TestWriteToFile(t *testing.T) { id := "expiredkey" sec := "expiredsecret" @@ -208,7 +211,7 @@ func TestGetValidProfiles(t *testing.T) { } } -func TestWriteToShellUnix(t *testing.T) { +func TestOutputUnixEnvironment(t *testing.T) { id := "testkey" sec := "testsecret" tok := "testtoken" @@ -237,7 +240,7 @@ func TestWriteToShellUnix(t *testing.T) { } } -func TestWriteToShellWindows(t *testing.T) { +func TestOutputWindowsEnvironment(t *testing.T) { id := "testkey" sec := "testsecret" tok := "testtoken" diff --git a/keychain/keychain_test.go b/keychain/keychain_test.go index 5aafd5f..fe5940c 100644 --- a/keychain/keychain_test.go +++ b/keychain/keychain_test.go @@ -8,8 +8,12 @@ package keychain import ( "math/rand" "testing" + + "github.com/allcloud-io/clisso/log" ) +var _ = log.NewLogger("panic","", false) + func randSeq(n int, letters []rune) []byte { b := make([]rune, n) for i := range b { diff --git a/okta/client_test.go b/okta/client_test.go index 2ba1244..2f11b39 100644 --- a/okta/client_test.go +++ b/okta/client_test.go @@ -10,8 +10,12 @@ import ( "net/http/httptest" "testing" "time" + + "github.com/allcloud-io/clisso/log" ) +var _ = log.NewLogger("panic","", false) + func getTestServer(data string) *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte(data)) diff --git a/onelogin/client_test.go b/onelogin/client_test.go index e90296d..425ea77 100644 --- a/onelogin/client_test.go +++ b/onelogin/client_test.go @@ -10,6 +10,8 @@ import ( "net/http/httptest" "net/url" "testing" + + "github.com/allcloud-io/clisso/log" ) func getTestServer(data string) *httptest.Server { @@ -25,6 +27,8 @@ func getTestServer(data string) *httptest.Server { var c = Client{} +var _ = log.NewLogger("panic","", false) + func TestNewClient(t *testing.T) { for _, test := range []struct { name string diff --git a/onelogin/endpoints_test.go b/onelogin/endpoints_test.go index d891065..6e8d609 100644 --- a/onelogin/endpoints_test.go +++ b/onelogin/endpoints_test.go @@ -8,8 +8,11 @@ package onelogin import ( "net/url" "testing" + + "github.com/allcloud-io/clisso/log" ) +var _ = log.NewLogger("panic","", false) func TestEndpoints_SetBase(t *testing.T) { for _, test := range []struct { name string diff --git a/saml/saml_test.go b/saml/saml_test.go index 30be99d..74ef7f3 100644 --- a/saml/saml_test.go +++ b/saml/saml_test.go @@ -9,9 +9,12 @@ import ( "os" "testing" + "github.com/allcloud-io/clisso/log" "github.com/crewjam/saml" ) +var _ = log.NewLogger("panic","", false) + func TestExtractArns(t *testing.T) { for _, test := range []struct { name string From 3c5f2da74659f52d9665a93d4201bf7ba3c2d49b Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Fri, 8 Mar 2024 17:09:50 +0100 Subject: [PATCH 09/12] check return code --- cmd/root.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cmd/root.go b/cmd/root.go index 27bc738..590e8b0 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -153,7 +153,11 @@ func bindFlags(cmd *cobra.Command, v *viper.Viper) { if !f.Changed && v.IsSet(configName) { //fmt.Fprintf(os.Stderr, "Setting Flag %s by config: %s\n", f.Name, configName) val := v.Get(configName) - cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)) + err := cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)) + if err != nil { + // no logger yet, so print to stderr + fmt.Fprintf(os.Stderr, "Error setting flag %s: %v\n", f.Name, err) + } /*} else { fmt.Fprintf(os.Stderr, "Using Flag %s default: %v\n", f.Name, f.DefValue)*/ } From 19f1dd04fc7e9553883092780b98ccb873345e40 Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Tue, 30 Apr 2024 16:16:04 +0200 Subject: [PATCH 10/12] documentation, streamline help --- .gitignore | 1 + README.md | 35 +++++++++++++++++++++++++++++++++-- cmd/get.go | 6 +++--- cmd/root.go | 11 ++++++----- cmd/status.go | 2 +- 5 files changed, 44 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index c68230e..be336aa 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ steps_output.txt bottle_output.txt dist/ +clisso.yaml \ No newline at end of file diff --git a/README.md b/README.md index 18eb44c..d8e2096 100644 --- a/README.md +++ b/README.md @@ -283,11 +283,42 @@ with the app's name as the [profile name][10]. You can use the temporary credent the profile name as an argument to the AWS CLI (`--profile my-profile`), by setting the `AWS_PROFILE` environment variable or by configuring any AWS SDK to use the profile. -To save the credentials to a custom file, use the `-w` flag. +To save the credentials to a custom file, use the `--output` flag with a custom path. For example: -To print the credentials to the shell instead of storing them in a file, use the `-s` flag. This + clisso get my-app --output /path/to/credentials + +To print the credentials to the shell instead of storing them in a file, use the `--output environment` flag. This will output shell commands which can be pasted in any shell to use the credentials. +### Running as `credential_process` + +AWS CLI v2 introduced the `credential_process` feature which allows you to use an external command to obtain temporal credentials. +Clisso can be used as a `credential_process` command by setting the `--output credential_process` flag. For example: + + clisso get my-app --output credential_process + +You can use this by adding the following to your `~/.aws/credentials` file: + +```ini +[my-app] +credential_process = clisso get my-app --output credential_process +``` + +The AWS SDK does not cache any credentials obtained using `credential_process`. This means that every time you use the profile, Clisso will be called to obtain new credentials. If you want to cache the credentials, you can use the `--cache` flag. For example: + +```ini +[my-app] +credential_process = clisso get my-app --output credential_process --cache +``` + +Alternatively you can set it in the `~/.clisso.yaml` file: + +```yaml +global: + cache: + enable: true +``` + ### Storing the password in the key chain > WARNING: Storing the password without having MFA enabled is a security risk. It allows anyone diff --git a/cmd/get.go b/cmd/get.go index e5c35b1..af3c2fc 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -31,7 +31,7 @@ var cacheToFile string func init() { RootCmd.AddCommand(cmdGet) cmdGet.Flags().StringVarP( - &output, "output", "o", "$HOME/.aws/credentials", "How or where to output credentials. Two special values are supported 'environment' and 'credential_process'. All other values are interpreted as file paths", + &output, "output", "o", "~/.aws/credentials", "How or where to output credentials. Two special values are supported 'environment' and 'credential_process'. All other values are interpreted as file paths", ) cmdGet.Flags().BoolVarP( @@ -39,13 +39,13 @@ func init() { "Should credentials be cached to a file, important when run as a credential_process (default: false)", ) cmdGet.Flags().StringVarP( - &cacheToFile, "cache-path", "", "$HOME/.aws/credentials-cache", + &cacheToFile, "cache-path", "", "~/.aws/credentials-cache", "Write credentials to this file instead of the default", ) // Keep the old flags as is. cmdGet.Flags().StringVarP( - &writeToFile, "write-to-file", "w", "$HOME/.aws/credentials", + &writeToFile, "write-to-file", "w", "~/.aws/credentials", "Write credentials to this file instead of the default", ) cmdGet.Flags().BoolVarP( diff --git a/cmd/root.go b/cmd/root.go index 590e8b0..6195801 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -74,7 +74,7 @@ one at https://mozilla.org/MPL/2.0/. func init() { RootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", - "config file (default is $HOME/.clisso.yaml)", + "config file (default is ~/.clisso.yaml)", ) // Add a global log level flag RootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "", "info", "set log level to trace, debug, info, warn, error, fatal or panic") @@ -85,7 +85,7 @@ func init() { // } RootCmd.PersistentFlags().StringVarP( - &logFile, "log-file", "", "$HOME/.clisso.log", "log file location", + &logFile, "log-file", "", "~/.clisso.log", "log file location", ) // err = viper.BindPFlag("global.log.file", RootCmd.PersistentFlags().Lookup("log-file")) // if err != nil { @@ -123,7 +123,7 @@ func initConfig(cmd *cobra.Command) error { if _, err := os.Stat(file); os.IsNotExist(err) { _, err := os.Create(file) if err != nil { - log.Log.Fatalf("Error creating config file: %v", err) + panic(fmt.Errorf("can't create config file: %v", err)) } } @@ -133,10 +133,11 @@ func initConfig(cmd *cobra.Command) error { } if err := viper.ReadInConfig(); err != nil { - log.Log.Fatalf("Can't read config: %v", err) + // no logger yet, panic + panic(fmt.Errorf("can't read config: %v", err)) } bindFlags(cmd, viper.GetViper()) - log.Log = log.NewLogger(logLevel, logFile, logFile != "") + _ = log.NewLogger(logLevel, logFile, logFile != "") return nil } diff --git a/cmd/status.go b/cmd/status.go index df5fb51..d0d909e 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -24,7 +24,7 @@ func init() { RootCmd.AddCommand(cmdStatus) cmdStatus.Flags().StringVarP( &readFromFile, "read-from-file", "r", "", - "Read credentials from this file instead of the default ($HOME/.aws/credentials)", + "Read credentials from this file instead of the default (~/.aws/credentials)", ) err := viper.BindPFlag("global.credentials-path", cmdStatus.Flags().Lookup("read-from-file")) if err != nil { From 5950a72c125a0f0e105fde12ad6b890378f1b19a Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Tue, 30 Apr 2024 16:36:09 +0200 Subject: [PATCH 11/12] disable signing hooks --- .goreleaser.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 54081e7..e56a901 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -24,16 +24,16 @@ builds: - darwin goarch: - "amd64" - hooks: - post: gon gon-amd64.json + # hooks: + # post: gon gon-amd64.json - id: macos-arm64 binary: clisso goos: - darwin goarch: - "arm64" - hooks: - post: gon gon-arm64.json + # hooks: + # post: gon gon-arm64.json archives: - format: zip From 80ac4e01199bed90ccd43c67c643f602f6567d97 Mon Sep 17 00:00:00 2001 From: Jonathan Vogt Date: Tue, 30 Apr 2024 16:38:46 +0200 Subject: [PATCH 12/12] go formatting --- aws/aws_test.go | 2 +- cmd/root.go | 2 +- keychain/keychain_test.go | 2 +- okta/client_test.go | 2 +- onelogin/client_test.go | 2 +- onelogin/endpoints_test.go | 3 ++- saml/saml_test.go | 2 +- 7 files changed, 8 insertions(+), 7 deletions(-) diff --git a/aws/aws_test.go b/aws/aws_test.go index 6f89b61..56e3a2a 100644 --- a/aws/aws_test.go +++ b/aws/aws_test.go @@ -16,7 +16,7 @@ import ( "github.com/go-ini/ini" ) -var _ = log.NewLogger("panic","", false) +var _ = log.NewLogger("panic", "", false) func TestWriteToFile(t *testing.T) { id := "expiredkey" diff --git a/cmd/root.go b/cmd/root.go index 6195801..11caa1d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -159,7 +159,7 @@ func bindFlags(cmd *cobra.Command, v *viper.Viper) { // no logger yet, so print to stderr fmt.Fprintf(os.Stderr, "Error setting flag %s: %v\n", f.Name, err) } - /*} else { + /*} else { fmt.Fprintf(os.Stderr, "Using Flag %s default: %v\n", f.Name, f.DefValue)*/ } }) diff --git a/keychain/keychain_test.go b/keychain/keychain_test.go index fe5940c..6e71d3a 100644 --- a/keychain/keychain_test.go +++ b/keychain/keychain_test.go @@ -12,7 +12,7 @@ import ( "github.com/allcloud-io/clisso/log" ) -var _ = log.NewLogger("panic","", false) +var _ = log.NewLogger("panic", "", false) func randSeq(n int, letters []rune) []byte { b := make([]rune, n) diff --git a/okta/client_test.go b/okta/client_test.go index 2f11b39..6ee9f4f 100644 --- a/okta/client_test.go +++ b/okta/client_test.go @@ -14,7 +14,7 @@ import ( "github.com/allcloud-io/clisso/log" ) -var _ = log.NewLogger("panic","", false) +var _ = log.NewLogger("panic", "", false) func getTestServer(data string) *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/onelogin/client_test.go b/onelogin/client_test.go index 425ea77..8ae6eb5 100644 --- a/onelogin/client_test.go +++ b/onelogin/client_test.go @@ -27,7 +27,7 @@ func getTestServer(data string) *httptest.Server { var c = Client{} -var _ = log.NewLogger("panic","", false) +var _ = log.NewLogger("panic", "", false) func TestNewClient(t *testing.T) { for _, test := range []struct { diff --git a/onelogin/endpoints_test.go b/onelogin/endpoints_test.go index 6e8d609..3748393 100644 --- a/onelogin/endpoints_test.go +++ b/onelogin/endpoints_test.go @@ -12,7 +12,8 @@ import ( "github.com/allcloud-io/clisso/log" ) -var _ = log.NewLogger("panic","", false) +var _ = log.NewLogger("panic", "", false) + func TestEndpoints_SetBase(t *testing.T) { for _, test := range []struct { name string diff --git a/saml/saml_test.go b/saml/saml_test.go index 74ef7f3..0debe5e 100644 --- a/saml/saml_test.go +++ b/saml/saml_test.go @@ -13,7 +13,7 @@ import ( "github.com/crewjam/saml" ) -var _ = log.NewLogger("panic","", false) +var _ = log.NewLogger("panic", "", false) func TestExtractArns(t *testing.T) { for _, test := range []struct {