diff --git a/google/internal/externalaccount/basecredentials.go b/google/internal/externalaccount/basecredentials.go index 33288d367..e97210c74 100644 --- a/google/internal/externalaccount/basecredentials.go +++ b/google/internal/externalaccount/basecredentials.go @@ -16,6 +16,8 @@ import ( "golang.org/x/oauth2/google/internal/stsexchange" ) +var maxUnixTime = time.Unix(1<<63-1, 999999999) + // now aliases time.Now for testing var now = func() time.Time { return time.Now().UTC() @@ -241,10 +243,16 @@ func (ts tokenSource) Token() (*oauth2.Token, error) { AccessToken: stsResp.AccessToken, TokenType: stsResp.TokenType, } + + // The RFC8693 doesn't define the explicit 0 of "expires_in" field behavior. + // In practice a lot of sts request is using 0 means "never expire" for an sts token. + // So the logic here shall use a max unix time value. if stsResp.ExpiresIn < 0 { return nil, fmt.Errorf("oauth2/google: got invalid expiry from security token service") - } else if stsResp.ExpiresIn >= 0 { + } else if stsResp.ExpiresIn > 0 { accessToken.Expiry = now().Add(time.Duration(stsResp.ExpiresIn) * time.Second) + } else { + accessToken.Expiry = maxUnixTime } if stsResp.RefreshToken != "" { diff --git a/google/internal/externalaccount/basecredentials_test.go b/google/internal/externalaccount/basecredentials_test.go index 9bdf8e01d..670cfb11c 100644 --- a/google/internal/externalaccount/basecredentials_test.go +++ b/google/internal/externalaccount/basecredentials_test.go @@ -6,6 +6,7 @@ package externalaccount import ( "context" + "encoding/json" "fmt" "io/ioutil" "net/http" @@ -99,15 +100,18 @@ func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.T return ts.Token() } -func validateToken(t *testing.T, tok *oauth2.Token) { - if got, want := tok.AccessToken, correctAT; got != want { +func validateToken(t *testing.T, tok *oauth2.Token, expectToken *oauth2.Token) { + if expectToken == nil { + return + } + if got, want := tok.AccessToken, expectToken.AccessToken; got != want { t.Errorf("Unexpected access token: got %v, but wanted %v", got, want) } - if got, want := tok.TokenType, "Bearer"; got != want { + if got, want := tok.TokenType, expectToken.TokenType; got != want { t.Errorf("Unexpected TokenType: got %v, but wanted %v", got, want) } - if got, want := tok.Expiry, testNow().Add(time.Duration(3600)*time.Second); got != want { + if got, want := tok.Expiry, expectToken.Expiry; got != want { t.Errorf("Unexpected Expiry: got %v, but wanted %v", got, want) } } @@ -117,30 +121,94 @@ func getExpectedMetricsHeader(source string, saImpersonation bool, configLifetim } func TestToken(t *testing.T) { - config := Config{ - Audience: "32555940559.apps.googleusercontent.com", - SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", - ClientSecret: "notsosecret", - ClientID: "rbrgnognrhongo3bi4gb9ghg9g", - CredentialSource: testBaseCredSource, - Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"}, + type MockSTSResponse struct { + AccessToken string `json:"access_token"` + IssuedTokenType string `json:"issued_token_type"` + TokenType string `json:"token_type"` + ExpiresIn int32 `json:"expires_in,omitempty"` + Scope string `json:"scopre,omitenpty"` } - server := testExchangeTokenServer{ - url: "/", - authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=", - contentType: "application/x-www-form-urlencoded", - metricsHeader: getExpectedMetricsHeader("file", false, false), - body: baseCredsRequestBody, - response: baseCredsResponseBody, + testCases := []struct { + name string + responseBody MockSTSResponse + expectToken *oauth2.Token + expectErrorMsg string + }{ + { + name: "happy case", + responseBody: MockSTSResponse{ + AccessToken: correctAT, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "https://www.googleapis.com/auth/cloud-platform", + }, + expectToken: &oauth2.Token{ + AccessToken: correctAT, + TokenType: "Bearer", + Expiry: testNow().Add(time.Duration(3600) * time.Second), + }, + }, + { + name: "happy case, non expire token", + responseBody: MockSTSResponse{ + AccessToken: correctAT, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: 0, + Scope: "https://www.googleapis.com/auth/cloud-platform", + }, + expectToken: &oauth2.Token{ + AccessToken: correctAT, + TokenType: "Bearer", + Expiry: maxUnixTime, + }, + }, + { + name: "negative expiry time", + responseBody: MockSTSResponse{ + AccessToken: correctAT, + IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", + TokenType: "Bearer", + ExpiresIn: -1, + Scope: "https://www.googleapis.com/auth/cloud-platform", + }, + expectToken: nil, + expectErrorMsg: "oauth2/google: got invalid expiry from security token service", + }, } - tok, err := run(t, &config, &server) + for _, testCase := range testCases { + config := Config{ + Audience: "32555940559.apps.googleusercontent.com", + SubjectTokenType: "urn:ietf:params:oauth:token-type:id_token", + ClientSecret: "notsosecret", + ClientID: "rbrgnognrhongo3bi4gb9ghg9g", + CredentialSource: testBaseCredSource, + Scopes: []string{"https://www.googleapis.com/auth/devstorage.full_control"}, + } - if err != nil { - t.Fatalf("Unexpected error: %e", err) + responseBody, _ := json.Marshal(testCase.responseBody) + + server := testExchangeTokenServer{ + url: "/", + authorization: "Basic cmJyZ25vZ25yaG9uZ28zYmk0Z2I5Z2hnOWc6bm90c29zZWNyZXQ=", + contentType: "application/x-www-form-urlencoded", + metricsHeader: getExpectedMetricsHeader("file", false, false), + body: baseCredsRequestBody, + response: string(responseBody), + } + + tok, err := run(t, &config, &server) + + if err != nil { + if err.Error() != testCase.expectErrorMsg { + t.Errorf("Error actual = %v, and Expect = %v", err, testCase.expectErrorMsg) + } + } + validateToken(t, tok, testCase.expectToken) } - validateToken(t, tok) } func TestWorkforcePoolTokenWithClientID(t *testing.T) { @@ -168,7 +236,12 @@ func TestWorkforcePoolTokenWithClientID(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %e", err) } - validateToken(t, tok) + expectToken := oauth2.Token{ + AccessToken: correctAT, + TokenType: "Bearer", + Expiry: testNow().Add(time.Duration(3600) * time.Second), + } + validateToken(t, tok, &expectToken) } func TestWorkforcePoolTokenWithoutClientID(t *testing.T) { @@ -195,7 +268,12 @@ func TestWorkforcePoolTokenWithoutClientID(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %e", err) } - validateToken(t, tok) + expectToken := oauth2.Token{ + AccessToken: correctAT, + TokenType: "Bearer", + Expiry: testNow().Add(time.Duration(3600) * time.Second), + } + validateToken(t, tok, &expectToken) } func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) {