Skip to content

Commit

Permalink
Add error fields to RetrieveError
Browse files Browse the repository at this point in the history
  • Loading branch information
hickford committed Mar 6, 2023
1 parent 62b4eed commit cbc7e73
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 12 deletions.
51 changes: 41 additions & 10 deletions internal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,18 @@ type Token struct {
}

// tokenJSON is the struct representing the HTTP response from OAuth2
// providers returning a token in JSON form.
// providers returning a token or error in JSON form.
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
type tokenJSON struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
// error fields
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorUri string `json:"error_uri"`
}

func (e *tokenJSON) expiry() (t time.Time) {
Expand Down Expand Up @@ -236,21 +242,29 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, &RetrieveError{
Response: r,
Body: body,
}

failureStatus := r.StatusCode < 200 || r.StatusCode > 299
retrieveError := &RetrieveError{
Response: r,
Body: body,
// attempt to populate error detail below
}

var token *Token
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
switch content {
case "application/x-www-form-urlencoded", "text/plain":
// some endpoints such as GitHub return a query string https://docs.github.com/en/developers/apps/building-oauth-apps/authorizing-oauth-apps#response-1
vals, err := url.ParseQuery(string(body))
if err != nil {
return nil, err
if failureStatus {
return nil, retrieveError
}
return nil, fmt.Errorf("oauth2: cannot parse response: %v", err)
}
retrieveError.ErrorCode = vals.Get("error")
retrieveError.ErrorDescription = vals.Get("error_description")
retrieveError.ErrorUri = vals.Get("error_uri")
token = &Token{
AccessToken: vals.Get("access_token"),
TokenType: vals.Get("token_type"),
Expand All @@ -263,10 +277,17 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
}
default:
// spec says to return JSON https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
var tj tokenJSON
if err = json.Unmarshal(body, &tj); err != nil {
return nil, err
if failureStatus {
return nil, retrieveError
}
return nil, fmt.Errorf("oauth2: cannot parse json: %v", err)
}
retrieveError.ErrorCode = tj.Error
retrieveError.ErrorDescription = tj.ErrorDescription
retrieveError.ErrorUri = tj.ErrorUri
token = &Token{
AccessToken: tj.AccessToken,
TokenType: tj.TokenType,
Expand All @@ -276,15 +297,25 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
}
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
}
// according to spec, servers should respond status 400 in error case
// https://www.rfc-editor.org/rfc/rfc6749#section-5.2
// but some unorthodox servers respond 200 in error case
if failureStatus || retrieveError.ErrorCode != "" {
return nil, retrieveError
}
if token.AccessToken == "" {
return nil, errors.New("oauth2: server response missing access_token")
}
return token, nil
}

// mirrors oauth2.RetrieveError
type RetrieveError struct {
Response *http.Response
Body []byte
Response *http.Response
Body []byte
ErrorCode string
ErrorDescription string
ErrorUri string
}

func (r *RetrieveError) Error() string {
Expand Down
36 changes: 35 additions & 1 deletion oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ func TestTokenRetrieveError(t *testing.T) {
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
}
w.Header().Set("Content-type", "application/json")
// "The authorization server responds with an HTTP 400 (Bad Request)" https://www.rfc-editor.org/rfc/rfc6749#section-5.2
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error": "invalid_grant"}`))
}))
Expand All @@ -493,7 +494,7 @@ func TestTokenRetrieveError(t *testing.T) {
if err == nil {
t.Fatalf("got no error, expected one")
}
_, ok := err.(*RetrieveError)
re, ok := err.(*RetrieveError)
if !ok {
t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
}
Expand All @@ -502,6 +503,39 @@ func TestTokenRetrieveError(t *testing.T) {
if errStr := err.Error(); errStr != expected {
t.Fatalf("got %#v, expected %#v", errStr, expected)
}
expected = "invalid_grant"
if re.ErrorCode != expected {
t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected)
}
}

// TestTokenRetrieveError200 tests handling of unorthodox server that returns 200 in error case
func TestTokenRetrieveError200(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() != "/token" {
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
}
w.Header().Set("Content-type", "application/json")
w.Write([]byte(`{"error": "invalid_grant"}`))
}))
defer ts.Close()
conf := newConf(ts.URL)
_, err := conf.Exchange(context.Background(), "exchange-code")
if err == nil {
t.Fatalf("got no error, expected one")
}
re, ok := err.(*RetrieveError)
if !ok {
t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
}
expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "200 OK", `{"error": "invalid_grant"}`)
if errStr := err.Error(); errStr != expected {
t.Fatalf("got %#v, expected %#v", errStr, expected)
}
expected = "invalid_grant"
if re.ErrorCode != expected {
t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected)
}
}

func TestRefreshToken_RefreshTokenReplacement(t *testing.T) {
Expand Down
8 changes: 7 additions & 1 deletion token.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,18 @@ func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error)
}

// RetrieveError is the error returned when the token endpoint returns a
// non-2XX HTTP status code.
// non-2XX HTTP status code or populates rfc6749 error parameter.
type RetrieveError struct {
Response *http.Response
// Body is the body that was consumed by reading Response.Body.
// It may be truncated.
Body []byte
// rfc6749 error parameter https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
ErrorCode string
// rfc6749 error_description parameter https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
ErrorDescription string
// rfc6749 error_uri parameter https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
ErrorUri string
}

func (r *RetrieveError) Error() string {
Expand Down

0 comments on commit cbc7e73

Please sign in to comment.