diff --git a/.gitignore b/.gitignore index c68230e..be336aa 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ steps_output.txt bottle_output.txt dist/ +clisso.yaml \ No newline at end of file diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 54081e7..e56a901 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -24,16 +24,16 @@ builds: - darwin goarch: - "amd64" - hooks: - post: gon gon-amd64.json + # hooks: + # post: gon gon-amd64.json - id: macos-arm64 binary: clisso goos: - darwin goarch: - "arm64" - hooks: - post: gon gon-arm64.json + # hooks: + # post: gon gon-arm64.json archives: - format: zip diff --git a/README.md b/README.md index 18eb44c..d8e2096 100644 --- a/README.md +++ b/README.md @@ -283,11 +283,42 @@ with the app's name as the [profile name][10]. You can use the temporary credent the profile name as an argument to the AWS CLI (`--profile my-profile`), by setting the `AWS_PROFILE` environment variable or by configuring any AWS SDK to use the profile. -To save the credentials to a custom file, use the `-w` flag. +To save the credentials to a custom file, use the `--output` flag with a custom path. For example: -To print the credentials to the shell instead of storing them in a file, use the `-s` flag. This + clisso get my-app --output /path/to/credentials + +To print the credentials to the shell instead of storing them in a file, use the `--output environment` flag. This will output shell commands which can be pasted in any shell to use the credentials. +### Running as `credential_process` + +AWS CLI v2 introduced the `credential_process` feature which allows you to use an external command to obtain temporal credentials. +Clisso can be used as a `credential_process` command by setting the `--output credential_process` flag. For example: + + clisso get my-app --output credential_process + +You can use this by adding the following to your `~/.aws/credentials` file: + +```ini +[my-app] +credential_process = clisso get my-app --output credential_process +``` + +The AWS SDK does not cache any credentials obtained using `credential_process`. This means that every time you use the profile, Clisso will be called to obtain new credentials. If you want to cache the credentials, you can use the `--cache` flag. For example: + +```ini +[my-app] +credential_process = clisso get my-app --output credential_process --cache +``` + +Alternatively you can set it in the `~/.clisso.yaml` file: + +```yaml +global: + cache: + enable: true +``` + ### Storing the password in the key chain > WARNING: Storing the password without having MFA enabled is a security risk. It allows anyone diff --git a/aws/aws.go b/aws/aws.go index 67e84dd..ba66d6f 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -10,8 +10,9 @@ import ( "io" "time" + "github.com/allcloud-io/clisso/log" "github.com/go-ini/ini" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) // Credentials represents a set of temporary credentials received from AWS STS @@ -32,11 +33,11 @@ type Profile struct { const expireKey = "aws_expiration" -// WriteToFile writes credentials to an AWS CLI credentials file +// OutputFile writes credentials to an AWS CLI credentials file // (https://docs.aws.amazon.com/cli/latest/userguide/cli-config-files.html). In addition, this // function removes expired temporary credentials from the credentials file. -func WriteToFile(c *Credentials, filename string, section string) error { - log.WithFields(log.Fields{ +func OutputFile(c *Credentials, filename string, section string) error { + log.Log.WithFields(logrus.Fields{ "filename": filename, "section": section, }).Debug("Writing credentials to file") @@ -65,29 +66,29 @@ func WriteToFile(c *Credentials, filename string, section string) error { // Remove expired credentials. for _, s := range cfg.Sections() { if !s.HasKey(expireKey) { - log.Tracef("Skipping profile %s because it does not have an %s key", s.Name(), expireKey) + log.Log.Tracef("Skipping profile %s because it does not have an %s key", s.Name(), expireKey) continue } v, err := s.Key(expireKey).TimeFormat(time.RFC3339) if err != nil { - log.Warnf("Cannot parse date (%v) in profile %s: %s", + log.Log.Warnf("Cannot parse date (%v) in profile %s: %s", s.Key(expireKey), s.Name(), err) continue } if time.Now().UTC().Unix() > v.Unix() { - log.Tracef("Removing expired credentials for profile %s", s.Name()) + log.Log.Tracef("Removing expired credentials for profile %s", s.Name()) cfg.DeleteSection(s.Name()) continue } - log.Tracef("Profile %s expires at %s", s.Name(), v.Format(time.RFC3339)) + log.Log.Tracef("Profile %s expires at %s", s.Name(), v.Format(time.RFC3339)) } return cfg.SaveTo(filename) } -// WriteToShell writes (prints) credentials to stdout. If windows is true, Windows syntax will be -// used. -func WriteToShell(c *Credentials, windows bool, w io.Writer) { +// OutputEnvironment writes (prints) credentials to stdout. If windows is true, Windows syntax will be +// used. The output can be used to set environment variables. +func OutputEnvironment(c *Credentials, windows bool, w io.Writer) { fmt.Print("Please paste the following in your shell:") if windows { fmt.Fprintf( @@ -108,21 +109,37 @@ func WriteToShell(c *Credentials, windows bool, w io.Writer) { } } -// GetValidCredentials returns profiles which have a aws_expiration key but are not yet expired. -func GetValidCredentials(filename string) ([]Profile, error) { +// OutputCredentialProcess writes (prints) credentials to stdout in the format required by the AWS CLI. +// The output can be used to set the credential_process option in the AWS CLI configuration file. +func OutputCredentialProcess(c *Credentials, w io.Writer) { + log.Log.Trace("Writing credentials to stdout in credential_process format") + log.Log.Infof("Credentials expire at %s, in %d Minutes", c.Expiration.Format(time.RFC3339), int(c.Expiration.Sub(time.Now().UTC()).Minutes())) + fmt.Fprintf( + w, + `{ "Version": 1, "AccessKeyId": %q, "SecretAccessKey": %q, "SessionToken": %q, "Expiration": %q }`, + c.AccessKeyID, + c.SecretAccessKey, + c.SessionToken, + // Time must be in ISO8601 format + c.Expiration.Format(time.RFC3339), + ) +} + +// GetValidProfiles returns profiles which have a aws_expiration key but are not yet expired. +func GetValidProfiles(filename string) ([]Profile, error) { var profiles []Profile - log.WithField("filename", filename).Trace("Loading AWS credentials file") + log.Log.WithField("filename", filename).Trace("Loading AWS credentials file") cfg, err := ini.LooseLoad(filename) if err != nil { err = fmt.Errorf("%s contains errors: %w", filename, err) - log.WithError(err).Trace("Failed to load AWS credentials file") + log.Log.WithError(err).Trace("Failed to load AWS credentials file") return nil, err } for _, s := range cfg.Sections() { if s.HasKey(expireKey) { v, err := s.Key(expireKey).TimeFormat(time.RFC3339) if err != nil { - log.Warnf("Cannot parse date (%v) in section %s: %s", + log.Log.Warnf("Cannot parse date (%v) in section %s: %s", s.Key(expireKey), s.Name(), err) continue } @@ -136,3 +153,38 @@ func GetValidCredentials(filename string) ([]Profile, error) { } return profiles, nil } + +// GetValidCredentials returns credentials which have a aws_expiration key but are not yet expired. +// returns a map of profile name to credentials +func GetValidCredentials(filename string) (map[string]Credentials, error) { + credentials := make(map[string]Credentials) + log.Log.WithField("filename", filename).Trace("Loading credentials file") + cfg, err := ini.LooseLoad(filename) + if err != nil { + err = fmt.Errorf("%s contains errors: %w", filename, err) + log.Log.WithError(err).Trace("Failed to load credentials file") + return nil, err + } + for _, s := range cfg.Sections() { + if s.HasKey(expireKey) { + v, err := s.Key(expireKey).TimeFormat(time.RFC3339) + if err != nil { + log.Log.Warnf("Cannot parse date (%v) in section %s: %s", + s.Key(expireKey), s.Name(), err) + continue + } + + if time.Now().UTC().Unix() < v.Unix() { + credential := Credentials{ + AccessKeyID: s.Key("aws_access_key_id").String(), + SecretAccessKey: s.Key("aws_secret_access_key").String(), + SessionToken: s.Key("aws_session_token").String(), + Expiration: v, + } + credentials[s.Name()] = credential + } + + } + } + return credentials, nil +} diff --git a/aws/aws_test.go b/aws/aws_test.go index 8a9b1af..56e3a2a 100644 --- a/aws/aws_test.go +++ b/aws/aws_test.go @@ -12,9 +12,12 @@ import ( "testing" "time" + "github.com/allcloud-io/clisso/log" "github.com/go-ini/ini" ) +var _ = log.NewLogger("panic", "", false) + func TestWriteToFile(t *testing.T) { id := "expiredkey" sec := "expiredsecret" @@ -32,7 +35,7 @@ func TestWriteToFile(t *testing.T) { p := "expiredprofile" // Write credentials - err := WriteToFile(&c, fn, p) + err := OutputFile(&c, fn, p) if err != nil { t.Fatal("Could not write credentials to file: ", err) } @@ -55,7 +58,7 @@ func TestWriteToFile(t *testing.T) { p = "testprofile" // Write credentials - err = WriteToFile(&c, fn, p) + err = OutputFile(&c, fn, p) if err != nil { t.Fatal("Could not write credentials to file: ", err) } @@ -120,7 +123,7 @@ func TestWriteToFile(t *testing.T) { } } -func TestGetValidCredentials(t *testing.T) { +func TestGetValidProfiles(t *testing.T) { fn := "test_creds.txt" id := "testkey" @@ -139,7 +142,7 @@ func TestGetValidCredentials(t *testing.T) { p := "expired" // Write credentials - err := WriteToFile(&c, fn, p) + err := OutputFile(&c, fn, p) if err != nil { t.Fatal("Could not write credentials to file: ", err) } @@ -149,7 +152,7 @@ func TestGetValidCredentials(t *testing.T) { p = "valid" // Write credentials - err = WriteToFile(&c, fn, p) + err = OutputFile(&c, fn, p) if err != nil { t.Fatal("Could not write credentials to file: ", err) } @@ -179,7 +182,7 @@ func TestGetValidCredentials(t *testing.T) { time.Sleep(time.Duration(1) * time.Second) - profiles, err := GetValidCredentials(fn) + profiles, err := GetValidProfiles(fn) if err != nil { t.Fatal("Failed to get NonExpiredCredentials") } @@ -202,13 +205,13 @@ func TestGetValidCredentials(t *testing.T) { t.Fatalf("Could not remove file %v during cleanup", fn) } - _, err = GetValidCredentials(fn) + _, err = GetValidProfiles(fn) if err != nil { t.Fatal("Function did crash on missing file") } } -func TestWriteToShellUnix(t *testing.T) { +func TestOutputUnixEnvironment(t *testing.T) { id := "testkey" sec := "testsecret" tok := "testtoken" @@ -222,7 +225,7 @@ func TestWriteToShellUnix(t *testing.T) { } var b bytes.Buffer - WriteToShell(&c, false, &b) + OutputEnvironment(&c, false, &b) got := b.String() want := fmt.Sprintf( @@ -237,7 +240,7 @@ func TestWriteToShellUnix(t *testing.T) { } } -func TestWriteToShellWindows(t *testing.T) { +func TestOutputWindowsEnvironment(t *testing.T) { id := "testkey" sec := "testsecret" tok := "testtoken" @@ -251,7 +254,7 @@ func TestWriteToShellWindows(t *testing.T) { } var b bytes.Buffer - WriteToShell(&c, true, &b) + OutputEnvironment(&c, true, &b) got := b.String() want := fmt.Sprintf( diff --git a/aws/sts.go b/aws/sts.go index 524967e..748fedf 100644 --- a/aws/sts.go +++ b/aws/sts.go @@ -11,12 +11,13 @@ import ( "regexp" "strings" + "github.com/allcloud-io/clisso/log" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go" "github.com/icza/gog" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) const ( @@ -41,13 +42,13 @@ const ( // returns a specific error message to indicate that. In this case we return a custom error to the // caller to allow special handling such as retrying with a lower duration. func AssumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, duration int32) (*Credentials, error) { - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "PrincipalArn": PrincipalArn, "RoleArn": RoleArn, "awsRegion": awsRegion, "duration": duration, }).Debug("Assuming role with SAML assertion") - log.WithField("SAMLAssertion", SAMLAssertion).Trace("SAML assertion") + log.Log.WithField("SAMLAssertion", SAMLAssertion).Trace("SAML assertion") creds, err := assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion, duration) if err != nil { // Check if API error returned by AWS @@ -80,15 +81,15 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura config, err := config.LoadDefaultConfig(ctx, config.WithRegion(awsRegion)) if err != nil { - log.WithError(err).Debug("Error loading default configuration") + log.Log.WithError(err).Debug("Error loading default configuration") return nil, err } - log.WithField("awsRegion", config.Region).Trace("Loaded default config") + log.Log.WithField("awsRegion", config.Region).Trace("Loaded default config") // If we request credentials for China we need to provide a Chinese region idp := regexp.MustCompile(`^arn:aws-cn:iam::\d+:saml-provider\/\S+$`) if idp.MatchString(PrincipalArn) && !strings.HasPrefix(awsRegion, "cn-") { - log.Trace("Setting region to cn-north-1") + log.Log.Trace("Setting region to cn-north-1") config.Region = "cn-north-1" } svc := sts.NewFromConfig(config, func(o *sts.Options) { @@ -98,7 +99,7 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura aResp, err := svc.AssumeRoleWithSAML(ctx, &input) if err != nil { - log.WithError(err).Debug("Error assuming role with SAML assertion") + log.Log.WithError(err).Debug("Error assuming role with SAML assertion") return nil, err } @@ -107,10 +108,10 @@ func assumeSAMLRole(PrincipalArn, RoleArn, SAMLAssertion, awsRegion string, dura sessionToken := *aResp.Credentials.SessionToken expiration := *aResp.Credentials.Expiration - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "AccessKeyID": keyID, - "SecretAccessKey": gog.If(log.GetLevel() == log.TraceLevel, secretKey, ""), - "SessionToken": gog.If(log.GetLevel() == log.TraceLevel, sessionToken, ""), + "SecretAccessKey": gog.If(log.Log.GetLevel() == logrus.TraceLevel, secretKey, ""), + "SessionToken": gog.If(log.Log.GetLevel() == logrus.TraceLevel, sessionToken, ""), "Expiration": expiration, }).Debug("Got temporary credentials") diff --git a/cmd/apps.go b/cmd/apps.go index 5a56bbf..aa70dff 100644 --- a/cmd/apps.go +++ b/cmd/apps.go @@ -10,7 +10,7 @@ import ( "sort" "strconv" - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -63,7 +63,7 @@ var cmdAppsList = &cobra.Command{ Long: "List all configured apps.", Run: func(cmd *cobra.Command, args []string) { apps := viper.GetStringMap("apps") - log.Trace("Listing apps") + log.Log.Trace("Listing apps") if len(apps) == 0 { fmt.Println("No apps configured") @@ -105,18 +105,18 @@ var cmdAppsCreateOneLogin = &cobra.Command{ // Verify app doesn't exist if exists := viper.Get("apps." + name); exists != nil { - log.Fatalf("App '%s' already exists", name) + log.Log.Fatalf("App '%s' already exists", name) } // Verify provider exists if exists := viper.Get("providers." + provider); exists == nil { - log.Fatalf("Provider '%s' doesn't exist", provider) + log.Log.Fatalf("Provider '%s' doesn't exist", provider) } // Verify provider type pType := viper.GetString(fmt.Sprintf("providers.%s.type", provider)) if pType != "onelogin" { - log.Fatalf( + log.Log.Fatalf( "Invalid provider type '%s' for a OneLogin app. Type must be 'onelogin'.", pType, ) @@ -134,9 +134,9 @@ var cmdAppsCreateOneLogin = &cobra.Command{ if duration != 0 { // Duration specified - validate value if duration < 3600 || duration > 43200 { - log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } - log.Tracef("Setting duration to %d", duration) + log.Log.Tracef("Setting duration to %d", duration) conf["duration"] = strconv.Itoa(duration) } @@ -145,9 +145,9 @@ var cmdAppsCreateOneLogin = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Fatalf("Error writing config: %v", err) + log.Log.Fatalf("Error writing config: %v", err) } - log.Printf("App '%s' saved to config file", name) + log.Log.Printf("App '%s' saved to config file", name) }, } @@ -161,18 +161,18 @@ var cmdAppsCreateOkta = &cobra.Command{ // Verify app doesn't exist if exists := viper.Get("apps." + name); exists != nil { - log.Fatalf("App '%s' already exists", name) + log.Log.Fatalf("App '%s' already exists", name) } // Verify provider exists if exists := viper.Get("providers." + provider); exists == nil { - log.Fatalf("Provider '%s' doesn't exist", provider) + log.Log.Fatalf("Provider '%s' doesn't exist", provider) } // Verify provider type pType := viper.GetString(fmt.Sprintf("providers.%s.type", provider)) if pType != "okta" { - log.Fatalf( + log.Log.Fatalf( "Invalid provider type '%s' for an Okta app. Type must be 'okta'.", pType, ) @@ -186,9 +186,9 @@ var cmdAppsCreateOkta = &cobra.Command{ if duration != 0 { // Duration specified - validate value if duration < 3600 || duration > 43200 { - log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } - log.Tracef("Setting duration to %d", duration) + log.Log.Tracef("Setting duration to %d", duration) conf["duration"] = strconv.Itoa(duration) } @@ -197,9 +197,9 @@ var cmdAppsCreateOkta = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Fatalf("Error writing config: %v", err) + log.Log.Fatalf("Error writing config: %v", err) } - log.Printf("App '%s' saved to config file", name) + log.Log.Printf("App '%s' saved to config file", name) }, } @@ -213,19 +213,19 @@ var cmdAppsSelect = &cobra.Command{ if app == "" { viper.Set("global.selected-app", "") - log.Println("Unsetting selected app") + log.Log.Println("Unsetting selected app") } else { if exists := viper.Get("apps." + app); exists == nil { - log.Fatalf("App '%s' doesn't exist", app) + log.Log.Fatalf("App '%s' doesn't exist", app) } - log.Printf("Setting selected app to '%s'", app) + log.Log.Printf("Setting selected app to '%s'", app) viper.Set("global.selected-app", app) } // Write config to file err := viper.WriteConfig() if err != nil { - log.Fatalf("Error writing config: %v", err) + log.Log.Fatalf("Error writing config: %v", err) } }, } diff --git a/cmd/get.go b/cmd/get.go index 2bda408..af3c2fc 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -11,8 +11,8 @@ import ( "path/filepath" "runtime" + "github.com/allcloud-io/clisso/log" "github.com/mitchellh/go-homedir" - log "github.com/sirupsen/logrus" "github.com/allcloud-io/clisso/aws" "github.com/allcloud-io/clisso/okta" @@ -21,21 +21,99 @@ import ( "github.com/spf13/viper" ) +var output string var printToShell bool +var printToCredentialProcess bool +var cacheCredentials bool var writeToFile string +var cacheToFile string func init() { RootCmd.AddCommand(cmdGet) + cmdGet.Flags().StringVarP( + &output, "output", "o", "~/.aws/credentials", "How or where to output credentials. Two special values are supported 'environment' and 'credential_process'. All other values are interpreted as file paths", + ) + cmdGet.Flags().BoolVarP( - &printToShell, "shell", "s", false, "Print credentials to shell", + &cacheCredentials, "cache-enable", "", false, + "Should credentials be cached to a file, important when run as a credential_process (default: false)", ) cmdGet.Flags().StringVarP( - &writeToFile, "write-to-file", "w", "", - "Write credentials to this file instead of the default ($HOME/.aws/credentials)", + &cacheToFile, "cache-path", "", "~/.aws/credentials-cache", + "Write credentials to this file instead of the default", ) - err := viper.BindPFlag("global.credentials-path", cmdGet.Flags().Lookup("write-to-file")) + + // Keep the old flags as is. + cmdGet.Flags().StringVarP( + &writeToFile, "write-to-file", "w", "~/.aws/credentials", + "Write credentials to this file instead of the default", + ) + cmdGet.Flags().BoolVarP( + &printToShell, "shell", "s", false, "Print credentials to shell to be sourced as environment variables", + ) + + // Mark the old flag as deprecated. + err := cmdGet.Flags().MarkDeprecated("write-to-file", "please use output-file instead.") if err != nil { - log.Fatalf("Error binding flag global.credentials-path: %v", err) + // we don't have a logger yet, so we can't use it but need to print the error to the console + fmt.Printf("Error marking flag as deprecated: %v", err) + } + err = cmdGet.Flags().MarkDeprecated("shell", "please use output-environment instead.") + if err != nil { + // we don't have a logger yet, so we can't use it but need to print the error to the console + fmt.Printf("Error marking flag as deprecated: %v", err) + } + + cmdGet.MarkFlagsMutuallyExclusive("output", "shell", "write-to-file") + +} + +func preferredOutput(cmd *cobra.Command, app string) string { + // Order of preference: + // * output flag + // * write-to-file flag (deprecated) + // * app specific config file + // * global config file + // * default to ~/.aws/credentials + out, err := cmd.Flags().GetString("output") + if err != nil { + log.Log.Warnf("Error getting output flag: %v", err) + } + if out != "" { + return out + } + + out, err = cmd.Flags().GetString("write-to-file") + if err != nil { + log.Log.Warnf("Error getting write-to-file flag: %v", err) + } + if out != "" { + return out + } + + out = viper.GetString(fmt.Sprintf("apps.%s.output", app)) + if out != "" { + return out + } + + out = viper.GetString("global.output") + if out != "" { + return out + } + + return "~/.aws/credentials" +} + +func setOutput(cmd *cobra.Command, app string) { + o := preferredOutput(cmd, app) + writeToFile = "" + switch o { + case "environment": + printToShell = true + case "credential_process": + printToCredentialProcess = true + default: + writeToFile = o } } @@ -43,30 +121,49 @@ func init() { func processCredentials(creds *aws.Credentials, app string) error { if printToShell { // Print credentials to shell using the correct syntax for the OS. - aws.WriteToShell(creds, runtime.GOOS == "windows", os.Stdout) - } else { - path, err := homedir.Expand(viper.GetString("global.credentials-path")) - if err != nil { - return fmt.Errorf("expanding config file path: %v", err) - } + aws.OutputEnvironment(creds, runtime.GOOS == "windows", os.Stdout) + } - // Create the `global.credentials-path` directory if it doesn't exist. - credsFileParentDir := filepath.Dir(path) - if _, err := os.Stat(credsFileParentDir); os.IsNotExist(err) { - log.Warnf("Credentials directory '%s' does not exist - creating it", credsFileParentDir) + if printToCredentialProcess { + aws.OutputCredentialProcess(creds, os.Stdout) + } - err = os.MkdirAll(credsFileParentDir, 0755) - if err != nil { - return fmt.Errorf("creating credentials directory: %v", err) - } + if cacheCredentials { + if err := writeCredentialsToFile(creds, app, cacheToFile); err != nil { + log.Log.Errorf("writing credentials to file: %v", err) } + } - if err = aws.WriteToFile(creds, path, app); err != nil { + // if writeToFile is set, write the credentials to the file, might be the cache file or the credentials file + if writeToFile != "" { + if err := writeCredentialsToFile(creds, app, writeToFile); err != nil { return fmt.Errorf("writing credentials to file: %v", err) } - log.Printf("Credentials written successfully to '%s'", path) + } + return nil +} + +func writeCredentialsToFile(creds *aws.Credentials, app, file string) error { + log.Log.Tracef("Writing credentials to '%s'", file) + path, err := homedir.Expand(file) + if err != nil { + return fmt.Errorf("expanding config file path: %v", err) + } + // Create the `global.credentials-path` directory if it doesn't exist. + credsFileParentDir := filepath.Dir(path) + if _, err := os.Stat(credsFileParentDir); os.IsNotExist(err) { + log.Log.Warnf("Credentials directory '%s' does not exist - creating it", credsFileParentDir) + // Lets default to strict permissions on the folders we create + err = os.MkdirAll(credsFileParentDir, 0700) + if err != nil { + return fmt.Errorf("creating credentials directory: %v", err) + } } + if err := aws.OutputFile(creds, path, app); err != nil { + return fmt.Errorf("writing credentials to file: %v", err) + } + log.Log.Printf("Credentials written successfully to '%s'", path) return nil } @@ -100,6 +197,32 @@ func awsRegion(app string) string { return "aws-global" } +func getCachedCredential(app string) (*aws.Credentials, error) { + // get the credentials from the cache file + log.Log.Tracef("Looking for cached credentials in '%s'", cacheToFile) + credentialFile, err := homedir.Expand(cacheToFile) + if err != nil { + log.Log.Fatalf("Failed to expand home: %s", err) + } + + profiles, err := aws.GetValidCredentials(credentialFile) + if err != nil { + log.Log.Fatalf("Failed to retrieve non-expired credentials: %s", err) + } + + if len(profiles) == 0 { + return nil, nil + } + + // find the app we are looking for + for k, p := range profiles { + if k == app { + return &p, nil + } + } + return nil, fmt.Errorf("no valid credentials found for app '%s'", app) +} + var cmdGet = &cobra.Command{ Use: "get", Short: "Get temporary credentials for an app", @@ -115,7 +238,7 @@ If no app is specified, the selected app (if configured) will be assumed.`, selected := viper.GetString("global.selected-app") if selected == "" { // No default app configured. - log.Fatal("No app specified and no default app configured") + log.Log.Fatal("No app specified and no default app configured") } app = selected } else { @@ -125,12 +248,12 @@ If no app is specified, the selected app (if configured) will be assumed.`, provider := viper.GetString(fmt.Sprintf("apps.%s.provider", app)) if provider == "" { - log.Fatalf("Could not get provider for app '%s'", app) + log.Log.Fatalf("Could not get provider for app '%s'", app) } pType := viper.GetString(fmt.Sprintf("providers.%s.type", provider)) if pType == "" { - log.Fatalf("Could not get provider type for provider '%s'", provider) + log.Log.Fatalf("Could not get provider type for provider '%s'", provider) } // allow preferred "arn" to be specified in the config file for each app @@ -141,29 +264,47 @@ If no app is specified, the selected app (if configured) will be assumed.`, awsRegion := awsRegion(app) + setOutput(cmd, app) + + if printToCredentialProcess && cacheCredentials { + log.Log.Trace("Using --cache-credentials and --output-process") + // we need to cache the credentials to a file and return valid credentials instead of constantly hitting the IdPs + credential, err := getCachedCredential(app) + if err != nil { + log.Log.WithError(err).Debugf("Failed to find cached credentials for app '%s'", app) + } + if credential != nil { + aws.OutputCredentialProcess(credential, os.Stdout) + return + } + } + + interactive := !printToShell && !printToCredentialProcess if pType == "onelogin" { - creds, err := onelogin.Get(app, provider, pArn, awsRegion, duration) + creds, err := onelogin.Get(app, provider, pArn, awsRegion, duration, interactive) if err != nil { - log.Fatal("Could not get temporary credentials: ", err) + log.Log.Fatal("Could not get temporary credentials: ", err) } // Process credentials err = processCredentials(creds, app) if err != nil { - log.Fatalf("Error processing credentials: %v", err) + log.Log.Fatalf("Error processing credentials: %v", err) } } else if pType == "okta" { - creds, err := okta.Get(app, provider, pArn, awsRegion, duration) + creds, err := okta.Get(app, provider, pArn, awsRegion, duration, interactive) if err != nil { - log.Fatal("Could not get temporary credentials: ", err) + log.Log.Fatal("Could not get temporary credentials: ", err) } // Process credentials err = processCredentials(creds, app) if err != nil { - log.Fatalf("Error processing credentials: %v", err) + log.Log.Fatalf("Error processing credentials: %v", err) } } else { - log.Fatalf("Unsupported identity provider type '%s' for app '%s'", pType, app) + log.Log.Fatalf("Unsupported identity provider type '%s' for app '%s'", pType, app) + } + if interactive { + printStatus() } - printStatus() }, } diff --git a/cmd/helpers.go b/cmd/helpers.go index 3da06e9..59698dc 100644 --- a/cmd/helpers.go +++ b/cmd/helpers.go @@ -6,13 +6,13 @@ package cmd import ( - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" "github.com/spf13/cobra" ) func mandatoryFlag(cmd *cobra.Command, name string) { err := cmd.MarkFlagRequired(name) if err != nil { - log.Fatalf("Error marking flag %s as required: %v", name, err) + log.Log.Fatalf("Error marking flag %s as required: %v", name, err) } } diff --git a/cmd/providers.go b/cmd/providers.go index 6d84333..a9f0c3a 100644 --- a/cmd/providers.go +++ b/cmd/providers.go @@ -12,7 +12,7 @@ import ( "syscall" "github.com/allcloud-io/clisso/keychain" - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" "github.com/spf13/cobra" "github.com/spf13/viper" "golang.org/x/term" @@ -77,7 +77,7 @@ var cmdProvidersList = &cobra.Command{ providers := viper.GetStringMap("providers") if len(providers) == 0 { - log.Println("No providers configured") + log.Log.Println("No providers configured") return } @@ -88,7 +88,7 @@ var cmdProvidersList = &cobra.Command{ } sort.Strings(keys) for _, k := range keys { - log.Println(k) + log.Log.Println(k) } }, } @@ -104,16 +104,16 @@ var cmdProvidersPassword = &cobra.Command{ fmt.Printf("Please enter the password for the '%s' provider: ", provider) pass, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { - log.Fatalf("Could not read password") + log.Log.Fatalf("Could not read password") } keyChain := keychain.DefaultKeychain{} err = keyChain.Set(provider, pass) if err != nil { - log.Fatalf("Could not save to keychain: %+v", err) + log.Log.Fatalf("Could not save to keychain: %+v", err) } - log.Printf("Saved password for Provider '%s'", provider) + log.Log.Printf("Saved password for Provider '%s'", provider) }, } @@ -133,13 +133,13 @@ var cmdProvidersCreateOneLogin = &cobra.Command{ // Verify provider doesn't exist if exists := viper.Get("providers." + name); exists != nil { - log.Fatalf("Provider '%s' already exists", name) + log.Log.Fatalf("Provider '%s' already exists", name) } switch region { case "US", "EU": default: - log.Fatal("Region must be either US or EU") + log.Log.Fatal("Region must be either US or EU") } conf := map[string]string{ @@ -153,7 +153,7 @@ var cmdProvidersCreateOneLogin = &cobra.Command{ if providerDuration != 0 { // Duration specified - validate value if providerDuration < 3600 || providerDuration > 43200 { - log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } conf["duration"] = strconv.Itoa(providerDuration) } @@ -162,9 +162,9 @@ var cmdProvidersCreateOneLogin = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Fatalf("Error writing config: %v", err) + log.Log.Fatalf("Error writing config: %v", err) } - log.Printf("Provider '%s' saved to config file", name) + log.Log.Printf("Provider '%s' saved to config file", name) }, } @@ -178,7 +178,7 @@ var cmdProvidersCreateOkta = &cobra.Command{ // Verify provider doesn't exist if exists := viper.Get("providers." + name); exists != nil { - log.Fatalf("Provider '%s' already exists", name) + log.Log.Fatalf("Provider '%s' already exists", name) } conf := map[string]string{ @@ -189,7 +189,7 @@ var cmdProvidersCreateOkta = &cobra.Command{ if providerDuration != 0 { // Duration specified - validate value if providerDuration < 3600 || providerDuration > 43200 { - log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") + log.Log.Fatal("Invalid duration Specified. Valid values: 3600 - 43200") } conf["duration"] = strconv.Itoa(providerDuration) } @@ -198,8 +198,8 @@ var cmdProvidersCreateOkta = &cobra.Command{ // Write config to file err := viper.WriteConfig() if err != nil { - log.Fatalf("Error writing config: %v", err) + log.Log.Fatalf("Error writing config: %v", err) } - log.Printf("Provider '%s' saved to config file", name) + log.Log.Printf("Provider '%s' saved to config file", name) }, } diff --git a/cmd/root.go b/cmd/root.go index 953469d..11caa1d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -6,29 +6,28 @@ package cmd import ( + "fmt" "os" "path/filepath" + "strings" + "github.com/allcloud-io/clisso/log" homedir "github.com/mitchellh/go-homedir" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/spf13/pflag" "github.com/spf13/viper" ) var cfgFile string +var logFile string +var logLevel string var RootCmd = &cobra.Command{ Use: "clisso", Version: "0.0.0", - PersistentPreRun: func(cmd *cobra.Command, args []string) { - // get log level flag value - logLevelFlag := cmd.Flag("log-level").Value.String() - // parse log level flag and set log level - logLevel, err := log.ParseLevel(logLevelFlag) - if err != nil { - log.Fatalf("Error parsing log level: %v", err) - } - log.SetLevel(logLevel) + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + return initConfig(cmd) }, } @@ -74,13 +73,24 @@ one at https://mozilla.org/MPL/2.0/. ` func init() { - cobra.OnInitialize(initConfig) RootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", - "config file (default is $HOME/.clisso.yaml)", + "config file (default is ~/.clisso.yaml)", ) // Add a global log level flag - RootCmd.PersistentFlags().String("log-level", "info", "set log level to trace, debug, info, warn, error, fatal or panic") - + RootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "", "info", "set log level to trace, debug, info, warn, error, fatal or panic") + // err := viper.BindPFlag("global.log.level", RootCmd.PersistentFlags().Lookup("log-level")) + // if err != nil { + // // log isn't available yet, so we can't use it + // logrus.Fatalf("Error binding flag global.log.level: %v", err) + // } + + RootCmd.PersistentFlags().StringVarP( + &logFile, "log-file", "", "~/.clisso.log", "log file location", + ) + // err = viper.BindPFlag("global.log.file", RootCmd.PersistentFlags().Lookup("log-file")) + // if err != nil { + // logrus.Fatalf("Error binding flag global.log.file: %v", err) + // } RootCmd.SetUsageTemplate(usageTemplate) RootCmd.SetVersionTemplate(versionTemplate) } @@ -91,17 +101,17 @@ func Execute(version, commit, date string) { RootCmd.Version = version + " (" + commit + " " + date + ")" err := RootCmd.Execute() if err != nil { - log.Fatalf("Failed to execute: %v", err) + logrus.Fatalf("Failed to execute: %v", err) } } -func initConfig() { +func initConfig(cmd *cobra.Command) error { if cfgFile != "" { viper.SetConfigFile(cfgFile) } else { home, err := homedir.Dir() if err != nil { - log.Fatalf("Error getting home directory: %v", err) + log.Log.Fatalf("Error getting home directory: %v", err) } viper.SetConfigType("yaml") @@ -113,15 +123,44 @@ func initConfig() { if _, err := os.Stat(file); os.IsNotExist(err) { _, err := os.Create(file) if err != nil { - log.Fatalf("Error creating config file: %v", err) + panic(fmt.Errorf("can't create config file: %v", err)) } } - // Set default config values - viper.SetDefault("global.credentials-path", filepath.Join(home, ".aws", "credentials")) + // // Set default config values + // viper.SetDefault("global.credentials-path", filepath.Join(home, ".aws", "credentials")) + // viper.SetDefault("global.cache.path", filepath.Join(home, ".aws", "credentials-cache")) } if err := viper.ReadInConfig(); err != nil { - log.Fatalf("Can't read config: %v", err) + // no logger yet, panic + panic(fmt.Errorf("can't read config: %v", err)) } + bindFlags(cmd, viper.GetViper()) + _ = log.NewLogger(logLevel, logFile, logFile != "") + return nil +} + +// Bind each cobra flag to its associated viper configuration (config file and environment variable) +func bindFlags(cmd *cobra.Command, v *viper.Viper) { + cmd.Flags().VisitAll(func(f *pflag.Flag) { + + // Determine the naming convention of the flags when represented in the config file + configName := fmt.Sprintf("global.%s", f.Name) + configName = strings.ReplaceAll(configName, "-", ".") + //fmt.Fprintf(os.Stderr, "Checking Flag: %s\n", configName) + + // Apply the viper config value to the flag when the flag is not set and viper has a value + if !f.Changed && v.IsSet(configName) { + //fmt.Fprintf(os.Stderr, "Setting Flag %s by config: %s\n", f.Name, configName) + val := v.Get(configName) + err := cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)) + if err != nil { + // no logger yet, so print to stderr + fmt.Fprintf(os.Stderr, "Error setting flag %s: %v\n", f.Name, err) + } + /*} else { + fmt.Fprintf(os.Stderr, "Using Flag %s default: %v\n", f.Name, f.DefValue)*/ + } + }) } diff --git a/cmd/status.go b/cmd/status.go index 75fa790..d0d909e 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -11,9 +11,9 @@ import ( "time" "github.com/allcloud-io/clisso/aws" + "github.com/allcloud-io/clisso/log" homedir "github.com/mitchellh/go-homedir" "github.com/olekukonko/tablewriter" - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -24,11 +24,11 @@ func init() { RootCmd.AddCommand(cmdStatus) cmdStatus.Flags().StringVarP( &readFromFile, "read-from-file", "r", "", - "Read credentials from this file instead of the default ($HOME/.aws/credentials)", + "Read credentials from this file instead of the default (~/.aws/credentials)", ) err := viper.BindPFlag("global.credentials-path", cmdStatus.Flags().Lookup("read-from-file")) if err != nil { - log.Fatalf("Error binding flag global.credentials-path: %v", err) + log.Log.Fatalf("Error binding flag global.credentials-path: %v", err) } } @@ -42,14 +42,14 @@ var cmdStatus = &cobra.Command{ } func printStatus() { - configfile, err := homedir.Expand(viper.GetString("global.credentials-path")) + credentialFile, err := homedir.Expand(viper.GetString("global.credentials-path")) if err != nil { - log.Fatalf("Failed to expand home: %s", err) + log.Log.Fatalf("Failed to expand home: %s", err) } - profiles, err := aws.GetValidCredentials(configfile) + profiles, err := aws.GetValidProfiles(credentialFile) if err != nil { - log.Fatalf("Failed to retrieve non-expired credentials: %s", err) + log.Log.Fatalf("Failed to retrieve non-expired credentials: %s", err) } if len(profiles) == 0 { @@ -60,7 +60,7 @@ func printStatus() { table := tablewriter.NewWriter(os.Stdout) table.SetHeader([]string{"App", "Expire At", "Remaining"}) - log.Print("The following apps currently have valid credentials:") + log.Log.Print("The following apps currently have valid credentials:") for _, p := range profiles { table.Append([]string{p.Name, fmt.Sprintf("%d", p.ExpireAtUnix), p.LifetimeLeft.Round(time.Second).String()}) } diff --git a/config/config.go b/config/config.go index 69b16bc..fc65960 100644 --- a/config/config.go +++ b/config/config.go @@ -9,8 +9,9 @@ import ( "errors" "fmt" + "github.com/allcloud-io/clisso/log" "github.com/icza/gog" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/spf13/viper" ) @@ -27,14 +28,14 @@ type OneLoginProviderConfig struct { // GetOneLoginProvider returns a OneLoginProviderConfig struct containing the configuration for // provider p. func GetOneLoginProvider(p string) (*OneLoginProviderConfig, error) { - log.WithField("provider", p).Trace("Reading OneLogin provider config") + log.Log.WithField("provider", p).Trace("Reading OneLogin provider config") clientSecret := viper.GetString(fmt.Sprintf("providers.%s.client-secret", p)) clientID := viper.GetString(fmt.Sprintf("providers.%s.client-id", p)) subdomain := viper.GetString(fmt.Sprintf("providers.%s.subdomain", p)) username := viper.GetString(fmt.Sprintf("providers.%s.username", p)) region := viper.GetString(fmt.Sprintf("providers.%s.region", p)) - log.WithFields(log.Fields{ - "clientSecret": gog.If(log.GetLevel() == log.TraceLevel, clientSecret, ""), + log.Log.WithFields(logrus.Fields{ + "clientSecret": gog.If(log.Log.GetLevel() == logrus.TraceLevel, clientSecret, ""), "clientID": clientID, "subdomain": subdomain, "username": username, diff --git a/go.mod b/go.mod index 4deea4b..76bc2d7 100644 --- a/go.mod +++ b/go.mod @@ -15,8 +15,10 @@ require ( github.com/mattn/go-colorable v0.1.13 github.com/mitchellh/go-homedir v1.1.0 github.com/olekukonko/tablewriter v0.0.5 + github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 + github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.18.2 github.com/zalando/go-keyring v0.2.3 golang.org/x/net v0.21.0 @@ -56,7 +58,6 @@ require ( github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.19.0 // indirect diff --git a/go.sum b/go.sum index 3fe86ee..ec7e594 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 h1:mZHayPoR0lNmnHyvtYjDeq0zlVHn9K/ZXoy17ylucdo= +github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5/go.mod h1:GEXHk5HgEKCvEIIrSpFI3ozzG5xOKA2DVlEX/gGnewM= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= diff --git a/keychain/keychain.go b/keychain/keychain.go index 13be514..3bf823d 100644 --- a/keychain/keychain.go +++ b/keychain/keychain.go @@ -9,7 +9,7 @@ import ( "fmt" "syscall" - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" keyring "github.com/zalando/go-keyring" "golang.org/x/term" ) @@ -43,15 +43,15 @@ func (DefaultKeychain) Set(provider string, password []byte) (err error) { // and just ask the user for the password instead. Error could be anything from access denied to // password not found. func (DefaultKeychain) Get(provider string) (pw []byte, err error) { - log.WithField("provider", provider).Trace("Reading password from keychain") + log.Log.WithField("provider", provider).Trace("Reading password from keychain") pass, err := get(provider) if err != nil { - log.WithError(err).Trace("Couldn't read password from keychain") + log.Log.WithError(err).Trace("Couldn't read password from keychain") fmt.Printf("Please enter %s password: ", provider) pass, err = term.ReadPassword(int(syscall.Stdin)) if err != nil { err = fmt.Errorf("couldn't read password from terminal: %w", err) - log.WithError(err).Trace("Couldn't read password from terminal") + log.Log.WithError(err).Trace("Couldn't read password from terminal") return nil, err } } diff --git a/keychain/keychain_test.go b/keychain/keychain_test.go index 5aafd5f..6e71d3a 100644 --- a/keychain/keychain_test.go +++ b/keychain/keychain_test.go @@ -8,8 +8,12 @@ package keychain import ( "math/rand" "testing" + + "github.com/allcloud-io/clisso/log" ) +var _ = log.NewLogger("panic", "", false) + func randSeq(n int, letters []rune) []byte { b := make([]rune, n) for i := range b { diff --git a/log/log.go b/log/log.go new file mode 100644 index 0000000..daa7e6e --- /dev/null +++ b/log/log.go @@ -0,0 +1,61 @@ +package log + +import ( + "io" + "runtime" + + "github.com/mattn/go-colorable" + "github.com/mitchellh/go-homedir" + "github.com/rifflock/lfshook" + "github.com/sirupsen/logrus" +) + +var Log *logrus.Logger + +func NewLogger(logLevelFlag, logFilePath string, enableLogFile bool) *logrus.Logger { + if Log != nil { + Log.Tracef("Logger already initialized") + return Log + } + + // parse log level flag and set log level + logLevel, err := logrus.ParseLevel(logLevelFlag) + if err != nil { + logrus.Fatalf("Error parsing log level: %v", err) + } + Log = logrus.New() + Log.SetLevel(logLevel) + + if enableLogFile { + logFile, err := homedir.Expand(logFilePath) + if err != nil { + Log.Fatalf("Error expanding homedir: %v", err) + } + + // set all log levels to write to the log file + pathMap := lfshook.PathMap{ + logrus.TraceLevel: logFile, + logrus.DebugLevel: logFile, + logrus.InfoLevel: logFile, + logrus.WarnLevel: logFile, + logrus.ErrorLevel: logFile, + logrus.FatalLevel: logFile, + logrus.PanicLevel: logFile, + } + Log.Hooks.Add(lfshook.NewHook( + pathMap, + &logrus.JSONFormatter{}, + )) + Log.Out = io.Discard + } else { + if runtime.GOOS == "windows" { + // Handle terminal colors on Windows machines. + // TODO, check if still required with the switch to logrus + Log.SetOutput(colorable.NewColorableStdout()) + } + Log.SetFormatter(&logrus.TextFormatter{PadLevelText: true}) + } + Log.Warning("This is a warning") + Log.Warnf("Log level set to %s", logLevelFlag) + return Log +} diff --git a/main.go b/main.go index 115f1c6..2e1bd60 100644 --- a/main.go +++ b/main.go @@ -6,11 +6,6 @@ package main import ( - "runtime" - - "github.com/mattn/go-colorable" - log "github.com/sirupsen/logrus" - "github.com/allcloud-io/clisso/cmd" ) @@ -21,16 +16,6 @@ var ( date = "unknown" ) -func init() { - if runtime.GOOS == "windows" { - // Handle terminal colors on Windows machines. - // TODO, check if still required with the switch to logrus - log.SetOutput(colorable.NewColorableStdout()) - } - - log.SetFormatter(&log.TextFormatter{PadLevelText: true}) -} - func main() { cmd.Execute(version, commit, date) } diff --git a/okta/client.go b/okta/client.go index ff5dc3d..1be57fa 100644 --- a/okta/client.go +++ b/okta/client.go @@ -16,7 +16,8 @@ import ( "time" "github.com/PuerkitoBio/goquery" - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" + "github.com/sirupsen/logrus" "golang.org/x/net/publicsuffix" ) @@ -186,7 +187,7 @@ func (c *Client) LaunchApp(p *LaunchAppParams) (*string, error) { // using the client, handles any HTTP-related errors and returns any data as a string. func (c *Client) doRequest(r *http.Request) (string, error) { resp, err := c.Do(r) - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "status": resp.Status, "url": resp.Request.URL, "host": resp.Request.Host, diff --git a/okta/client_test.go b/okta/client_test.go index 2ba1244..6ee9f4f 100644 --- a/okta/client_test.go +++ b/okta/client_test.go @@ -10,8 +10,12 @@ import ( "net/http/httptest" "testing" "time" + + "github.com/allcloud-io/clisso/log" ) +var _ = log.NewLogger("panic", "", false) + func getTestServer(data string) *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte(data)) diff --git a/okta/get.go b/okta/get.go index aa8835d..2cdc527 100644 --- a/okta/get.go +++ b/okta/get.go @@ -12,10 +12,11 @@ import ( "github.com/allcloud-io/clisso/aws" "github.com/allcloud-io/clisso/config" "github.com/allcloud-io/clisso/keychain" + "github.com/allcloud-io/clisso/log" "github.com/allcloud-io/clisso/saml" "github.com/allcloud-io/clisso/spinner" "github.com/icza/gog" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) const ( @@ -31,13 +32,14 @@ var ( ) // Get gets temporary credentials for the given app. -func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credentials, error) { - log.WithFields(log.Fields{ - "app": app, - "provider": provider, - "pArn": pArn, - "awsRegion": awsRegion, - "duration": duration, +func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool) (*aws.Credentials, error) { + log.Log.WithFields(logrus.Fields{ + "app": app, + "provider": provider, + "pArn": pArn, + "awsRegion": awsRegion, + "duration": duration, + "interactive": interactive, }).Trace("Getting credentials from Okta") // Get provider config p, err := config.GetOktaProvider(provider) @@ -71,21 +73,21 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential } // Initialize spinner - var s = spinner.New() + var s = spinner.New(interactive) // Get session token s.Start() - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "Username": user, // print password only in Trace Log Level - "Password": gog.If(log.GetLevel() == log.TraceLevel, string(pass), ""), + "Password": gog.If(log.Log.GetLevel() == logrus.TraceLevel, string(pass), ""), }).Debug("Calling GetSessionToken") resp, err := c.GetSessionToken(&GetSessionTokenParams{ Username: user, Password: string(pass), }) s.Stop() - log.WithField("Status", resp.Status).WithError(err).Trace("GetSessionToken done") + log.Log.WithField("Status", resp.Status).WithError(err).Trace("GetSessionToken done") if err != nil { return nil, fmt.Errorf("getting session token: %v", err) } @@ -99,7 +101,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential case StatusMFARequired: factor := resp.Embedded.Factors[0] stateToken := resp.StateToken - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "factorID": factor.ID, "factorLink": factor.Links.Verify.Href, "stateToken": stateToken, @@ -114,7 +116,9 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential // https://developer.okta.com/docs/api/resources/authn/#verify-push-factor // Keep polling authentication transactions with WAITING result until the challenge // completes or expires. - fmt.Println("Please approve request on Okta Verify app") + if interactive { + fmt.Println("Please approve request on Okta Verify app") + } s.Start() vfResp, err = c.VerifyFactor(&VerifyFactorParams{ FactorID: factor.ID, @@ -164,7 +168,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential // Handle failed MFA verification (verification rejected or timed out) if vfResp.Status != VerifyFactorStatusSuccess { err = fmt.Errorf("MFA verification failed") - log.WithField("status", vfResp.Status).WithError(err).Warn("MFA verification failed") + log.Log.WithField("status", vfResp.Status).WithError(err).Warn("MFA verification failed") return nil, fmt.Errorf("MFA verification failed") } @@ -175,7 +179,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential // Launch Okta app with session token s.Start() - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "SessionToken": st, "URL": a.URL, }).Trace("Calling LaunchApp") @@ -196,7 +200,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential if err != nil { if err.Error() == aws.ErrDurationExceeded { - log.Warn(aws.DurationExceededMessage) + log.Log.Warn(aws.DurationExceededMessage) s.Start() creds, err = aws.AssumeSAMLRole(arn.Provider, arn.Role, *samlAssertion, awsRegion, 3600) s.Stop() diff --git a/onelogin/client.go b/onelogin/client.go index a113316..4bb7fb4 100644 --- a/onelogin/client.go +++ b/onelogin/client.go @@ -14,7 +14,8 @@ import ( "net/http" "time" - log "github.com/sirupsen/logrus" + "github.com/allcloud-io/clisso/log" + "github.com/sirupsen/logrus" ) // Client represents a OneLogin API client. @@ -115,14 +116,13 @@ func (c *Client) doRequest(r *http.Request) (string, error) { resp, err := c.Do(r) if resp != nil { - log.WithFields( - log.Fields{ - "status": resp.Status, - "url": resp.Request.URL, - "host": resp.Request.Host, - "code": resp.StatusCode, - "method": resp.Request.Method, - }).WithError(err).Trace("HTTP request sent") + log.Log.WithFields(logrus.Fields{ + "status": resp.Status, + "url": resp.Request.URL, + "host": resp.Request.Host, + "code": resp.StatusCode, + "method": resp.Request.Method, + }).WithError(err).Trace("HTTP request sent") } if err != nil { diff --git a/onelogin/client_test.go b/onelogin/client_test.go index e90296d..8ae6eb5 100644 --- a/onelogin/client_test.go +++ b/onelogin/client_test.go @@ -10,6 +10,8 @@ import ( "net/http/httptest" "net/url" "testing" + + "github.com/allcloud-io/clisso/log" ) func getTestServer(data string) *httptest.Server { @@ -25,6 +27,8 @@ func getTestServer(data string) *httptest.Server { var c = Client{} +var _ = log.NewLogger("panic", "", false) + func TestNewClient(t *testing.T) { for _, test := range []struct { name string diff --git a/onelogin/endpoints_test.go b/onelogin/endpoints_test.go index d891065..3748393 100644 --- a/onelogin/endpoints_test.go +++ b/onelogin/endpoints_test.go @@ -8,8 +8,12 @@ package onelogin import ( "net/url" "testing" + + "github.com/allcloud-io/clisso/log" ) +var _ = log.NewLogger("panic", "", false) + func TestEndpoints_SetBase(t *testing.T) { for _, test := range []struct { name string diff --git a/onelogin/get.go b/onelogin/get.go index f5d4439..be44960 100644 --- a/onelogin/get.go +++ b/onelogin/get.go @@ -8,6 +8,7 @@ package onelogin import ( "errors" "fmt" + "os" "strconv" "strings" "time" @@ -15,10 +16,11 @@ import ( "github.com/allcloud-io/clisso/aws" "github.com/allcloud-io/clisso/config" "github.com/allcloud-io/clisso/keychain" + "github.com/allcloud-io/clisso/log" "github.com/allcloud-io/clisso/saml" "github.com/allcloud-io/clisso/spinner" "github.com/icza/gog" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) const ( @@ -40,13 +42,14 @@ var ( // Get gets temporary credentials for the given app. // TODO Move AWS logic outside this function. -func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credentials, error) { - log.WithFields(log.Fields{ - "app": app, - "provider": provider, - "pArn": pArn, - "awsRegion": awsRegion, - "duration": duration, +func Get(app, provider, pArn, awsRegion string, duration int32, interactive bool) (*aws.Credentials, error) { + log.Log.WithFields(logrus.Fields{ + "app": app, + "provider": provider, + "pArn": pArn, + "awsRegion": awsRegion, + "duration": duration, + "interactive": interactive, }).Trace("Getting credentials from OneLogin") // Read config p, err := config.GetOneLoginProvider(provider) @@ -65,11 +68,11 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential } // Initialize spinner - var s = spinner.New() + var s = spinner.New(interactive) // Get OneLogin access token s.Start() - log.Trace("Generating access token") + log.Log.Trace("Generating access token") token, err := c.GenerateTokens(p.ClientID, p.ClientSecret) s.Stop() if err != nil { @@ -78,7 +81,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential user := p.Username if user == "" { - log.Trace("No username provided") + log.Log.Trace("No username provided") // Get credentials from the user fmt.Print("OneLogin username: ") fmt.Scanln(&user) @@ -99,10 +102,10 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential Subdomain: p.Subdomain, } - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "UsernameOrEmail": user, // print password only in Trace Log Level - "Password": gog.If(log.GetLevel() == log.TraceLevel, string(pass), ""), + "Password": gog.If(log.Log.GetLevel() == logrus.TraceLevel, string(pass), ""), "AppId": a.ID, "Subdomain": p.Subdomain, }).Debug("Calling GenerateSamlAssertion") @@ -114,14 +117,14 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential return nil, fmt.Errorf("generating SAML assertion: %v", err) } - log.WithField("Message", rSaml.Message).Debug("GenerateSamlAssertion is done") + log.Log.WithField("Message", rSaml.Message).Debug("GenerateSamlAssertion is done") var rData string if rSaml.Message != "Success" { st := rSaml.StateToken devices := rSaml.Devices - log.WithField("Devices", devices).Trace("Devices returned by GenerateSamlAssertion") + log.Log.WithField("Devices", devices).Trace("Devices returned by GenerateSamlAssertion") device, err := getDevice(devices) if err != nil { return nil, fmt.Errorf("error getting devices: %s", err) @@ -141,7 +144,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential OtpToken: "", DoNotNotify: false, } - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "AppId": a.ID, "DeviceId": device.DeviceID, "StateToken": st, @@ -155,14 +158,18 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential } pMfa.DoNotNotify = true - - fmt.Println(rMfa.Message) + if interactive { + fmt.Println(rMfa.Message) + } else { + // print to StdErr if we're not interactive + fmt.Fprintln(os.Stderr, rMfa.Message) + } timeout := MFAPushTimeout s.Start() for strings.Contains(rMfa.Message, "pending") && timeout > 0 { time.Sleep(time.Duration(MFAInterval) * time.Second) - log.Trace("MFAInterval completed, calling VerifyFactor again") + log.Log.Trace("MFAInterval completed, calling VerifyFactor again") rMfa, err = c.VerifyFactor(token, &pMfa) if err != nil { s.Stop() @@ -202,7 +209,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential } } rData = rMfa.Data - log.Trace("Factor is verified") + log.Log.Trace("Factor is verified") } else { rData = rSaml.Data } @@ -218,7 +225,7 @@ func Get(app, provider, pArn, awsRegion string, duration int32) (*aws.Credential if err != nil { if err.Error() == aws.ErrDurationExceeded { - log.Warn(aws.DurationExceededMessage) + log.Log.Warn(aws.DurationExceededMessage) s.Start() creds, err = aws.AssumeSAMLRole(arn.Provider, arn.Role, rData, awsRegion, 3600) s.Stop() @@ -241,7 +248,7 @@ func getDevice(devices []Device) (device *Device, err error) { } if len(devices) == 1 { - log.Trace("Only one MFA device returned by Onelogin, automatically selecting it.") + log.Log.Trace("Only one MFA device returned by Onelogin, automatically selecting it.") device = &Device{DeviceID: devices[0].DeviceID, DeviceType: devices[0].DeviceType} return } diff --git a/saml/saml.go b/saml/saml.go index 5595ace..f66b606 100644 --- a/saml/saml.go +++ b/saml/saml.go @@ -14,8 +14,9 @@ import ( "strconv" "strings" + "github.com/allcloud-io/clisso/log" "github.com/crewjam/saml" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" "github.com/spf13/viper" ) @@ -32,14 +33,14 @@ const idpRegex = `^arn:(?:aws|aws-cn):iam::\d+:saml-provider\/\S+$` func Get(data, pArn string) (a ARN, err error) { samlBody, err := decode(data) if err != nil { - log.WithError(err).Error("Error decoding SAML assertion") + log.Log.WithError(err).Error("Error decoding SAML assertion") return } x := new(saml.Response) err = xml.Unmarshal(samlBody, x) if err != nil { - log.WithError(err).Error("Error parsing SAML assertion") + log.Log.WithError(err).Error("Error parsing SAML assertion") return } @@ -67,10 +68,10 @@ func decode(in string) (b []byte, err error) { } func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { - log.WithField("preferredARN", pArn).Trace("Extracting ARNs from SAML AttributeStatements") + log.Log.WithField("preferredARN", pArn).Trace("Extracting ARNs from SAML AttributeStatements") // check for human readable ARN strings in config accounts := viper.GetStringMap("global.accounts") - log.WithFields(accounts).Trace("Accounts loaded from config") + log.Log.WithFields(accounts).Trace("Accounts loaded from config") arns := make([]ARN, 0) for _, stmt := range stmts { @@ -90,7 +91,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { components := strings.Split(strings.TrimSpace(av.Value), ",") if len(components) != 2 { // Wrong number of components - move on - log.WithFields(log.Fields{ + log.Log.WithFields(logrus.Fields{ "components": components, "length": len(components), "value": av.Value, @@ -101,7 +102,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { // people like to put spaces in there, AWS accepts them, let's remove them on our end too. components[0] = strings.TrimSpace(components[0]) components[1] = strings.TrimSpace(components[1]) - log.WithField("components", components).Trace("ARN components extracted from SAML assertion") + log.Log.WithField("components", components).Trace("ARN components extracted from SAML assertion") arn := ARN{} @@ -110,13 +111,13 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { // Otherwise it matches it with what is in the .clisso.yaml file if pArn != "" { if components[0] == pArn { - log.Trace("Preferred ARN matches first component") + log.Log.Trace("Preferred ARN matches first component") arn = ARN{components[0], components[1], ""} } else if components[1] == pArn { - log.Trace("Preferred ARN matches second component") + log.Log.Trace("Preferred ARN matches second component") arn = ARN{components[1], components[0], ""} } else { - log.Trace("Preferred ARN does not match either component") + log.Log.Trace("Preferred ARN does not match either component") continue } } else { @@ -125,20 +126,20 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { idp := regexp.MustCompile(idpRegex) if role.MatchString(components[0]) && idp.MatchString(components[1]) { - log.Trace("First component is role, second component is IdP") + log.Log.Trace("First component is role, second component is IdP") arn = ARN{components[0], components[1], ""} } else if role.MatchString(components[1]) && idp.MatchString(components[0]) { - log.Trace("First component is IdP, second component is role") + log.Log.Trace("First component is IdP, second component is role") arn = ARN{components[1], components[0], ""} } else { - log.Trace("Neither component matches expected pattern") + log.Log.Trace("Neither component matches expected pattern") continue } // Look up the human friendly name, if available if len(accounts) > 0 { ids := role.FindStringSubmatch(arn.Role) - log.WithField("matches", ids).Trace("Role regex matches") + log.Log.WithField("matches", ids).Trace("Role regex matches") // if the regex matches we should have 3 entries from the regex match // 1) the matching string @@ -146,7 +147,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { // 3) the match for Name // we want to match the Id to any accounts/roles in our config if len(ids) == 3 && accounts[ids[1]] != "" && accounts[ids[1]] != nil { - log.Trace("Found human friendly name for account") + log.Log.Trace("Found human friendly name for account") arn.Name = fmt.Sprintf("%s - %s", accounts[ids[1]].(string), ids[2]) } } @@ -159,7 +160,7 @@ func extractArns(stmts []saml.AttributeStatement, pArn string) []ARN { } } } - log.Trace("No statements in SAML assertion or no ARNs found.") + log.Log.Trace("No statements in SAML assertion or no ARNs found.") // Empty :( return arns } diff --git a/saml/saml_test.go b/saml/saml_test.go index 30be99d..0debe5e 100644 --- a/saml/saml_test.go +++ b/saml/saml_test.go @@ -9,9 +9,12 @@ import ( "os" "testing" + "github.com/allcloud-io/clisso/log" "github.com/crewjam/saml" ) +var _ = log.NewLogger("panic", "", false) + func TestExtractArns(t *testing.T) { for _, test := range []struct { name string diff --git a/spinner/spinner.go b/spinner/spinner.go index af98f8b..eb16708 100644 --- a/spinner/spinner.go +++ b/spinner/spinner.go @@ -8,8 +8,8 @@ package spinner // This is a wrapper around spinner to disable unsupported operation systems transparently until upstream is fixed. // See https://github.com/briandowns/spinner/issues/52 -func New() SpinnerWrapper { - return new() +func New(interactive bool) SpinnerWrapper { + return new(interactive) } // SpinnerWrapper is used to abstract a spinner so that it can be conveniently disabled on terminals which don't support it. diff --git a/spinner/spinner_unix.go b/spinner/spinner_unix.go index cd07068..3a2e9e9 100644 --- a/spinner/spinner_unix.go +++ b/spinner/spinner_unix.go @@ -12,12 +12,13 @@ package spinner import ( "time" + "github.com/allcloud-io/clisso/log" "github.com/briandowns/spinner" - log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus" ) -func new() SpinnerWrapper { - if log.GetLevel() >= log.DebugLevel { +func new(interactive bool) SpinnerWrapper { + if log.Log.GetLevel() >= logrus.DebugLevel || !interactive { return &noopSpinner{} } return spinner.New(spinner.CharSets[14], 50*time.Millisecond) diff --git a/spinner/spinner_windows.go b/spinner/spinner_windows.go index 8ecb0f9..fb2e23c 100644 --- a/spinner/spinner_windows.go +++ b/spinner/spinner_windows.go @@ -8,6 +8,6 @@ */ package spinner -func new() SpinnerWrapper { +func new(interactive bool) SpinnerWrapper { return &noopSpinner{} }