Skip to content

Commit

Permalink
add dynamic file tokensource
Browse files Browse the repository at this point in the history
  • Loading branch information
zetaab committed Apr 5, 2024
1 parent 451d2fa commit 3361a3f
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 0 deletions.
142 changes: 142 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package common

import (
"context"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)

func NewClient(ctx context.Context, conf *clientcredentials.Config, secret string, secretFile string) *http.Client {
return oauth2.NewClient(ctx, newTokenSource(ctx, conf, secret, secretFile))
}

func newTokenSource(ctx context.Context, conf *clientcredentials.Config, secret string, secretFile string) oauth2.TokenSource {
// normal static secret token
if secretFile == "" {
conf.ClientSecret = secret
return conf.TokenSource(ctx)
}
source := &fileTokenSource{
ctx: ctx,
conf: conf,
secretFile: secretFile,
}
// dynamic file token source
return oauth2.ReuseTokenSource(nil, source)
}

type fileTokenSource struct {
ctx context.Context
conf *clientcredentials.Config
secretFile string
}

// Token refreshes the token by using a new client credentials request.
// tokens received this way do not include a refresh token
func (c *fileTokenSource) Token() (*oauth2.Token, error) {
v := url.Values{
"grant_type": {"client_credentials"},
}
if len(c.conf.Scopes) > 0 {
v.Set("scope", strings.Join(c.conf.Scopes, " "))
}
for k, p := range c.conf.EndpointParams {
// Allow grant_type to be overridden to allow interoperability with
// non-compliant implementations.
if _, ok := v[k]; ok && k != "grant_type" {
return nil, fmt.Errorf("oauth2: cannot overwrite parameter %q", k)
}
v[k] = p
}

content, err := os.ReadFile(c.secretFile)
if err != nil {
return nil, fmt.Errorf("oauth2: cannot read token file %q: %v", c.secretFile, err)
}

tk, err := retrieveToken(c.ctx, c.conf.ClientID, string(content), c.conf.TokenURL, v)
if err != nil {
return nil, err
}
return tk, nil
}

func getClient(ctx context.Context) *http.Client {
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
return c
}
return nil
}

func retrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*oauth2.Token, error) {
client := http.DefaultClient
if c := getClient(ctx); c != nil {
client = c
}

v.Set("client_id", clientID)
v.Set("client_secret", clientSecret)
encoded := v.Encode()
tj := tokenJSON{}
_, err := MakeRequest(
ctx,
HTTPRequest{
URL: tokenURL,
Method: "POST",
Body: []byte(encoded),
OKCode: []int{200},
Headers: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
},
&tj,
client,
Backoff{
Duration: 100 * time.Millisecond,
MaxRetries: 2,
},
)
if err != nil {
return nil, err
}

token := &oauth2.Token{
AccessToken: tj.AccessToken,
TokenType: tj.TokenType,
RefreshToken: tj.RefreshToken,
Expiry: tj.expiry(),
}

if token != nil && token.RefreshToken == "" {
token.RefreshToken = v.Get("refresh_token")
}
return token, err
}

type tokenJSON struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}

func (e *tokenJSON) expiry() (t time.Time) {
if v := e.ExpiresIn; v != 0 {
return time.Now().Add(time.Duration(v) * time.Second)
}
return
}

// BasicAuth returns a base64 encoded string of the user and password
func BasicAuth(user, password string) string {
auth := fmt.Sprintf("%s:%s", user, password)
return base64.StdEncoding.EncodeToString([]byte(auth))
}
129 changes: 129 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package common

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2/clientcredentials"
)

func ExampleBasicAuth() {
fmt.Println(BasicAuth("username", "password"))
// Output: dXNlcm5hbWU6cGFzc3dvcmQ=
}

func TestDecodeBasicAuth(t *testing.T) {
out, err := Base64decode("dXNlcm5hbWU6cGFzc3dvcmQ=")
require.NoError(t, err)
require.Equal(t, "username:password", out)
}

func TestNewClient(t *testing.T) {
secret := "secret"
srv := mockSrv(secret)
t.Cleanup(func() {
srv.Close()
})

creds := &clientcredentials.Config{
ClientID: "clientid",
TokenURL: fmt.Sprintf("%s/oauth2/token", srv.URL),
Scopes: []string{"openid", "email", "groups"},
EndpointParams: url.Values{
"groups": []string{"test"},
},
}

ctx := context.Background()
c := NewClient(ctx, creds, secret, "")

req, err := http.NewRequest("GET", fmt.Sprintf("%s?foo=bar", srv.URL), nil)
require.NoError(t, err)

resp, err := c.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
}

func TestNewClientToken(t *testing.T) {
secret := "tokenfile"
srv := mockSrv(secret)
t.Cleanup(func() {
srv.Close()
})

creds := &clientcredentials.Config{
ClientID: "clientid",
TokenURL: fmt.Sprintf("%s/oauth2/token", srv.URL),
Scopes: []string{"openid", "email", "groups"},
EndpointParams: url.Values{
"groups": []string{"test"},
},
}

ctx := context.Background()
c := NewClient(ctx, creds, "secret", "./testdata/token")

req, err := http.NewRequest("GET", fmt.Sprintf("%s?foo=bar", srv.URL), nil)
require.NoError(t, err)

resp, err := c.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
}

type tokenRequest struct {
GrantType string `form:"grant_type" json:"grant_type"`
Scope string `form:"scope" json:"scope"`
ClientID string `form:"client_id" json:"client_id"`
ClientSecret string `form:"client_secret" json:"client_secret"`
}

func mockSrv(secret string) *httptest.Server {
r := gin.New()
r.GET("/", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
r.POST("/oauth2/token", func(c *gin.Context) {
var payload tokenRequest
err := c.Bind(&payload)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

if payload.GrantType != "client_credentials" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid grant_type"})
return
}

requestedGroups := c.Request.Form["groups"]
if len(requestedGroups) != 1 || requestedGroups[0] != "test" {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid groups: %+v", requestedGroups)})
return
}

if payload.Scope != "openid email groups" {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid scope: %s", payload.Scope)})
return
}

if payload.ClientID != "clientid" || payload.ClientSecret != secret {
c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("invalid client: %+v", payload)})
return
}

c.JSON(http.StatusOK, gin.H{
"access_token": "token",
"token_type": "Bearer",
"expires_in": 3600,
})
})
return httptest.NewServer(r)
}
1 change: 1 addition & 0 deletions testdata/token
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tokenfile

0 comments on commit 3361a3f

Please sign in to comment.