From 34fcd0d6bb8db0e18b2708b46c1b2d0a8ed8e25c Mon Sep 17 00:00:00 2001 From: Jesse Haka Date: Sat, 6 Apr 2024 19:04:16 +0300 Subject: [PATCH] add basic auth support --- auth.go | 106 +++++++++++++++++++++++++++++++++++++-------------- auth_test.go | 16 +++++--- 2 files changed, 88 insertions(+), 34 deletions(-) diff --git a/auth.go b/auth.go index 5de4638..678741d 100644 --- a/auth.go +++ b/auth.go @@ -38,8 +38,17 @@ type fileTokenSource struct { ctx context.Context //nolint:containedctx conf *clientcredentials.Config secretFile string + style authStyle } +type authStyle int + +const ( + authStyleNotKnown authStyle = iota + authStyleInHeader + authStyleInParams +) + // Token refreshes the token by using a new client credentials request. // tokens received this way do not include a refresh token. func (c *fileTokenSource) Token() (*oauth2.Token, error) { @@ -63,11 +72,30 @@ func (c *fileTokenSource) Token() (*oauth2.Token, error) { return nil, fmt.Errorf("oauth2: cannot read token file %q: %w", c.secretFile, err) } - tk, err := retrieveToken(c.ctx, c.conf.ClientID, string(content), c.conf.TokenURL, v) - if err != nil { - return nil, err + var tk *oauth2.Token + + switch { + case c.style == authStyleNotKnown, c.style == authStyleInHeader: + tk, err = retrieveToken(c.ctx, c.conf.TokenURL, c.conf.ClientID, string(content), v, authStyleInHeader) + if err == nil { + c.style = authStyleInHeader + return tk, nil + } + if c.style == authStyleNotKnown { + tk, err = retrieveToken(c.ctx, c.conf.TokenURL, c.conf.ClientID, string(content), v, authStyleInParams) + if err == nil { + c.style = authStyleInParams + return tk, nil + } + } + case c.style == authStyleInParams: + tk, err = retrieveToken(c.ctx, c.conf.TokenURL, c.conf.ClientID, string(content), v, authStyleInParams) + if err == nil { + c.style = authStyleInParams + return tk, nil + } } - return tk, nil + return nil, err } func getClient(ctx context.Context) *http.Client { @@ -77,35 +105,38 @@ func getClient(ctx context.Context) *http.Client { return nil } -func retrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*oauth2.Token, error) { +func buildHeadersAndBody(clientID, clientSecret string, v url.Values, style authStyle) (map[string]string, url.Values) { + headers := map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + } + switch style { + case authStyleInHeader, authStyleNotKnown: + headers["Authorization"] = "Basic " + BasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret)) + case authStyleInParams: + v.Set("client_id", clientID) + v.Set("client_secret", clientSecret) + } + return headers, v +} + +func retrieveToken(ctx context.Context, tokenURL, clientID, clientSecret string, v url.Values, style authStyle) (*oauth2.Token, error) { client := http.DefaultClient if c := getClient(ctx); c != nil { client = c } - // TODO: missing support for plain/form post body, missing support for client id and secret in basic auth header - v.Set("client_id", clientID) - v.Set("client_secret", clientSecret) - encoded := v.Encode() - tj := tokenJSON{} - _, err := MakeRequest( - ctx, - HTTPRequest{ - URL: tokenURL, - Method: "POST", - Body: []byte(encoded), - OKCode: []int{200}, - Headers: map[string]string{ - "Content-Type": "application/x-www-form-urlencoded", - }, - }, - &tj, - client, - Backoff{ - Duration: 100 * time.Millisecond, - MaxRetries: 2, - }, - ) + headers, v := buildHeadersAndBody(clientID, clientSecret, v, style) + req := HTTPRequest{ + URL: tokenURL, + Method: "POST", + Body: []byte(v.Encode()), + OKCode: []int{200}, + Headers: headers, + } + + var tj *tokenJSON + var err error + tj, err = makeRequest(ctx, client, req) if err != nil { return nil, err } @@ -126,6 +157,25 @@ func retrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, return token, err } +func makeRequest(ctx context.Context, client *http.Client, req HTTPRequest) (*tokenJSON, error) { + // TODO: missing support for plain/form post body + tj := &tokenJSON{} + _, err := MakeRequest( + ctx, + req, + &tj, + client, + Backoff{ + Duration: 100 * time.Millisecond, + MaxRetries: 2, + }, + ) + if err != nil { + return nil, err + } + return tj, nil +} + type tokenJSON struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` diff --git a/auth_test.go b/auth_test.go index 5dccdc2..aab8194 100644 --- a/auth_test.go +++ b/auth_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "net/url" "testing" + "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -70,12 +71,15 @@ func TestNewClientToken(t *testing.T) { ctx := context.Background() c := NewClient(ctx, creds, "secret", "./testdata/token") - req, err := http.NewRequest("GET", fmt.Sprintf("%s?foo=bar", srv.URL), nil) - require.NoError(t, err) + for i := 0; i < 3; i++ { + req, err := http.NewRequest("GET", fmt.Sprintf("%s?foo=bar", srv.URL), nil) + require.NoError(t, err) - resp, err := c.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) + resp, err := c.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + time.Sleep(1 * time.Second) + } } type tokenRequest struct { @@ -122,7 +126,7 @@ func mockSrv(secret string) *httptest.Server { c.JSON(http.StatusOK, gin.H{ "access_token": "token", "token_type": "Bearer", - "expires_in": 3600, + "expires_in": 1, }) }) return httptest.NewServer(r)