Skip to content

Commit

Permalink
SNOW-833537 Each time retry keypair auth with new token (#845)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus authored Jul 13, 2023
1 parent 25f4b6c commit 7d6e39a
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 127 deletions.
151 changes: 81 additions & 70 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,15 @@ func postAuth(
client *http.Client,
params *url.Values,
headers map[string]string,
body []byte,
bodyCreator bodyCreatorType,
timeout time.Duration) (
data *authResponse, err error) {
params.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
params.Add(requestGUIDKey, NewUUID().String())

fullURL := sr.getFullURL(loginRequestPath, params)
logger.Infof("full URL: %v", fullURL)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, body, timeout, true)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -287,6 +287,23 @@ func authenticate(
samlResponse []byte,
proofKey []byte,
) (resp *authResponseMain, err error) {
if sc.cfg.Authenticator == AuthTypeTokenAccessor {
logger.Info("Bypass authentication using existing token from token accessor")
sessionInfo := authResponseSessionInfo{
DatabaseName: sc.cfg.Database,
SchemaName: sc.cfg.Schema,
WarehouseName: sc.cfg.Warehouse,
RoleName: sc.cfg.Role,
}
token, masterToken, sessionID := sc.cfg.TokenAccessor.GetTokens()
return &authResponseMain{
Token: token,
MasterToken: masterToken,
SessionID: sessionID,
SessionInfo: sessionInfo,
}, nil
}

headers := getHeaders()
clientEnvironment := authRequestClientEnvironment{
Application: sc.cfg.Application,
Expand All @@ -310,6 +327,67 @@ func authenticate(
if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
sessionParameters[clientStoreTemporaryCredential] = true
}
bodyCreator := func() ([]byte, error) {
return createRequestBody(sc, sessionParameters, clientEnvironment, proofKey, samlResponse)
}

params := &url.Values{}
if sc.cfg.Database != "" {
params.Add("databaseName", sc.cfg.Database)
}
if sc.cfg.Schema != "" {
params.Add("schemaName", sc.cfg.Schema)
}
if sc.cfg.Warehouse != "" {
params.Add("warehouse", sc.cfg.Warehouse)
}
if sc.cfg.Role != "" {
params.Add("roleName", sc.cfg.Role)
}

logger.WithContext(sc.ctx).Infof("PARAMS for Auth: %v, %v, %v, %v, %v, %v",
params, sc.rest.Protocol, sc.rest.Host, sc.rest.Port, sc.rest.LoginTimeout, sc.cfg.Authenticator.String())

respd, err := sc.rest.FuncPostAuth(ctx, sc.rest, sc.rest.getClientFor(sc.cfg.Authenticator), params, headers, bodyCreator, sc.rest.LoginTimeout)
if err != nil {
return nil, err
}
if !respd.Success {
logger.Errorln("Authentication FAILED")
sc.rest.TokenAccessor.SetTokens("", "", -1)
if sessionParameters[clientRequestMfaToken] == true {
deleteCredential(sc, mfaToken)
}
if sessionParameters[clientStoreTemporaryCredential] == true {
deleteCredential(sc, idToken)
}
code, err := strconv.Atoi(respd.Code)
if err != nil {
code = -1
return nil, err
}
return nil, (&SnowflakeError{
Number: code,
SQLState: SQLStateConnectionRejected,
Message: respd.Message,
}).exceptionTelemetry(sc)
}
logger.Info("Authentication SUCCESS")
sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID)
if sessionParameters[clientRequestMfaToken] == true {
token := respd.Data.MfaToken
setCredential(sc, mfaToken, token)
}
if sessionParameters[clientStoreTemporaryCredential] == true {
token := respd.Data.IDToken
setCredential(sc, idToken, token)
}
return &respd.Data, nil
}

func createRequestBody(sc *snowflakeConn, sessionParameters map[string]interface{},
clientEnvironment authRequestClientEnvironment, proofKey []byte, samlResponse []byte,
) ([]byte, error) {
requestMain := authRequestData{
ClientAppID: clientType,
ClientAppVersion: SnowflakeGoDriverVersion,
Expand Down Expand Up @@ -362,83 +440,16 @@ func authenticate(
if sc.cfg.MfaToken != "" {
requestMain.Token = sc.cfg.MfaToken
}
case AuthTypeTokenAccessor:
logger.Info("Bypass authentication using existing token from token accessor")
sessionInfo := authResponseSessionInfo{
DatabaseName: sc.cfg.Database,
SchemaName: sc.cfg.Schema,
WarehouseName: sc.cfg.Warehouse,
RoleName: sc.cfg.Role,
}
token, masterToken, sessionID := sc.cfg.TokenAccessor.GetTokens()
return &authResponseMain{
Token: token,
MasterToken: masterToken,
SessionID: sessionID,
SessionInfo: sessionInfo,
}, nil
}

authRequest := authRequest{
Data: requestMain,
}
params := &url.Values{}
if sc.cfg.Database != "" {
params.Add("databaseName", sc.cfg.Database)
}
if sc.cfg.Schema != "" {
params.Add("schemaName", sc.cfg.Schema)
}
if sc.cfg.Warehouse != "" {
params.Add("warehouse", sc.cfg.Warehouse)
}
if sc.cfg.Role != "" {
params.Add("roleName", sc.cfg.Role)
}

jsonBody, err := json.Marshal(authRequest)
if err != nil {
return
}

logger.WithContext(sc.ctx).Infof("PARAMS for Auth: %v, %v, %v, %v, %v, %v",
params, sc.rest.Protocol, sc.rest.Host, sc.rest.Port, sc.rest.LoginTimeout, sc.cfg.Authenticator.String())

respd, err := sc.rest.FuncPostAuth(ctx, sc.rest, sc.rest.getClientFor(sc.cfg.Authenticator), params, headers, jsonBody, sc.rest.LoginTimeout)
if err != nil {
return nil, err
}
if !respd.Success {
logger.Errorln("Authentication FAILED")
sc.rest.TokenAccessor.SetTokens("", "", -1)
if sessionParameters[clientRequestMfaToken] == true {
deleteCredential(sc, mfaToken)
}
if sessionParameters[clientStoreTemporaryCredential] == true {
deleteCredential(sc, idToken)
}
code, err := strconv.Atoi(respd.Code)
if err != nil {
code = -1
return nil, err
}
return nil, (&SnowflakeError{
Number: code,
SQLState: SQLStateConnectionRejected,
Message: respd.Message,
}).exceptionTelemetry(sc)
}
logger.Info("Authentication SUCCESS")
sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID)
if sessionParameters[clientRequestMfaToken] == true {
token := respd.Data.MfaToken
setCredential(sc, mfaToken, token)
}
if sessionParameters[clientStoreTemporaryCredential] == true {
token := respd.Data.IDToken
setCredential(sc, idToken, token)
}
return &respd.Data, nil
return jsonBody, nil
}

// Generate a JWT token in string given the configuration
Expand Down
58 changes: 36 additions & 22 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,67 +26,70 @@ func TestUnitPostAuth(t *testing.T) {
FuncAuthPost: postAuthTestAfterRenew,
}
var err error
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0)
bodyCreator := func() ([]byte, error) {
return []byte{0x12, 0x34}, nil
}
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err != nil {
t.Fatalf("err: %v", err)
}
sr.FuncAuthPost = postAuthTestError
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0)
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
sr.FuncAuthPost = postAuthTestAppBadGatewayError
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0)
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
sr.FuncAuthPost = postAuthTestAppForbiddenError
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0)
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
sr.FuncAuthPost = postAuthTestAppUnexpectedError
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), []byte{0x12, 0x34}, 0)
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
}

func postAuthFailServiceIssue(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
func postAuthFailServiceIssue(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return nil, &SnowflakeError{
Number: ErrCodeServiceUnavailable,
}
}

func postAuthFailWrongAccount(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
func postAuthFailWrongAccount(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return nil, &SnowflakeError{
Number: ErrCodeFailedToConnect,
}
}

func postAuthFailUnknown(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
func postAuthFailUnknown(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return nil, &SnowflakeError{
Number: ErrFailedToAuth,
}
}

func postAuthSuccessWithErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
func postAuthSuccessWithErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: false,
Code: "98765",
Message: "wrong!",
}, nil
}

func postAuthSuccessWithInvalidErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
func postAuthSuccessWithInvalidErrorCode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: false,
Code: "abcdef",
Message: "wrong!",
}, nil
}

func postAuthSuccess(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
func postAuthSuccess(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*authResponse, error) {
return &authResponse{
Success: true,
Data: authResponseMain{
Expand All @@ -99,8 +102,9 @@ func postAuthSuccess(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *
}, nil
}

func postAuthCheckSAMLResponse(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
func postAuthCheckSAMLResponse(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand All @@ -126,9 +130,10 @@ func postAuthCheckOAuth(
_ *snowflakeRestful,
_ *http.Client,
_ *url.Values, _ map[string]string,
jsonBody []byte,
bodyCreator bodyCreatorType,
_ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand All @@ -153,8 +158,9 @@ func postAuthCheckOAuth(
}, nil
}

func postAuthCheckPasscode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
func postAuthCheckPasscode(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand All @@ -173,8 +179,9 @@ func postAuthCheckPasscode(_ context.Context, _ *snowflakeRestful, _ *http.Clien
}, nil
}

func postAuthCheckPasscodeInPassword(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
func postAuthCheckPasscodeInPassword(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand All @@ -195,8 +202,9 @@ func postAuthCheckPasscodeInPassword(_ context.Context, _ *snowflakeRestful, _ *

// JWT token validate callback function to check the JWT token
// It uses the public key paired with the testPrivKey
func postAuthCheckJWTToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
func postAuthCheckJWTToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand Down Expand Up @@ -231,8 +239,9 @@ func postAuthCheckJWTToken(_ context.Context, _ *snowflakeRestful, _ *http.Clien
}, nil
}

func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand All @@ -253,8 +262,9 @@ func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _
}, nil
}

func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand All @@ -275,8 +285,9 @@ func postAuthCheckUsernamePasswordMfaToken(_ context.Context, _ *snowflakeRestfu
}, nil
}

func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand All @@ -292,8 +303,9 @@ func postAuthCheckUsernamePasswordMfaFailed(_ context.Context, _ *snowflakeRestf
}, nil
}

func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand All @@ -314,8 +326,9 @@ func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *htt
}, nil
}

func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand All @@ -336,8 +349,9 @@ func postAuthCheckExternalBrowserToken(_ context.Context, _ *snowflakeRestful, _
}, nil
}

func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, _ *http.Client, _ *url.Values, _ map[string]string, bodyCreator bodyCreatorType, _ time.Duration) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 7d6e39a

Please sign in to comment.