From 9a92d0fd1ecbca4c732c29c0962369b683e7a472 Mon Sep 17 00:00:00 2001 From: Hiruna Wijesinghe Date: Mon, 7 Jun 2021 02:43:01 +1000 Subject: [PATCH 1/5] initial commit to use okta sessions API, remember session locally + remember MFA --- helper/credentials/saml.go | 8 + pkg/creds/creds.go | 17 +- pkg/provider/okta/okta.go | 340 ++++++++++++++++++++++++++++++++++--- 3 files changed, 331 insertions(+), 34 deletions(-) diff --git a/helper/credentials/saml.go b/helper/credentials/saml.go index 7f080ed9c..0f3ba65a8 100644 --- a/helper/credentials/saml.go +++ b/helper/credentials/saml.go @@ -17,6 +17,14 @@ func LookupCredentials(loginDetails *creds.LoginDetails, provider string) error loginDetails.Username = username loginDetails.Password = password + // If the provider is Okta, check for existing Okta Session Cookie (sid) + if provider == "Okta" { + _, oktaSessionCookie, err := CurrentHelper.Get(loginDetails.URL + "/sessionCookie") + if err == nil { + loginDetails.OktaSessionCookie = oktaSessionCookie + } + } + if provider == "OneLogin" { id, secret, err := CurrentHelper.Get(path.Join(loginDetails.URL, "/auth/oauth2/v2/token")) if err != nil { diff --git a/pkg/creds/creds.go b/pkg/creds/creds.go index ed08611c4..e006216ba 100644 --- a/pkg/creds/creds.go +++ b/pkg/creds/creds.go @@ -2,12 +2,13 @@ package creds // LoginDetails used to authenticate type LoginDetails struct { - ClientID string // used by OneLogin - ClientSecret string // used by OneLogin - Username string - Password string - MFAToken string - DuoMFAOption string - URL string - StateToken string // used by Okta + ClientID string // used by OneLogin + ClientSecret string // used by OneLogin + Username string + Password string + MFAToken string + DuoMFAOption string + URL string + StateToken string // used by Okta + OktaSessionCookie string // used by Okta } diff --git a/pkg/provider/okta/okta.go b/pkg/provider/okta/okta.go index 7361e31d4..0247c61d8 100644 --- a/pkg/provider/okta/okta.go +++ b/pkg/provider/okta/okta.go @@ -21,6 +21,7 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" + "github.com/versent/saml2aws/v2/helper/credentials" "github.com/versent/saml2aws/v2/pkg/cfg" "github.com/versent/saml2aws/v2/pkg/creds" "github.com/versent/saml2aws/v2/pkg/page" @@ -73,8 +74,14 @@ type AuthRequest struct { // VerifyRequest represents an mfa verify request type VerifyRequest struct { - StateToken string `json:"stateToken"` - PassCode string `json:"passCode,omitempty"` + StateToken string `json:"stateToken"` + PassCode string `json:"passCode,omitempty"` + RememberDevice string `json:"rememberDevice,omitempty"` +} //https://developer.okta.com/docs/reference/api/authn/#verify-security-question-factor + +// SessionRequst holds the SessionToken used to create an Okta Session +type SessionRequst struct { + SessionToken string `json:"sessionToken"` } // mfaChallengeContext is used to hold MFA challenge context in a simple struct. @@ -115,26 +122,274 @@ func New(idpAccount *cfg.IDPAccount) (*Client, error) { type ctxKey string -// Authenticate logs into Okta and returns a SAML response -func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) { +func (oc *Client) validateSession(loginDetails *creds.LoginDetails) error { + logger.Debug("validate session func called") + + if loginDetails == nil { + logger.Debugf("unable to validate the okta session, nil input | loginDetails: %v ", loginDetails) + return fmt.Errorf("unable to validate the okta session, nil input") + } + + sessionCookie := loginDetails.OktaSessionCookie oktaURL, err := url.Parse(loginDetails.URL) if err != nil { - return "", errors.Wrap(err, "error building oktaURL") + return errors.Wrap(err, "error building oktaURL") + } + + oktaOrgHost := oktaURL.Host + + sessionReqURL := fmt.Sprintf("https://%s/api/v1/sessions/me", oktaOrgHost) + sessionReqBody := new(bytes.Buffer) + + req, err := http.NewRequest("GET", sessionReqURL, sessionReqBody) + if err != nil { + return errors.Wrap(err, "error building new session request") + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + req.Header.Add("Cookie", fmt.Sprintf("sid=%s", sessionCookie)) + + res, err := oc.client.Do(req) + if err != nil { + return errors.Wrap(err, "error retrieving session response") + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return errors.Wrap(err, "error retrieving body from response") + } + + resp := string(body) + + if res.StatusCode != 200 { + logger.Debug("invalid okta session") + return fmt.Errorf("invalid okta session") + } else { + sessionResponseStatus := gjson.Get(resp, "status").String() + switch sessionResponseStatus { + case "ACTIVE": + logger.Debug("okta session established") + case "MFA_REQUIRED": + _, err := verifyMfa(oc, oktaOrgHost, loginDetails, resp) + if err != nil { + return errors.Wrap(err, "error verifying MFA") + } + case "MFA_ENROLL": + // Not yet fully implemented, so just return the status as the error string... + return fmt.Errorf("MFA_ENROLL") + } + } + + logger.Debug("valid okta session") + return nil +} + +func (oc *Client) createSession(loginDetails *creds.LoginDetails, sessionToken string) (string, string, error) { + logger.Debug("create session func called") + if loginDetails == nil || sessionToken == "" { + logger.Debugf("unable to create an Okta session, nil input | loginDetails: %v | sessionToken: %s", loginDetails, sessionToken) + return "", "", fmt.Errorf("unable to create an okta session, nil input") + } + + oktaURL, err := url.Parse(loginDetails.URL) + if err != nil { + return "", "", errors.Wrap(err, "error building okta url") } oktaOrgHost := oktaURL.Host - //dummy request to set device token cookie ("dt") + //authenticate via okta api + sessionReq := SessionRequst{SessionToken: sessionToken} + sessionReqBody := new(bytes.Buffer) + err = json.NewEncoder(sessionReqBody).Encode(sessionReq) + if err != nil { + return "", "", errors.Wrap(err, "error encoding session req") + } + + sessionReqURL := fmt.Sprintf("https://%s/api/v1/sessions", oktaOrgHost) + + req, err := http.NewRequest("POST", sessionReqURL, sessionReqBody) + if err != nil { + return "", "", errors.Wrap(err, "error building new session request") + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + + res, err := oc.client.Do(req) + if err != nil { + return "", "", errors.Wrap(err, "error retrieving session response") + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return "", "", errors.Wrap(err, "error retrieving body from response") + } + + if res.StatusCode == 401 { // https://developer.okta.com/docs/reference/api/sessions/#response-parameters + return "", "", fmt.Errorf("unable to create an Okta session, invalid sessionToken") + } + + resp := string(body) + + oktaSessionExpiresAtStr := gjson.Get(resp, "expiresAt").String() + logger.Debugf("okta session expires at: %s", oktaSessionExpiresAtStr) + + oktaSessionCookie := gjson.Get(resp, "id").String() + + err = credentials.SaveCredentials(loginDetails.URL+"/sessionCookie", loginDetails.Username, oktaSessionCookie) + if err != nil { + log.Printf("error storing okta session token | err: %v", err) //TODO: handle this properly instead of dumping to stdout + } + oktaSessionToken := gjson.Get(resp, "sessionToken").String() + sessionResponseStatus := gjson.Get(resp, "status").String() + switch sessionResponseStatus { + case "ACTIVE": + logger.Debug("okta session established") + case "MFA_REQUIRED": + oktaSessionToken, err = verifyMfa(oc, oktaOrgHost, loginDetails, resp) + if err != nil { + return "", "", errors.Wrap(err, "error verifying MFA") + } + case "MFA_ENROLL": + // Not yet fully implemented, so just return the status as the error string... + return "", "", fmt.Errorf("MFA_ENROLL") + } + + return oktaSessionCookie, oktaSessionToken, nil +} + +func (oc *Client) authWithSession(loginDetails *creds.LoginDetails) (string, error) { + logger.Debug("auth with session func called") + sessionCookie := loginDetails.OktaSessionCookie + err := oc.validateSession(loginDetails) + if err != nil { + modifiedLoginDetails := loginDetails + modifiedLoginDetails.OktaSessionCookie = "" + return oc.Authenticate(modifiedLoginDetails) + } + req, err := http.NewRequest("GET", loginDetails.URL, nil) if err != nil { - return "", errors.Wrap(err, "error building device token request") + return "", errors.Wrap(err, "error building authWithSession request") + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + req.Header.Add("Cookie", fmt.Sprintf("sid=%s", sessionCookie)) + + ctx := context.WithValue(context.Background(), ctxKey("authWithSession"), loginDetails) + + res, err := oc.client.Do(req) + if err != nil { + logger.Debugf("error authing with session: %v", err) } - _, err = oc.client.Do(req) + + body, err := ioutil.ReadAll(res.Body) if err != nil { - return "", errors.Wrap(err, "error retrieving device token") + logger.Debugf("error reading body for auth with session: %v", err) } + if strings.Contains(string(body), "/login/step-up/") { // https://developer.okta.com/docs/reference/api/authn/#step-up-authentication-with-okta-session + logger.Debug("okta step-up prompted, need mfa...") + stateToken, err := getStateTokenFromOktaPageBody(string(body)) + if err != nil { + return "", errors.Wrap(err, "error retrieving saml response") + } + loginDetails.StateToken = stateToken + return oc.Authenticate(loginDetails) + } + + // fmt.Println(resp) + // os.Exit(0) + return oc.follow(ctx, req, loginDetails) +} + +// https://devforum.okta.com/t/how-per-device-remember-me-api-works/3955/3 + +// func (oc *Client) getDeviceTokenFromOkta(loginDetails *creds.LoginDetails) (string, error) { +// //dummy request to set device token cookie ("dt") +// req, err := http.NewRequest("GET", loginDetails.URL, nil) +// if err != nil { +// return "", errors.Wrap(err, "error building device token request") +// } +// resp, err := oc.client.Do(req) +// if err != nil { +// return "", errors.Wrap(err, "error retrieving device token") +// } + +// for _, c := range resp.Cookies() { +// if c.Name == "DT" { // Device token +// return c.Value, nil +// } +// } + +// return "", fmt.Errorf("unable to get a device token from okta") +// } + +func (oc *Client) setDeviceTokenCookie(loginDetails *creds.LoginDetails) error { + + oktaURL, err := url.Parse(loginDetails.URL) + if err != nil { + return errors.Wrap(err, "error building oktaURL to set device token") + } + oktaURLScheme := oktaURL.Scheme + oktaURLHost := oktaURL.Host + baseURL := &url.URL{Scheme: oktaURLScheme, Host: oktaURLHost, Path: "/"} + + var cookies []*http.Cookie + cookie := http.Cookie{ + Name: "DT", + Secure: true, + Expires: time.Now().Add(time.Hour * 24 * 30), // 30 Days + Value: fmt.Sprintf("okta_%s_saml2aws", loginDetails.Username), + } + cookies = append(cookies, &cookie) + oc.client.Jar.SetCookies(baseURL, cookies) + + return nil +} + +// Authenticate logs into Okta and returns a SAML response +func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) { + // Get Okta session cookie (sid) from login details (if found via login.go) + oktaSessionCookie := loginDetails.OktaSessionCookie + + // Set Okta device token + err := oc.setDeviceTokenCookie(loginDetails) + if err != nil { + return "", errors.Wrap(err, "error setting device token in cookie jar") + } + + // If Okta session cookie is not empty + if oktaSessionCookie != "" && loginDetails.StateToken == "" { + return oc.authWithSession(loginDetails) + } + + oktaURL, err := url.Parse(loginDetails.URL) + if err != nil { + return "", errors.Wrap(err, "error building oktaURL") + } + + oktaOrgHost := oktaURL.Host + + // // Get the current list of cookies (if any) + // currentCookies := oc.client.Jar.Cookies(oktaURL) + // // Create Okta session cookie containing sid + // osc := http.Cookie{ + // Name: "sid", + // Value: oktaSessionCookie, + // Raw: fmt.Sprintf("sid=%s", oktaSessionCookie), + // Path: "/", + // HttpOnly: true, + // Secure: true, + // } + // // Add the session cookie to cookie jar + // currentCookies = append(currentCookies, &osc) + // // Set the cookie jar back to the http client + // oc.client.Jar.SetCookies(oktaURL, currentCookies) + //authenticate via okta api authReq := AuthRequest{Username: loginDetails.Username, Password: loginDetails.Password} if loginDetails.StateToken != "" { @@ -148,7 +403,7 @@ func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) authSubmitURL := fmt.Sprintf("https://%s/api/v1/authn", oktaOrgHost) - req, err = http.NewRequest("POST", authSubmitURL, authBody) + req, err := http.NewRequest("POST", authSubmitURL, authBody) if err != nil { return "", errors.Wrap(err, "error building authentication request") } @@ -179,36 +434,67 @@ func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) } } - //now call saml endpoint - oktaSessionRedirectURL := fmt.Sprintf("https://%s/login/sessionCookieRedirect", oktaOrgHost) - - req, err = http.NewRequest("GET", oktaSessionRedirectURL, nil) - if err != nil { - return "", errors.Wrap(err, "error building authentication request") + if oktaSessionCookie == "" { + oktaSessionCookie, _, err = oc.createSession(loginDetails, oktaSessionToken) + if err != nil { + return "", err + } + loginDetails.OktaSessionCookie = oktaSessionCookie } - q := req.URL.Query() - q.Add("checkAccountSetupComplete", "true") - q.Add("token", oktaSessionToken) - q.Add("redirectUrl", loginDetails.URL) - req.URL.RawQuery = q.Encode() - ctx := context.WithValue(context.Background(), ctxKey("login"), loginDetails) - return oc.follow(ctx, req, loginDetails) + return oc.authWithSession(loginDetails) + + // //now call saml endpoint + // oktaSessionRedirectURL := fmt.Sprintf("https://%s/login/sessionCookieRedirect", oktaOrgHost) + + // req, err = http.NewRequest("GET", oktaSessionRedirectURL, nil) + // if err != nil { + // return "", errors.Wrap(err, "error building authentication request") + // } + // q := req.URL.Query() + // q.Add("checkAccountSetupComplete", "true") + // q.Add("token", oktaSessionToken) + // q.Add("redirectUrl", loginDetails.URL) + // req.URL.RawQuery = q.Encode() + + // ctx := context.WithValue(context.Background(), ctxKey("login"), loginDetails) + // return oc.follow(ctx, req, loginDetails) } func (oc *Client) follow(ctx context.Context, req *http.Request, loginDetails *creds.LoginDetails) (string, error) { + if ctx.Value(ctxKey("follow")) != nil { + logger.Debug("follow func called from itself") + } + + if ctx.Value(ctxKey("authWithSession")) != nil { + logger.Debug("follow func called from auth with session func") + // oc.client.DisableFollowRedirect() + } res, err := oc.client.Do(req) if err != nil { + logger.Debug("ERROR FOLLOWING") return "", errors.Wrap(err, "error following") } doc, err := goquery.NewDocumentFromReader(res.Body) if err != nil { + logger.Debug("FAILED TO BUILD DOC FROM RESP") return "", errors.Wrap(err, "failed to build document from response") } var handler func(context.Context, *goquery.Document) (context.Context, *http.Request, error) + // if ctx.Value(ctxKey("authWithSession")) != nil { + // logger.Debug("follow func called from auth with session func") + // handler = oc.handleFormRedirect + // modifiedCtx := context.WithValue(context.Background(), ctxKey("follow"), loginDetails) + // modifiedCtx, req, err = handler(modifiedCtx, doc) + // if err != nil { + // return "", err + // } + // return oc.follow(modifiedCtx, req, loginDetails) + // } + if docIsFormRedirectToTarget(doc, oc.targetURL) { logger.WithField("type", "saml-response-to-aws").Debug("doc detect") if samlResponse, ok := extractSAMLResponse(doc); ok { @@ -348,7 +634,7 @@ func getMfaChallengeContext(oc *Client, mfaOption int, resp string) (*mfaChallen } // get signature & callback - verifyReq := VerifyRequest{StateToken: stateToken} + verifyReq := VerifyRequest{StateToken: stateToken, RememberDevice: "true"} verifyBody := new(bytes.Buffer) // Login flow is different for YubiKeys ( of course ) @@ -392,6 +678,7 @@ func getMfaChallengeContext(oc *Client, mfaOption int, resp string) (*mfaChallen }, nil } +// TODO: set device token https://developer.okta.com/docs/reference/api/authn/#context-object func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, resp string) (string, error) { stateToken := gjson.Get(resp, "stateToken").String() @@ -426,7 +713,7 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, if verifyCode == "" { verifyCode = prompter.StringRequired("Enter verification code") } - tokenReq := VerifyRequest{StateToken: stateToken, PassCode: verifyCode} + tokenReq := VerifyRequest{StateToken: stateToken, PassCode: verifyCode, RememberDevice: "true"} tokenBody := new(bytes.Buffer) err = json.NewEncoder(tokenBody).Encode(tokenReq) if err != nil { @@ -465,6 +752,7 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, // on 'success' status if gjson.Get(body, "status").String() == "SUCCESS" { fmt.Printf(" Approved\n\n") + fmt.Println(gjson.Get(body, "expiresAt").String()) // DEBUG return gjson.Get(body, "sessionToken").String(), nil } @@ -738,7 +1026,7 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, // extract okta session token - verifyReq := VerifyRequest{StateToken: stateToken} + verifyReq := VerifyRequest{StateToken: stateToken, RememberDevice: "true"} verifyBody := new(bytes.Buffer) err = json.NewEncoder(verifyBody).Encode(verifyReq) if err != nil { From 227c29d7dbe71b3d688feb86cf15a56c630757ed Mon Sep 17 00:00:00 2001 From: Hiruna Wijesinghe Date: Mon, 7 Jun 2021 15:00:14 +1000 Subject: [PATCH 2/5] finalize --- cmd/saml2aws/commands/login.go | 6 + cmd/saml2aws/commands/login_test.go | 34 +++ cmd/saml2aws/main.go | 6 +- go.mod | 4 +- go.sum | 9 +- pkg/cfg/cfg.go | 53 ++-- pkg/flags/flags.go | 56 ++-- pkg/provider/okta/okta.go | 379 +++++++++++++++------------- pkg/provider/okta/okta_test.go | 67 +++++ 9 files changed, 390 insertions(+), 224 deletions(-) diff --git a/cmd/saml2aws/commands/login.go b/cmd/saml2aws/commands/login.go index 91b0c6585..418a38872 100644 --- a/cmd/saml2aws/commands/login.go +++ b/cmd/saml2aws/commands/login.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "os" + "strings" "time" "github.com/aws/aws-sdk-go/aws" @@ -187,6 +188,11 @@ func resolveLoginDetails(account *cfg.IDPAccount, loginFlags *flags.LoginExecFla return nil, errors.Wrap(err, "error loading saved password") } } + } else { // if user disabled keychain, dont use Okta sessions & dont remember Okta MFA device + if strings.ToLower(account.Provider) == "okta" { + account.DisableSessions = true + account.DisableRememberDevice = true + } } // log.Printf("%s %s", savedUsername, savedPassword) diff --git a/cmd/saml2aws/commands/login_test.go b/cmd/saml2aws/commands/login_test.go index 762ccf424..0e29f3640 100644 --- a/cmd/saml2aws/commands/login_test.go +++ b/cmd/saml2aws/commands/login_test.go @@ -1,6 +1,7 @@ package commands import ( + "fmt" "testing" "time" @@ -29,6 +30,39 @@ func TestResolveLoginDetailsWithFlags(t *testing.T) { assert.Equal(t, &creds.LoginDetails{Username: "wolfeidau", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails) } +func TestOktaResolveLoginDetailsWithFlags(t *testing.T) { + + // Default state - user did not supply values for DisableSessions and DisableSessions + commonFlags := &flags.CommonFlags{URL: "https://id.example.com", Username: "testuser", Password: "testtestlol", MFAToken: "123456", SkipPrompt: true} + loginFlags := &flags.LoginExecFlags{CommonFlags: commonFlags} + + idpa := &cfg.IDPAccount{ + URL: "https://id.example.com", + MFA: "none", + Provider: "Okta", + Username: "testuser", + } + loginDetails, err := resolveLoginDetails(idpa, loginFlags) + + assert.Nil(t, err) + assert.False(t, idpa.DisableSessions, fmt.Errorf("default state, DisableSessions should be false")) + assert.False(t, idpa.DisableRememberDevice, fmt.Errorf("default state, DisableRememberDevice should be false")) + assert.Equal(t, &creds.LoginDetails{Username: "testuser", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails) + + // User disabled keychain, resolveLoginDetails should set the account's DisableSessions and DisableSessions fields to true + + commonFlags = &flags.CommonFlags{URL: "https://id.example.com", Username: "testuser", Password: "testtestlol", MFAToken: "123456", SkipPrompt: true, DisableKeychain: true} + loginFlags = &flags.LoginExecFlags{CommonFlags: commonFlags} + + loginDetails, err = resolveLoginDetails(idpa, loginFlags) + + assert.Nil(t, err) + assert.True(t, idpa.DisableSessions, fmt.Errorf("user disabled keychain, DisableSessions should be true")) + assert.True(t, idpa.DisableRememberDevice, fmt.Errorf("user disabled keychain, DisableRememberDevice should be true")) + assert.Equal(t, &creds.LoginDetails{Username: "testuser", Password: "testtestlol", URL: "https://id.example.com", MFAToken: "123456"}, loginDetails) + +} + func TestResolveRoleSingleEntry(t *testing.T) { adminRole := &saml2aws.AWSRole{ diff --git a/cmd/saml2aws/main.go b/cmd/saml2aws/main.go index 603654db1..712d89c00 100644 --- a/cmd/saml2aws/main.go +++ b/cmd/saml2aws/main.go @@ -80,7 +80,7 @@ func main() { app.Flag("aws-urn", "The URN used by SAML when you login. (env: SAML2AWS_AWS_URN)").Envar("SAML2AWS_AWS_URN").StringVar(&commonFlags.AmazonWebservicesURN) app.Flag("skip-prompt", "Skip prompting for parameters during login.").BoolVar(&commonFlags.SkipPrompt) app.Flag("session-duration", "The duration of your AWS Session. (env: SAML2AWS_SESSION_DURATION)").Envar("SAML2AWS_SESSION_DURATION").IntVar(&commonFlags.SessionDuration) - app.Flag("disable-keychain", "Do not use keychain at all.").Envar("SAML2AWS_DISABLE_KEYCHAIN").BoolVar(&commonFlags.DisableKeychain) + app.Flag("disable-keychain", "Do not use keychain at all. This will also disable Okta sessions & remembering MFA device. (env: SAML2AWS_DISABLE_KEYCHAIN)").Envar("SAML2AWS_DISABLE_KEYCHAIN").BoolVar(&commonFlags.DisableKeychain) app.Flag("region", "AWS region to use for API requests, e.g. us-east-1, us-gov-west-1, cn-north-1 (env: SAML2AWS_REGION)").Envar("SAML2AWS_REGION").Short('r').StringVar(&commonFlags.Region) // `configure` command and settings @@ -94,6 +94,8 @@ func main() { cmdConfigure.Flag("credentials-file", "The file that will cache the credentials retrieved from AWS. When not specified, will use the default AWS credentials file location. (env: SAML2AWS_CREDENTIALS_FILE)").Envar("SAML2AWS_CREDENTIALS_FILE").StringVar(&commonFlags.CredentialsFile) cmdConfigure.Flag("cache-saml", "Caches the SAML response (env: SAML2AWS_CACHE_SAML)").Envar("SAML2AWS_CACHE_SAML").BoolVar(&commonFlags.SAMLCache) cmdConfigure.Flag("cache-file", "The location of the SAML cache file (env: SAML2AWS_SAML_CACHE_FILE)").Envar("SAML2AWS_SAML_CACHE_FILE").StringVar(&commonFlags.SAMLCacheFile) + cmdConfigure.Flag("disable-sessions", "Do not use Okta sessions. Uses Okta sessions by default. (env: SAML2AWS_OKTA_DISABLE_SESSIONS)").Envar("SAML2AWS_OKTA_DISABLE_SESSIONS").BoolVar(&commonFlags.DisableSessions) + cmdConfigure.Flag("disable-remember-device", "Do not remember Okta MFA device. Remembers MFA device by default. (env: SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE)").Envar("SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE").BoolVar(&commonFlags.DisableRememberDevice) configFlags := commonFlags // `login` command and settings @@ -109,6 +111,8 @@ func main() { cmdLogin.Flag("credentials-file", "The file that will cache the credentials retrieved from AWS. When not specified, will use the default AWS credentials file location. (env: SAML2AWS_CREDENTIALS_FILE)").Envar("SAML2AWS_CREDENTIALS_FILE").StringVar(&commonFlags.CredentialsFile) cmdLogin.Flag("cache-saml", "Caches the SAML response (env: SAML2AWS_CACHE_SAML)").Envar("SAML2AWS_CACHE_SAML").BoolVar(&commonFlags.SAMLCache) cmdLogin.Flag("cache-file", "The location of the SAML cache file (env: SAML2AWS_SAML_CACHE_FILE)").Envar("SAML2AWS_SAML_CACHE_FILE").StringVar(&commonFlags.SAMLCacheFile) + cmdLogin.Flag("disable-sessions", "Do not use Okta sessions. Uses Okta sessions by default. (env: SAML2AWS_OKTA_DISABLE_SESSIONS)").Envar("SAML2AWS_OKTA_DISABLE_SESSIONS").BoolVar(&commonFlags.DisableSessions) + cmdLogin.Flag("disable-remember-device", "Do not remember Okta MFA device. Remembers MFA device by default. (env: SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE)").Envar("SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE").BoolVar(&commonFlags.DisableRememberDevice) // `exec` command and settings cmdExec := app.Command("exec", "Exec the supplied command with env vars from STS token.") diff --git a/go.mod b/go.mod index 3e7bd9e5c..155eec4aa 100644 --- a/go.mod +++ b/go.mod @@ -34,8 +34,8 @@ require ( github.com/tidwall/gjson v1.1.1 github.com/tidwall/match v1.0.0 // indirect golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad // indirect - golang.org/x/net v0.0.0-20210119194325-5f4716e94777 - golang.org/x/sys v0.0.0-20210218155724-8ebf48af031b // indirect + golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 + golang.org/x/sys v0.0.0-20210603125802-9665404d3644 // indirect golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf // indirect golang.org/x/text v0.3.5 // indirect gopkg.in/ini.v1 v1.62.0 diff --git a/go.sum b/go.sum index 01956e54f..7fb42cc3c 100644 --- a/go.sum +++ b/go.sum @@ -218,8 +218,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210119194325-5f4716e94777 h1:003p0dJM77cxMSyCPFphvZf/Y5/NXf5fzg6ufd1/Oew= -golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -239,8 +239,9 @@ golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210218155724-8ebf48af031b h1:lAZ0/chPUDWwjqosYR0X4M490zQhMsiJ4K3DbA7o+3g= -golang.org/x/sys v0.0.0-20210218155724-8ebf48af031b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210603125802-9665404d3644 h1:CA1DEQ4NdKphKeL70tvsWNdT5oFh1lOjihRcEDROi0I= +golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf h1:MZ2shdL+ZM/XzY3ZGOnh4Nlpnxz5GSOhOmtHo3iPU6M= diff --git a/pkg/cfg/cfg.go b/pkg/cfg/cfg.go index b2ad37471..b0c53f5b9 100644 --- a/pkg/cfg/cfg.go +++ b/pkg/cfg/cfg.go @@ -30,32 +30,35 @@ const ( // IDPAccount saml IDP account type IDPAccount struct { - Name string `ini:"name"` - AppID string `ini:"app_id"` // used by OneLogin and AzureAD - URL string `ini:"url"` - Username string `ini:"username"` - Provider string `ini:"provider"` - MFA string `ini:"mfa"` - SkipVerify bool `ini:"skip_verify"` - Timeout int `ini:"timeout"` - AmazonWebservicesURN string `ini:"aws_urn"` - SessionDuration int `ini:"aws_session_duration"` - Profile string `ini:"aws_profile"` - ResourceID string `ini:"resource_id"` // used by F5APM - Subdomain string `ini:"subdomain"` // used by OneLogin - RoleARN string `ini:"role_arn"` - Region string `ini:"region"` - HttpAttemptsCount string `ini:"http_attempts_count"` - HttpRetryDelay string `ini:"http_retry_delay"` - CredentialsFile string `ini:"credentials_file"` - SAMLCache bool `ini:"saml_cache"` - SAMLCacheFile string `ini:"saml_cache_file"` - TargetURL string `ini:"target_url"` + Name string `ini:"name"` + AppID string `ini:"app_id"` // used by OneLogin and AzureAD + URL string `ini:"url"` + Username string `ini:"username"` + Provider string `ini:"provider"` + MFA string `ini:"mfa"` + SkipVerify bool `ini:"skip_verify"` + Timeout int `ini:"timeout"` + AmazonWebservicesURN string `ini:"aws_urn"` + SessionDuration int `ini:"aws_session_duration"` + Profile string `ini:"aws_profile"` + ResourceID string `ini:"resource_id"` // used by F5APM + Subdomain string `ini:"subdomain"` // used by OneLogin + RoleARN string `ini:"role_arn"` + Region string `ini:"region"` + HttpAttemptsCount string `ini:"http_attempts_count"` + HttpRetryDelay string `ini:"http_retry_delay"` + CredentialsFile string `ini:"credentials_file"` + SAMLCache bool `ini:"saml_cache"` + SAMLCacheFile string `ini:"saml_cache_file"` + TargetURL string `ini:"target_url"` + DisableRememberDevice bool `ini:"disable_remember_device"` // used by Okta + DisableSessions bool `ini:"disable_sessions"` // used by Okta } func (ia IDPAccount) String() string { var appID string var policyID string + var oktaCfg string switch ia.Provider { case "OneLogin": appID = fmt.Sprintf(` @@ -66,9 +69,13 @@ func (ia IDPAccount) String() string { case "AzureAD": appID = fmt.Sprintf(` AppID: %s`, ia.AppID) + case "Okta": + oktaCfg = fmt.Sprintf(` + DisableSessions: %v + DisableRememberDevice: %v`, ia.DisableSessions, ia.DisableSessions) } - return fmt.Sprintf(`account {%s%s + return fmt.Sprintf(`account {%s%s%s URL: %s Username: %s Provider: %s @@ -79,7 +86,7 @@ func (ia IDPAccount) String() string { Profile: %s RoleARN: %s Region: %s -}`, appID, policyID, ia.URL, ia.Username, ia.Provider, ia.MFA, ia.SkipVerify, ia.AmazonWebservicesURN, ia.SessionDuration, ia.Profile, ia.RoleARN, ia.Region) +}`, appID, policyID, oktaCfg, ia.URL, ia.Username, ia.Provider, ia.MFA, ia.SkipVerify, ia.AmazonWebservicesURN, ia.SessionDuration, ia.Profile, ia.RoleARN, ia.Region) } // Validate validate the required / expected fields are set diff --git a/pkg/flags/flags.go b/pkg/flags/flags.go index a84b0ec3e..54ab38c82 100644 --- a/pkg/flags/flags.go +++ b/pkg/flags/flags.go @@ -6,30 +6,32 @@ import ( // CommonFlags flags common to all of the `saml2aws` commands (except `help`) type CommonFlags struct { - AppID string - ClientID string - ClientSecret string - ConfigFile string - IdpAccount string - IdpProvider string - MFA string - MFAToken string - URL string - Username string - Password string - RoleArn string - AmazonWebservicesURN string - SessionDuration int - SkipPrompt bool - SkipVerify bool - Profile string - Subdomain string - ResourceID string - DisableKeychain bool - Region string - CredentialsFile string - SAMLCache bool - SAMLCacheFile string + AppID string + ClientID string + ClientSecret string + ConfigFile string + IdpAccount string + IdpProvider string + MFA string + MFAToken string + URL string + Username string + Password string + RoleArn string + AmazonWebservicesURN string + SessionDuration int + SkipPrompt bool + SkipVerify bool + Profile string + Subdomain string + ResourceID string + DisableKeychain bool + Region string + CredentialsFile string + SAMLCache bool + SAMLCacheFile string + DisableRememberDevice bool + DisableSessions bool } // LoginExecFlags flags for the Login / Exec commands @@ -106,4 +108,10 @@ func ApplyFlagOverrides(commonFlags *CommonFlags, account *cfg.IDPAccount) { if commonFlags.SAMLCacheFile != "" { account.SAMLCacheFile = commonFlags.SAMLCacheFile } + if commonFlags.DisableRememberDevice { + account.DisableRememberDevice = commonFlags.DisableRememberDevice + } + if commonFlags.DisableSessions { + account.DisableSessions = commonFlags.DisableSessions + } } diff --git a/pkg/provider/okta/okta.go b/pkg/provider/okta/okta.go index 0247c61d8..fb0cb6497 100644 --- a/pkg/provider/okta/okta.go +++ b/pkg/provider/okta/okta.go @@ -13,6 +13,7 @@ import ( "net/http/cookiejar" "net/url" "regexp" + "strconv" "strings" "time" @@ -60,9 +61,11 @@ var ( type Client struct { provider.ValidateBase - client *provider.HTTPClient - mfa string - targetURL string + client *provider.HTTPClient + mfa string + targetURL string + disableSessions bool + rememberDevice bool } // AuthRequest represents an mfa okta request @@ -76,8 +79,12 @@ type AuthRequest struct { type VerifyRequest struct { StateToken string `json:"stateToken"` PassCode string `json:"passCode,omitempty"` - RememberDevice string `json:"rememberDevice,omitempty"` -} //https://developer.okta.com/docs/reference/api/authn/#verify-security-question-factor + RememberDevice string `json:"rememberDevice,omitempty"` // This is needed to remember Okta MFA device +} + +// Articles referencing the Okta MFA + remembering device +// https://developer.okta.com/docs/reference/api/authn/#verify-security-question-factor +// https://devforum.okta.com/t/how-per-device-remember-me-api-works/3955/3 // SessionRequst holds the SessionToken used to create an Okta Session type SessionRequst struct { @@ -113,78 +120,29 @@ func New(idpAccount *cfg.IDPAccount) (*Client, error) { } client.Jar = jar - return &Client{ - client: client, - mfa: idpAccount.MFA, - targetURL: idpAccount.TargetURL, - }, nil -} - -type ctxKey string - -func (oc *Client) validateSession(loginDetails *creds.LoginDetails) error { - logger.Debug("validate session func called") - - if loginDetails == nil { - logger.Debugf("unable to validate the okta session, nil input | loginDetails: %v ", loginDetails) - return fmt.Errorf("unable to validate the okta session, nil input") - } - - sessionCookie := loginDetails.OktaSessionCookie - - oktaURL, err := url.Parse(loginDetails.URL) - if err != nil { - return errors.Wrap(err, "error building oktaURL") - } - - oktaOrgHost := oktaURL.Host - - sessionReqURL := fmt.Sprintf("https://%s/api/v1/sessions/me", oktaOrgHost) - sessionReqBody := new(bytes.Buffer) + disableSessions := idpAccount.DisableSessions + rememberDevice := !idpAccount.DisableRememberDevice - req, err := http.NewRequest("GET", sessionReqURL, sessionReqBody) - if err != nil { - return errors.Wrap(err, "error building new session request") + if idpAccount.DisableSessions { // if user disabled sessions, also dont remember device + rememberDevice = false } - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - req.Header.Add("Cookie", fmt.Sprintf("sid=%s", sessionCookie)) - res, err := oc.client.Do(req) - if err != nil { - return errors.Wrap(err, "error retrieving session response") - } - - body, err := ioutil.ReadAll(res.Body) - if err != nil { - return errors.Wrap(err, "error retrieving body from response") - } - - resp := string(body) - - if res.StatusCode != 200 { - logger.Debug("invalid okta session") - return fmt.Errorf("invalid okta session") - } else { - sessionResponseStatus := gjson.Get(resp, "status").String() - switch sessionResponseStatus { - case "ACTIVE": - logger.Debug("okta session established") - case "MFA_REQUIRED": - _, err := verifyMfa(oc, oktaOrgHost, loginDetails, resp) - if err != nil { - return errors.Wrap(err, "error verifying MFA") - } - case "MFA_ENROLL": - // Not yet fully implemented, so just return the status as the error string... - return fmt.Errorf("MFA_ENROLL") - } - } + // Debug the disableSessions and rememberDevice values + logger.Debugf("okta | disableSessions: %v", disableSessions) + logger.Debugf("okta | rememberDevice: %v", rememberDevice) - logger.Debug("valid okta session") - return nil + return &Client{ + client: client, + mfa: idpAccount.MFA, + targetURL: idpAccount.TargetURL, + disableSessions: disableSessions, + rememberDevice: rememberDevice, + }, nil } +type ctxKey string + +// createSession calls the Okta sessions API to create a new session using the sessionToken passed in func (oc *Client) createSession(loginDetails *creds.LoginDetails, sessionToken string) (string, string, error) { logger.Debug("create session func called") if loginDetails == nil || sessionToken == "" { @@ -227,8 +185,11 @@ func (oc *Client) createSession(loginDetails *creds.LoginDetails, sessionToken s return "", "", errors.Wrap(err, "error retrieving body from response") } - if res.StatusCode == 401 { // https://developer.okta.com/docs/reference/api/sessions/#response-parameters - return "", "", fmt.Errorf("unable to create an Okta session, invalid sessionToken") + if res.StatusCode == 200 { // https://developer.okta.com/docs/reference/api/sessions/#response-parameters + if res.StatusCode == 401 { + return "", "", fmt.Errorf("unable to create an Okta session, invalid sessionToken") + } + return "", "", fmt.Errorf("unable to create an Okta session, HTTP Code: %d", res.StatusCode) } resp := string(body) @@ -240,8 +201,9 @@ func (oc *Client) createSession(loginDetails *creds.LoginDetails, sessionToken s err = credentials.SaveCredentials(loginDetails.URL+"/sessionCookie", loginDetails.Username, oktaSessionCookie) if err != nil { - log.Printf("error storing okta session token | err: %v", err) //TODO: handle this properly instead of dumping to stdout + return "", "", fmt.Errorf("error storing okta session token | err: %v", err) } + oktaSessionToken := gjson.Get(resp, "sessionToken").String() sessionResponseStatus := gjson.Get(resp, "status").String() switch sessionResponseStatus { @@ -253,13 +215,79 @@ func (oc *Client) createSession(loginDetails *creds.LoginDetails, sessionToken s return "", "", errors.Wrap(err, "error verifying MFA") } case "MFA_ENROLL": - // Not yet fully implemented, so just return the status as the error string... + // Not yet fully implemented, most likely no need, so just return the status as the error string... return "", "", fmt.Errorf("MFA_ENROLL") } return oktaSessionCookie, oktaSessionToken, nil } +// validateSession calls the Okta session API to check if the session is valid +// returns an error if the session is NOT valid +func (oc *Client) validateSession(loginDetails *creds.LoginDetails) error { + logger.Debug("validate session func called") + + if loginDetails == nil { + logger.Debug("unable to validate the okta session, nil input") + return fmt.Errorf("unable to validate the okta session, nil input") + } + + sessionCookie := loginDetails.OktaSessionCookie + + oktaURL, err := url.Parse(loginDetails.URL) + if err != nil { + return errors.Wrap(err, "error building oktaURL") + } + + oktaOrgHost := oktaURL.Host + + sessionReqURL := fmt.Sprintf("https://%s/api/v1/sessions/me", oktaOrgHost) // This api endpoint returns user details + sessionReqBody := new(bytes.Buffer) + + req, err := http.NewRequest("GET", sessionReqURL, sessionReqBody) + if err != nil { + return errors.Wrap(err, "error building new session request") + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + req.Header.Add("Cookie", fmt.Sprintf("sid=%s", sessionCookie)) + + res, err := oc.client.Do(req) + if err != nil { + return errors.Wrap(err, "error retrieving session response") + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return errors.Wrap(err, "error retrieving body from response") + } + + resp := string(body) + + if res.StatusCode != 200 { + logger.Debug("invalid okta session") + return fmt.Errorf("invalid okta session") + } else { + sessionResponseStatus := gjson.Get(resp, "status").String() + switch sessionResponseStatus { + case "ACTIVE": + logger.Debug("okta session established") + case "MFA_REQUIRED": + _, err := verifyMfa(oc, oktaOrgHost, loginDetails, resp) + if err != nil { + return errors.Wrap(err, "error verifying MFA") + } + case "MFA_ENROLL": + // Not yet fully implemented, so just return the status as the error string... + return fmt.Errorf("MFA_ENROLL") + } + } + + logger.Debug("valid okta session") + return nil +} + +// authWithSession authenticates user via sessions API -> direct to target URL using follow func func (oc *Client) authWithSession(loginDetails *creds.LoginDetails) (string, error) { logger.Debug("auth with session func called") sessionCookie := loginDetails.OktaSessionCookie @@ -291,6 +319,7 @@ func (oc *Client) authWithSession(loginDetails *creds.LoginDetails) (string, err logger.Debugf("error reading body for auth with session: %v", err) } + // This usually happens if using an active session (> 5 mins) but MFA was NOT remembered if strings.Contains(string(body), "/login/step-up/") { // https://developer.okta.com/docs/reference/api/authn/#step-up-authentication-with-okta-session logger.Debug("okta step-up prompted, need mfa...") stateToken, err := getStateTokenFromOktaPageBody(string(body)) @@ -301,38 +330,48 @@ func (oc *Client) authWithSession(loginDetails *creds.LoginDetails) (string, err return oc.Authenticate(loginDetails) } - // fmt.Println(resp) - // os.Exit(0) return oc.follow(ctx, req, loginDetails) } -// https://devforum.okta.com/t/how-per-device-remember-me-api-works/3955/3 +// getDeviceTokenFromOkta creates a dummy HTTP call to Okta and returns the device token +// cookie value +// This function is not currently used and but can be used in the future +func (oc *Client) getDeviceTokenFromOkta(loginDetails *creds.LoginDetails) (string, error) { + //dummy request to set device token cookie ("dt") + req, err := http.NewRequest("GET", loginDetails.URL, nil) + if err != nil { + return "", errors.Wrap(err, "error building device token request") + } + resp, err := oc.client.Do(req) + if err != nil { + return "", errors.Wrap(err, "error retrieving device token") + } -// func (oc *Client) getDeviceTokenFromOkta(loginDetails *creds.LoginDetails) (string, error) { -// //dummy request to set device token cookie ("dt") -// req, err := http.NewRequest("GET", loginDetails.URL, nil) -// if err != nil { -// return "", errors.Wrap(err, "error building device token request") -// } -// resp, err := oc.client.Do(req) -// if err != nil { -// return "", errors.Wrap(err, "error retrieving device token") -// } - -// for _, c := range resp.Cookies() { -// if c.Name == "DT" { // Device token -// return c.Value, nil -// } -// } - -// return "", fmt.Errorf("unable to get a device token from okta") -// } + for _, c := range resp.Cookies() { + if c.Name == "DT" { // Device token + return c.Value, nil + } + } + + return "", fmt.Errorf("unable to get a device token from okta") +} +// setDeviceTokenCookie sets the DT cookie in the HTTP Client cookie jar +// using the okta__saml2aws, we reduce making an extra api call +// this func can be uplifted in the future to set custom device tokens or used with +// getDeviceTokenFromOkta function func (oc *Client) setDeviceTokenCookie(loginDetails *creds.LoginDetails) error { + // getDeviceTokenFromOkta is not used but doing this to keep the function code + // uncommented (avoid linting issues) + if false { + dt, _ := oc.getDeviceTokenFromOkta(loginDetails) + logger.Debugf("getDeviceTokenFromOkta is not yet implemented: dt: %s", dt) + } + oktaURL, err := url.Parse(loginDetails.URL) if err != nil { - return errors.Wrap(err, "error building oktaURL to set device token") + return errors.Wrap(err, "error building oktaURL to set device token cookie") } oktaURLScheme := oktaURL.Scheme oktaURLHost := oktaURL.Host @@ -342,8 +381,8 @@ func (oc *Client) setDeviceTokenCookie(loginDetails *creds.LoginDetails) error { cookie := http.Cookie{ Name: "DT", Secure: true, - Expires: time.Now().Add(time.Hour * 24 * 30), // 30 Days - Value: fmt.Sprintf("okta_%s_saml2aws", loginDetails.Username), + Expires: time.Now().Add(time.Hour * 24 * 30), // 30 Days -> this time might not matter as this cookie is set on every saml2aws login request + Value: fmt.Sprintf("okta_%s_saml2aws", loginDetails.Username), // Okta recommends using an UUID but this should be unique enough. Also, this is key to remembering Okta MFA device } cookies = append(cookies, &cookie) oc.client.Jar.SetCookies(baseURL, cookies) @@ -351,45 +390,16 @@ func (oc *Client) setDeviceTokenCookie(loginDetails *creds.LoginDetails) error { return nil } -// Authenticate logs into Okta and returns a SAML response -func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) { - // Get Okta session cookie (sid) from login details (if found via login.go) - oktaSessionCookie := loginDetails.OktaSessionCookie - - // Set Okta device token - err := oc.setDeviceTokenCookie(loginDetails) - if err != nil { - return "", errors.Wrap(err, "error setting device token in cookie jar") - } - - // If Okta session cookie is not empty - if oktaSessionCookie != "" && loginDetails.StateToken == "" { - return oc.authWithSession(loginDetails) - } +// primaryAuth creates the Okta Primary Authentication request +// returns the authStatus, sessionToken, http response and a error +func (oc *Client) primaryAuth(loginDetails *creds.LoginDetails) (string, string, string, error) { oktaURL, err := url.Parse(loginDetails.URL) if err != nil { - return "", errors.Wrap(err, "error building oktaURL") + return "", "", "", errors.Wrap(err, "error building oktaURL") } oktaOrgHost := oktaURL.Host - - // // Get the current list of cookies (if any) - // currentCookies := oc.client.Jar.Cookies(oktaURL) - // // Create Okta session cookie containing sid - // osc := http.Cookie{ - // Name: "sid", - // Value: oktaSessionCookie, - // Raw: fmt.Sprintf("sid=%s", oktaSessionCookie), - // Path: "/", - // HttpOnly: true, - // Secure: true, - // } - // // Add the session cookie to cookie jar - // currentCookies = append(currentCookies, &osc) - // // Set the cookie jar back to the http client - // oc.client.Jar.SetCookies(oktaURL, currentCookies) - //authenticate via okta api authReq := AuthRequest{Username: loginDetails.Username, Password: loginDetails.Password} if loginDetails.StateToken != "" { @@ -398,14 +408,14 @@ func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) authBody := new(bytes.Buffer) err = json.NewEncoder(authBody).Encode(authReq) if err != nil { - return "", errors.Wrap(err, "error encoding authreq") + return "", "", "", errors.Wrap(err, "error encoding authreq") } authSubmitURL := fmt.Sprintf("https://%s/api/v1/authn", oktaOrgHost) req, err := http.NewRequest("POST", authSubmitURL, authBody) if err != nil { - return "", errors.Wrap(err, "error building authentication request") + return "", "", "", errors.Wrap(err, "error building authentication request") } req.Header.Add("Content-Type", "application/json") @@ -413,12 +423,12 @@ func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) res, err := oc.client.Do(req) if err != nil { - return "", errors.Wrap(err, "error retrieving auth response") + return "", "", "", errors.Wrap(err, "error retrieving auth response") } body, err := ioutil.ReadAll(res.Body) if err != nil { - return "", errors.Wrap(err, "error retrieving body from response") + return "", "", "", errors.Wrap(err, "error retrieving body from response") } resp := string(body) @@ -426,14 +436,72 @@ func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) authStatus := gjson.Get(resp, "status").String() oktaSessionToken := gjson.Get(resp, "sessionToken").String() + return authStatus, oktaSessionToken, resp, nil +} + +// Authenticate logs into Okta and returns a SAML response +func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) { + + // Set Okta device token + err := oc.setDeviceTokenCookie(loginDetails) + if err != nil { + return "", errors.Wrap(err, "error setting device token in cookie jar") + } + + // Get Okta session cookie (sid) from login details (if found via login.go) + oktaSessionCookie := loginDetails.OktaSessionCookie + + // If user disabled sessions, do not use sessions API + if !oc.disableSessions { + // If Okta session cookie is not empty + // Note on checking StateToken: StateToken is set in the follow func + // if the follow func calls this function (Authenticate), it means the session requires MFA to continue + // so don't call authWithSession, instead flow through to create the primary authentication call + if oktaSessionCookie != "" && loginDetails.StateToken == "" { + return oc.authWithSession(loginDetails) + } + } + + oktaURL, err := url.Parse(loginDetails.URL) + if err != nil { + return "", errors.Wrap(err, "error building oktaURL") + } + + oktaOrgHost := oktaURL.Host + + authStatus, oktaSessionToken, primaryAuthResp, err := oc.primaryAuth(loginDetails) + if err != nil { + return "", err + } + // mfa required if authStatus == "MFA_REQUIRED" { - oktaSessionToken, err = verifyMfa(oc, oktaOrgHost, loginDetails, resp) + oktaSessionToken, err = verifyMfa(oc, oktaOrgHost, loginDetails, primaryAuthResp) if err != nil { return "", errors.Wrap(err, "error verifying MFA") } } + // if user disabled sessions, default to using standard login WITHOUT sessions + if oc.disableSessions { + //now call saml endpoint + oktaSessionRedirectURL := fmt.Sprintf("https://%s/login/sessionCookieRedirect", oktaOrgHost) + + req, err := http.NewRequest("GET", oktaSessionRedirectURL, nil) + if err != nil { + return "", errors.Wrap(err, "error building authentication request") + } + q := req.URL.Query() + q.Add("checkAccountSetupComplete", "true") + q.Add("token", oktaSessionToken) + q.Add("redirectUrl", loginDetails.URL) + req.URL.RawQuery = q.Encode() + + ctx := context.WithValue(context.Background(), ctxKey("login"), loginDetails) + return oc.follow(ctx, req, loginDetails) + } + + // Only reaches here if user DID NOT DISABLE okta sessions if oktaSessionCookie == "" { oktaSessionCookie, _, err = oc.createSession(loginDetails, oktaSessionToken) if err != nil { @@ -443,22 +511,6 @@ func (oc *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) } return oc.authWithSession(loginDetails) - - // //now call saml endpoint - // oktaSessionRedirectURL := fmt.Sprintf("https://%s/login/sessionCookieRedirect", oktaOrgHost) - - // req, err = http.NewRequest("GET", oktaSessionRedirectURL, nil) - // if err != nil { - // return "", errors.Wrap(err, "error building authentication request") - // } - // q := req.URL.Query() - // q.Add("checkAccountSetupComplete", "true") - // q.Add("token", oktaSessionToken) - // q.Add("redirectUrl", loginDetails.URL) - // req.URL.RawQuery = q.Encode() - - // ctx := context.WithValue(context.Background(), ctxKey("login"), loginDetails) - // return oc.follow(ctx, req, loginDetails) } func (oc *Client) follow(ctx context.Context, req *http.Request, loginDetails *creds.LoginDetails) (string, error) { @@ -468,7 +520,6 @@ func (oc *Client) follow(ctx context.Context, req *http.Request, loginDetails *c if ctx.Value(ctxKey("authWithSession")) != nil { logger.Debug("follow func called from auth with session func") - // oc.client.DisableFollowRedirect() } res, err := oc.client.Do(req) @@ -484,17 +535,6 @@ func (oc *Client) follow(ctx context.Context, req *http.Request, loginDetails *c var handler func(context.Context, *goquery.Document) (context.Context, *http.Request, error) - // if ctx.Value(ctxKey("authWithSession")) != nil { - // logger.Debug("follow func called from auth with session func") - // handler = oc.handleFormRedirect - // modifiedCtx := context.WithValue(context.Background(), ctxKey("follow"), loginDetails) - // modifiedCtx, req, err = handler(modifiedCtx, doc) - // if err != nil { - // return "", err - // } - // return oc.follow(modifiedCtx, req, loginDetails) - // } - if docIsFormRedirectToTarget(doc, oc.targetURL) { logger.WithField("type", "saml-response-to-aws").Debug("doc detect") if samlResponse, ok := extractSAMLResponse(doc); ok { @@ -634,7 +674,7 @@ func getMfaChallengeContext(oc *Client, mfaOption int, resp string) (*mfaChallen } // get signature & callback - verifyReq := VerifyRequest{StateToken: stateToken, RememberDevice: "true"} + verifyReq := VerifyRequest{StateToken: stateToken, RememberDevice: strconv.FormatBool(oc.rememberDevice)} verifyBody := new(bytes.Buffer) // Login flow is different for YubiKeys ( of course ) @@ -678,7 +718,6 @@ func getMfaChallengeContext(oc *Client, mfaOption int, resp string) (*mfaChallen }, nil } -// TODO: set device token https://developer.okta.com/docs/reference/api/authn/#context-object func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, resp string) (string, error) { stateToken := gjson.Get(resp, "stateToken").String() @@ -713,7 +752,7 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, if verifyCode == "" { verifyCode = prompter.StringRequired("Enter verification code") } - tokenReq := VerifyRequest{StateToken: stateToken, PassCode: verifyCode, RememberDevice: "true"} + tokenReq := VerifyRequest{StateToken: stateToken, PassCode: verifyCode, RememberDevice: strconv.FormatBool(oc.rememberDevice)} tokenBody := new(bytes.Buffer) err = json.NewEncoder(tokenBody).Encode(tokenReq) if err != nil { @@ -1026,7 +1065,7 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, // extract okta session token - verifyReq := VerifyRequest{StateToken: stateToken, RememberDevice: "true"} + verifyReq := VerifyRequest{StateToken: stateToken, RememberDevice: strconv.FormatBool(oc.rememberDevice)} verifyBody := new(bytes.Buffer) err = json.NewEncoder(verifyBody).Encode(verifyReq) if err != nil { diff --git a/pkg/provider/okta/okta_test.go b/pkg/provider/okta/okta_test.go index 4eadaf56e..9bae1e4e9 100644 --- a/pkg/provider/okta/okta_test.go +++ b/pkg/provider/okta/okta_test.go @@ -2,9 +2,13 @@ package okta import ( "errors" + "fmt" + "net/url" "testing" "github.com/stretchr/testify/assert" + "github.com/versent/saml2aws/v2/pkg/cfg" + "github.com/versent/saml2aws/v2/pkg/creds" ) type stateTokenTests struct { @@ -47,3 +51,66 @@ func TestGetStateTokenFromOktaPageBody(t *testing.T) { }) } } + +func TestSetDeviceTokenCookie(t *testing.T) { + idpAccount := cfg.NewIDPAccount() + idpAccount.URL = "https://idp.example.com/abcd" + idpAccount.Username = "user@example.com" + + loginDetails := &creds.LoginDetails{ + Username: "user@example.com", + Password: "abc123", + URL: "https://idp.example.com/abcd", + } + + oc, err := New(idpAccount) + assert.Nil(t, err) + + err = oc.setDeviceTokenCookie(loginDetails) + assert.Nil(t, err) + + expectedDT := fmt.Sprintf("okta_%s_saml2aws", loginDetails.Username) + actualDT := "" + for _, c := range oc.client.Jar.Cookies(&url.URL{Scheme: "https", Host: "idp.example.com", Path: "/abc"}) { + if c.Name == "DT" { + actualDT = c.Value + } + } + assert.NotEqual(t, actualDT, "") + assert.Equal(t, expectedDT, actualDT) + +} + +func TestOktaCfgFlagsDefaultState(t *testing.T) { + idpAccount := cfg.NewIDPAccount() + idpAccount.URL = "https://idp.example.com/abcd" + idpAccount.Username = "user@example.com" + + oc, err := New(idpAccount) + assert.Nil(t, err) + + assert.False(t, oc.disableSessions, fmt.Errorf("disableSessions should be false by default")) + assert.True(t, oc.rememberDevice, fmt.Errorf("rememberDevice should be true by default")) +} + +func TestOktaCfgFlagsCustomState(t *testing.T) { + idpAccount := cfg.NewIDPAccount() + idpAccount.URL = "https://idp.example.com/abcd" + idpAccount.Username = "user@example.com" + + idpAccount.DisableRememberDevice = true + oc, err := New(idpAccount) + assert.Nil(t, err) + + assert.False(t, oc.disableSessions, fmt.Errorf("disableSessions should be false by default")) + assert.False(t, oc.rememberDevice, fmt.Errorf("DisableRememberDevice was set to true, so rememberDevice should be false")) + + idpAccount.DisableSessions = true + + oc, err = New(idpAccount) + assert.Nil(t, err) + + assert.True(t, oc.disableSessions, fmt.Errorf("DisableSessions was set to true so disableSessions should be true")) + assert.False(t, oc.rememberDevice, fmt.Errorf("DisablDisableSessionseRememberDevice was set to true, so rememberDevice should be false")) + +} From 7e77cfb654e47d996085424f5b34a495843d201d Mon Sep 17 00:00:00 2001 From: Hiruna Wijesinghe Date: Mon, 7 Jun 2021 15:17:23 +1000 Subject: [PATCH 3/5] add details to readme --- README.md | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 286100e41..fedaadcaa 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,8 @@ Commands: --config=CONFIG Path/filename of saml2aws config file (env: SAML2AWS_CONFIGFILE) --cache-saml Caches the SAML response (env: SAML2AWS_CACHE_SAML) --cache-file=CACHE-FILE The location of the SAML cache file (env: SAML2AWS_SAML_CACHE_FILE) + --disable-sessions Do not use Okta sessions. Uses Okta sessions by default. (env: SAML2AWS_OKTA_DISABLE_SESSIONS) + --disable-remember-device Do not remember Okta MFA device. Remembers MFA device by default. (env: SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE) login [] Login to a SAML 2.0 IDP and convert the SAML assertion to an STS token. @@ -199,7 +201,8 @@ Commands: The file that will cache the credentials retrieved from AWS. When not specified, will use the default AWS credentials file location. (env: SAML2AWS_CREDENTIALS_FILE) --cache-saml Caches the SAML response (env: SAML2AWS_CACHE_SAML) --cache-file=CACHE-FILE The location of the SAML cache file (env: SAML2AWS_SAML_CACHE_FILE) - + --disable-sessions Do not use Okta sessions. Uses Okta sessions by default. (env: SAML2AWS_OKTA_DISABLE_SESSIONS) + --disable-remember-device Do not remember Okta MFA device. Remembers MFA device by default. (env: SAML2AWS_OKTA_DISABLE_REMEMBER_DEVICE) exec [] [...] Exec the supplied command with env vars from STS token. @@ -677,6 +680,23 @@ there is a file per saml2aws profile, the cache directory is called `saml2aws` a You can toggle `--cache-saml` during `login` or during `list-roles`, and you can set it once during `configure` and use it implicitly. +# Okta Sessions + +This requires the use of the keychain (local credentials store). If you disabled the keychain using `--disable-keychain`, Okta sessions will also be disabled. + +Okta sessions are enabled by default. This will store the Okta session locally and save your device for MFA. This means that if the session has not yet expired, you will not be prompted for MFA. + +* To disable remembering the device, you can toggle `--disable-remember-device` during `login` or `configure` commands. +* To disable using Okta sessions, you can toggle `--disable-sessions` during `login` or `configure` commands. + * This will also disable the Okta MFA remember device feature + +Use the `--force` flag during `login` command to prompt for AWS role selection. + +If Okta sessions are disabled via any of the methods mentioned above, the login process will default to the standard authentication process (without using sessions). + +Please note that your Okta session duration and MFA policies are governed by your Okta host organization. + + # License This code is Copyright (c) 2018 [Versent](http://versent.com.au) and released under the MIT license. All rights not explicitly granted in the MIT license are reserved. See the included LICENSE.md file for more details. From 4693b7d9c8c9e8b1ff98426f324c3fe5b1623ad9 Mon Sep 17 00:00:00 2001 From: Hiruna Wijesinghe Date: Tue, 8 Jun 2021 10:56:39 +1000 Subject: [PATCH 4/5] fixed operator --- pkg/provider/okta/okta.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/provider/okta/okta.go b/pkg/provider/okta/okta.go index fb0cb6497..6c12a21be 100644 --- a/pkg/provider/okta/okta.go +++ b/pkg/provider/okta/okta.go @@ -185,7 +185,7 @@ func (oc *Client) createSession(loginDetails *creds.LoginDetails, sessionToken s return "", "", errors.Wrap(err, "error retrieving body from response") } - if res.StatusCode == 200 { // https://developer.okta.com/docs/reference/api/sessions/#response-parameters + if res.StatusCode != 200 { // https://developer.okta.com/docs/reference/api/sessions/#response-parameters if res.StatusCode == 401 { return "", "", fmt.Errorf("unable to create an Okta session, invalid sessionToken") } From 752a5b786a70ba1bc67b5d877b041fe5b65a1afa Mon Sep 17 00:00:00 2001 From: Hiruna Wijesinghe Date: Tue, 8 Jun 2021 15:06:48 +1000 Subject: [PATCH 5/5] change fmt println to logger debug output --- pkg/provider/okta/okta.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/provider/okta/okta.go b/pkg/provider/okta/okta.go index 6c12a21be..9855fa85b 100644 --- a/pkg/provider/okta/okta.go +++ b/pkg/provider/okta/okta.go @@ -791,7 +791,7 @@ func verifyMfa(oc *Client, oktaOrgHost string, loginDetails *creds.LoginDetails, // on 'success' status if gjson.Get(body, "status").String() == "SUCCESS" { fmt.Printf(" Approved\n\n") - fmt.Println(gjson.Get(body, "expiresAt").String()) // DEBUG + logger.Debugf("func verifyMfa | okta exiry: %s", gjson.Get(body, "expiresAt").String()) // DEBUG return gjson.Get(body, "sessionToken").String(), nil }