Skip to content

Commit

Permalink
Merge pull request #661 from snowflakedb/implementMFAConnectionCaching
Browse files Browse the repository at this point in the history
Add connection caching for mfa and id token
  • Loading branch information
sfc-gh-ext-simba-lb authored Nov 2, 2022
2 parents a2529a2 + db46032 commit 82aecfa
Show file tree
Hide file tree
Showing 194 changed files with 18,870 additions and 16 deletions.
106 changes: 90 additions & 16 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ const (
clientType = "Go"
)

const (
idToken = "ID_TOKEN"
mfaToken = "MFATOKEN"
clientStoreTemporaryCredential = "CLIENT_STORE_TEMPORARY_CREDENTIAL"
clientRequestMfaToken = "CLIENT_REQUEST_MFA_TOKEN"
idTokenAuthenticator = "ID_TOKEN"
)

// AuthType indicates the type of authentication in Snowflake
type AuthType int

Expand All @@ -40,6 +48,8 @@ const (
AuthTypeJwt
// AuthTypeTokenAccessor is to use the provided token accessor and bypass authentication
AuthTypeTokenAccessor
// AuthTypeUsernamePasswordMFA is to use username and password with mfa
AuthTypeUsernamePasswordMFA
)

func determineAuthenticatorType(cfg *Config, value string) error {
Expand All @@ -57,6 +67,9 @@ func determineAuthenticatorType(cfg *Config, value string) error {
} else if upperCaseValue == AuthTypeExternalBrowser.String() {
cfg.Authenticator = AuthTypeExternalBrowser
return nil
} else if upperCaseValue == AuthTypeUsernamePasswordMFA.String() {
cfg.Authenticator = AuthTypeUsernamePasswordMFA
return nil
} else {
// possibly Okta case
oktaURLString, err := url.QueryUnescape(lowerCaseValue)
Expand Down Expand Up @@ -104,6 +117,8 @@ func (authType AuthType) String() string {
return "SNOWFLAKE_JWT"
case AuthTypeTokenAccessor:
return "TOKENACCESSOR"
case AuthTypeUsernamePasswordMFA:
return "USERNAME_PASSWORD_MFA"
default:
return "UNKNOWN"
}
Expand Down Expand Up @@ -168,6 +183,10 @@ type authResponseMain struct {
Validity time.Duration `json:"validityInSeconds,omitempty"`
MasterToken string `json:"masterToken,omitempty"`
MasterValidity time.Duration `json:"masterValidityInSeconds"`
MfaToken string `json:"mfaToken,omitempty"`
MfaTokenValidity time.Duration `json:"mfaTokenValidityInSeconds"`
IDToken string `json:"idToken,omitempty"`
IDTokenValidity time.Duration `json:"idTokenValidityInSeconds"`
DisplayUserName string `json:"displayUserName"`
ServerVersion string `json:"serverVersion"`
FirstLogin bool `json:"firstLogin"`
Expand All @@ -182,6 +201,7 @@ type authResponseMain struct {
SSOURL string `json:"ssoUrl,omitempty"`
ProofKey string `json:"proofKey,omitempty"`
}

type authResponse struct {
Data authResponseMain `json:"data"`
Message string `json:"message"`
Expand Down Expand Up @@ -281,7 +301,6 @@ func authenticate(
}

sessionParameters[sessionClientValidateDefaultParameters] = sc.cfg.ValidateDefaultParameters != ConfigBoolFalse

requestMain := authRequestData{
ClientAppID: clientType,
ClientAppVersion: SnowflakeGoDriverVersion,
Expand All @@ -292,10 +311,15 @@ func authenticate(

switch sc.cfg.Authenticator {
case AuthTypeExternalBrowser:
requestMain.ProofKey = string(proofKey)
requestMain.Token = string(samlResponse)
requestMain.LoginName = sc.cfg.User
requestMain.Authenticator = AuthTypeExternalBrowser.String()
if sc.cfg.IDToken != "" {
requestMain.Authenticator = idTokenAuthenticator
requestMain.Token = sc.cfg.IDToken
} else {
requestMain.ProofKey = string(proofKey)
requestMain.Token = string(samlResponse)
requestMain.LoginName = sc.cfg.User
requestMain.Authenticator = AuthTypeExternalBrowser.String()
}
case AuthTypeOAuth:
requestMain.LoginName = sc.cfg.User
requestMain.Authenticator = AuthTypeOAuth.String()
Expand All @@ -321,6 +345,13 @@ func authenticate(
requestMain.Passcode = sc.cfg.Passcode
requestMain.ExtAuthnDuoMethod = "passcode"
}
case AuthTypeUsernamePasswordMFA:
logger.Info("Username and password MFA")
requestMain.LoginName = sc.cfg.User
requestMain.Password = sc.cfg.Password
if sc.cfg.MfaToken != "" {
requestMain.Token = sc.cfg.MfaToken
}
case AuthTypeTokenAccessor:
logger.Info("Bypass authentication using existing token from token accessor")
sessionInfo := authResponseSessionInfo{
Expand Down Expand Up @@ -370,6 +401,12 @@ func authenticate(
if !respd.Success {
logger.Errorln("Authentication FAILED")
sc.rest.TokenAccessor.SetTokens("", "", -1)
if sessionParameters[clientRequestMfaToken] == "true" {
deleteCredential(sc, mfaToken)
}
if sessionParameters[clientStoreTemporaryCredential] == "true" {
deleteCredential(sc, idToken)
}
code, err := strconv.Atoi(respd.Code)
if err != nil {
code = -1
Expand All @@ -383,6 +420,12 @@ func authenticate(
}
logger.Info("Authentication SUCCESS")
sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID)
if sc.isClientRequestMfaToken() {
setCredential(sc, mfaToken, respd.Data.MfaToken)
}
if sc.isClientStoreTemporaryCredential() {
setCredential(sc, idToken, respd.Data.IDToken)
}
return &respd.Data, nil
}

Expand Down Expand Up @@ -421,20 +464,43 @@ func authenticateWithConfig(sc *snowflakeConn) error {
var samlResponse []byte
var proofKey []byte
var err error
//var consentCacheIdToken = true

paramBoolValue := "true"
if sc.cfg.Authenticator == AuthTypeExternalBrowser {
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
sc.cfg.Params[clientStoreTemporaryCredential] = &paramBoolValue
}
if sc.isClientStoreTemporaryCredential() {
fillCachedIDToken(sc)
}
}

if sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA {
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
sc.cfg.Params[clientRequestMfaToken] = &paramBoolValue
}
if sc.isClientRequestMfaToken() {
fillCachedMfaToken(sc)
}
}

logger.Infof("Authenticating via %v", sc.cfg.Authenticator.String())
switch sc.cfg.Authenticator {
case AuthTypeExternalBrowser:
samlResponse, proofKey, err = authenticateByExternalBrowser(
sc.ctx,
sc.rest,
sc.cfg.Authenticator.String(),
sc.cfg.Application,
sc.cfg.Account,
sc.cfg.User,
sc.cfg.Password)
if err != nil {
sc.cleanup()
return err
if sc.cfg.IDToken == "" {
samlResponse, proofKey, err = authenticateByExternalBrowser(
sc.ctx,
sc.rest,
sc.cfg.Authenticator.String(),
sc.cfg.Application,
sc.cfg.Account,
sc.cfg.User,
sc.cfg.Password)
if err != nil {
sc.cleanup()
return err
}
}
case AuthTypeOkta:
samlResponse, err = authenticateBySAML(
Expand Down Expand Up @@ -463,3 +529,11 @@ func authenticateWithConfig(sc *snowflakeConn) error {
sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID)
return nil
}

func fillCachedIDToken(sc *snowflakeConn) {
getCredential(sc, idToken)
}

func fillCachedMfaToken(sc *snowflakeConn) {
getCredential(sc, mfaToken)
}
36 changes: 36 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,25 @@ func postAuthCheckJWTToken(_ context.Context, _ *snowflakeRestful, _ *url.Values
}, nil
}

func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}

return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
MfaToken: "mockedMfaToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}

func getDefaultSnowflakeConn() *snowflakeConn {
cfg := Config{
Account: "a",
Expand Down Expand Up @@ -469,3 +488,20 @@ func TestUnitAuthenticateJWT(t *testing.T) {
t.Fatalf("invalid token passed")
}
}

func TestUnitAuthenticateUsernamePasswordMfa(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuth: postAuthCheckUsernamePasswordMfa,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA
requestMfaToken := "true"
sc.cfg.Params[clientRequestMfaToken] = &requestMfaToken
sc.rest = sr
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
}
85 changes: 85 additions & 0 deletions cmd/mfa/mfa.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package main

import (
"database/sql"
"flag"
"fmt"
"log"
"os"
"strconv"

sf "github.com/snowflakedb/gosnowflake"
)

// getDSN constructs a DSN based on the test connection parameters
func getDSN() (string, *sf.Config, error) {
env := func(k string, failOnMissing bool) string {
if value := os.Getenv(k); value != "" {
return value
}
if failOnMissing {
log.Fatalf("%v environment variable is not set.", k)
}
return ""
}

account := env("SNOWFLAKE_TEST_ACCOUNT", true)
user := env("SNOWFLAKE_TEST_USER", true)
password := env("SNOWFLAKE_TEST_PASSWORD", true)
host := env("SNOWFLAKE_TEST_HOST", false)
port := env("SNOWFLAKE_TEST_PORT", false)
protocol := env("SNOWFLAKE_TEST_PROTOCOL", false)

portStr, err := strconv.Atoi(port)
if err != nil {
return "", nil, err
}
cfg := &sf.Config{
Account: account,
Authenticator: sf.AuthTypeUsernamePasswordMFA,
User: user,
Host: host,
Password: password,
Port: portStr,
Protocol: protocol,
}

dsn, err := sf.DSN(cfg)
return dsn, cfg, err
}

func main() {
if !flag.Parsed() {
flag.Parse()
}

dsn, cfg, err := getDSN()

if err != nil {
log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err)
}

// The external browser flow should start with the call to Open
db, err := sql.Open("snowflake", dsn)
if err != nil {
log.Fatalf("failed to connect. %v, err: %v", dsn, err)
}
defer db.Close()
query := "SELECT 1"
rows, err := db.Query(query)
if err != nil {
log.Fatalf("failed to run a query. %v, err: %v", query, err)
}
defer rows.Close()
var v int
for rows.Next() {
err := rows.Scan(&v)
if err != nil {
log.Fatalf("failed to get result. err: %v", err)
}
if v != 1 {
log.Fatalf("failed to get 1. got: %v", v)
}
fmt.Printf("Congrats! You have successfully run %v with Snowflake DB!", query)
}
}
16 changes: 16 additions & 0 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool {
return strings.Compare(*v, "true") == 0
}

func (sc *snowflakeConn) isClientStoreTemporaryCredential() bool {
v, ok := sc.cfg.Params[clientStoreTemporaryCredential]
if !ok {
return false
}
return strings.Compare(*v, "true") == 0
}

func (sc *snowflakeConn) isClientRequestMfaToken() bool {
v, ok := sc.cfg.Params[clientRequestMfaToken]
if !ok {
return false
}
return strings.Compare(*v, "true") == 0
}

func (sc *snowflakeConn) startHeartBeat() {
if !sc.isClientSessionKeepAliveEnabled() {
return
Expand Down
3 changes: 3 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ type Config struct {
DisableTelemetry bool // indicates whether to disable telemetry

Tracing string // sets logging level

MfaToken string // Internally used to cache the MFA token
IDToken string // Internally used to cache the Id Token for external browser
}

// ocspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED
Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.18

require (
github.com/Azure/azure-storage-blob-go v0.15.0
github.com/99designs/keyring v1.2.1
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40
github.com/aws/aws-sdk-go-v2 v1.16.16
github.com/aws/aws-sdk-go-v2/credentials v1.12.20
Expand All @@ -18,6 +19,7 @@ require (
)

require (
github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect
github.com/Azure/azure-pipeline-go v0.2.3 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.8 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23 // indirect
Expand All @@ -29,6 +31,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.17 // indirect
github.com/google/flatbuffers v2.0.8+incompatible // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/klauspost/compress v1.15.11 // indirect
github.com/mattn/go-ieproxy v0.0.9 // indirect
Expand Down
Loading

0 comments on commit 82aecfa

Please sign in to comment.