diff --git a/pkg/v1/remote/transport/bearer.go b/pkg/v1/remote/transport/bearer.go index bff2ece01..9c404df52 100644 --- a/pkg/v1/remote/transport/bearer.go +++ b/pkg/v1/remote/transport/bearer.go @@ -92,29 +92,25 @@ func (bt *bearerTransport) RoundTrip(in *http.Request) (*http.Response, error) { // The basic token exchange is attempted first, falling back to the oauth flow. // If the IdentityToken is set, this indicates that we should start with the oauth flow. func (bt *bearerTransport) refresh() error { - first, second := bt.refreshBasic, bt.refreshOauth - auth, err := bt.basic.Authorization() if err != nil { return err } + var content []byte if auth.IdentityToken != "" { // If the secret being stored is an identity token, // the Username should be set to , which indicates // we are using an oauth flow. - first, second = bt.refreshOauth, bt.refreshBasic - } - - content, err := func() ([]byte, error) { - b, err := first() - if err != nil { - b, err = second() - if err != nil { - return nil, err - } + content, err = bt.refreshOauth() + if terr, ok := err.(*Error); ok && terr.StatusCode == http.StatusNotFound { + // Note: Not all token servers implement oauth2. + // If the request to the endpoint returns 404 using the HTTP POST method, + // refer to Token Documentation for using the HTTP GET method supported by all token servers. + content, err = bt.refreshBasic() } - return b, err - }() + } else { + content, err = bt.refreshBasic() + } if err != nil { return err } @@ -209,6 +205,7 @@ func (bt *bearerTransport) refreshOauth() ([]byte, error) { v.Set("grant_type", "refresh_token") v.Set("refresh_token", auth.IdentityToken) } else if auth.Username != "" && auth.Password != "" { + // TODO(#629): This is unreachable. v.Set("grant_type", "password") v.Set("username", auth.Username) v.Set("password", auth.Password) diff --git a/pkg/v1/remote/transport/bearer_test.go b/pkg/v1/remote/transport/bearer_test.go index 01afa1e34..246e90d52 100644 --- a/pkg/v1/remote/transport/bearer_test.go +++ b/pkg/v1/remote/transport/bearer_test.go @@ -16,7 +16,6 @@ package transport import ( "fmt" - "io/ioutil" "net/http" "net/http/httptest" "net/url" @@ -198,25 +197,25 @@ func TestBearerTransportTokenRefresh(t *testing.T) { func TestBearerTransportOauthRefresh(t *testing.T) { initialToken := "foo" - refreshedToken := "bar" + accessToken := "bar" + refreshToken := "baz" server := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { - b, err := ioutil.ReadAll(r.Body) - if err != nil { + if err := r.ParseForm(); err != nil { t.Fatal(err) } - if len(b) == 0 { - t.Errorf("got empty request body for POST") + if it := r.FormValue("refresh_token"); it != initialToken { + t.Errorf("want %s got %s", initialToken, it) } w.WriteHeader(http.StatusOK) - w.Write([]byte(fmt.Sprintf(`{"access_token": %q}`, refreshedToken))) + w.Write([]byte(fmt.Sprintf(`{"access_token": %q, "refresh_token": %q}`, accessToken, refreshToken))) return } hdr := r.Header.Get("Authorization") - if hdr == "Bearer "+refreshedToken { + if hdr == "Bearer "+accessToken { w.WriteHeader(http.StatusOK) return } @@ -234,11 +233,11 @@ func TestBearerTransportOauthRefresh(t *testing.T) { t.Errorf("Unexpected error during NewRegistry: %v", err) } - bearer := &authn.Bearer{Token: initialToken} + bearer := &authn.Bearer{} transport := &bearerTransport{ inner: http.DefaultTransport, bearer: bearer, - basic: authn.FromConfig(authn.AuthConfig{IdentityToken: "baz"}), + basic: authn.FromConfig(authn.AuthConfig{IdentityToken: initialToken}), registry: registry, realm: server.URL, scheme: "http", @@ -254,32 +253,34 @@ func TestBearerTransportOauthRefresh(t *testing.T) { if res.StatusCode != http.StatusOK { t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK) } - if transport.bearer.Token != refreshedToken { - t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", bearer.Token, refreshedToken) + if transport.bearer.Token != accessToken { + t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", bearer.Token, accessToken) + } + basicAuthConfig, err := transport.basic.Authorization() + if err != nil { + t.Fatal(err) + } + if got, want := basicAuthConfig.IdentityToken, refreshToken; got != want { + t.Errorf("Expected Basic IdentityToken to be refreshed, got %v, want %v", got, want) } } -func TestBearerTransportOauthRefreshToken(t *testing.T) { - initialToken := "initial_token" +func TestBearerTransportOauth404Fallback(t *testing.T) { + basicAuth := "basic_auth" + identityToken := "identity_token" accessToken := "access_token" - refreshToken := "refresh_token" server := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { - b, err := ioutil.ReadAll(r.Body) - if err != nil { - t.Fatal(err) - } - if len(b) == 0 { - t.Errorf("got empty request body for POST") - } - w.WriteHeader(http.StatusOK) - w.Write([]byte(fmt.Sprintf(`{"access_token": %q, "refresh_token": %q}`, accessToken, refreshToken))) - return + w.WriteHeader(http.StatusNotFound) } hdr := r.Header.Get("Authorization") + if hdr == "Basic "+basicAuth { + w.WriteHeader(http.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"access_token": %q}`, accessToken))) + } if hdr == "Bearer "+accessToken { w.WriteHeader(http.StatusOK) return @@ -298,11 +299,14 @@ func TestBearerTransportOauthRefreshToken(t *testing.T) { t.Errorf("Unexpected error during NewRegistry: %v", err) } - bearer := &authn.Bearer{Token: initialToken} + bearer := &authn.Bearer{} transport := &bearerTransport{ - inner: http.DefaultTransport, - basic: authn.FromConfig(authn.AuthConfig{Username: "user", Password: "pass"}), - bearer: bearer, + inner: http.DefaultTransport, + bearer: bearer, + basic: authn.FromConfig(authn.AuthConfig{ + IdentityToken: identityToken, + Auth: basicAuth, + }), registry: registry, realm: server.URL, scheme: "http",