-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
package common | ||
|
||
import ( | ||
"context" | ||
"encoding/base64" | ||
"fmt" | ||
"net/http" | ||
"net/url" | ||
"os" | ||
"strings" | ||
"time" | ||
|
||
"golang.org/x/oauth2" | ||
"golang.org/x/oauth2/clientcredentials" | ||
) | ||
|
||
func NewClient(ctx context.Context, conf *clientcredentials.Config, secret string, secretFile string) *http.Client { | ||
return oauth2.NewClient(ctx, newTokenSource(ctx, conf, secret, secretFile)) | ||
} | ||
|
||
func newTokenSource(ctx context.Context, conf *clientcredentials.Config, secret string, secretFile string) oauth2.TokenSource { | ||
// normal static secret token | ||
if secretFile == "" { | ||
conf.ClientSecret = secret | ||
return conf.TokenSource(ctx) | ||
} | ||
source := &fileTokenSource{ | ||
ctx: ctx, | ||
conf: conf, | ||
secretFile: secretFile, | ||
} | ||
// dynamic file token source | ||
return oauth2.ReuseTokenSource(nil, source) | ||
} | ||
|
||
type fileTokenSource struct { | ||
ctx context.Context | ||
conf *clientcredentials.Config | ||
secretFile string | ||
} | ||
|
||
// Token refreshes the token by using a new client credentials request. | ||
// tokens received this way do not include a refresh token | ||
func (c *fileTokenSource) Token() (*oauth2.Token, error) { | ||
v := url.Values{ | ||
"grant_type": {"client_credentials"}, | ||
} | ||
if len(c.conf.Scopes) > 0 { | ||
v.Set("scope", strings.Join(c.conf.Scopes, " ")) | ||
} | ||
for k, p := range c.conf.EndpointParams { | ||
// Allow grant_type to be overridden to allow interoperability with | ||
// non-compliant implementations. | ||
if _, ok := v[k]; ok && k != "grant_type" { | ||
return nil, fmt.Errorf("oauth2: cannot overwrite parameter %q", k) | ||
} | ||
v[k] = p | ||
} | ||
|
||
content, err := os.ReadFile(c.secretFile) | ||
if err != nil { | ||
return nil, fmt.Errorf("oauth2: cannot read token file %q: %v", c.secretFile, err) | ||
} | ||
|
||
tk, err := retrieveToken(c.ctx, c.conf.ClientID, string(content), c.conf.TokenURL, v) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return tk, nil | ||
} | ||
|
||
func getClient(ctx context.Context) *http.Client { | ||
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { | ||
return c | ||
} | ||
return nil | ||
} | ||
|
||
func retrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*oauth2.Token, error) { | ||
client := http.DefaultClient | ||
if c := getClient(ctx); c != nil { | ||
client = c | ||
} | ||
|
||
v.Set("client_id", clientID) | ||
v.Set("client_secret", clientSecret) | ||
encoded := v.Encode() | ||
tj := tokenJSON{} | ||
_, err := MakeRequest( | ||
ctx, | ||
HTTPRequest{ | ||
URL: tokenURL, | ||
Method: "POST", | ||
Body: []byte(encoded), | ||
OKCode: []int{200}, | ||
Headers: map[string]string{ | ||
"Content-Type": "application/x-www-form-urlencoded", | ||
}, | ||
}, | ||
&tj, | ||
client, | ||
Backoff{ | ||
Duration: 100 * time.Millisecond, | ||
MaxRetries: 2, | ||
}, | ||
) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
token := &oauth2.Token{ | ||
AccessToken: tj.AccessToken, | ||
TokenType: tj.TokenType, | ||
RefreshToken: tj.RefreshToken, | ||
Expiry: tj.expiry(), | ||
} | ||
|
||
if token != nil && token.RefreshToken == "" { | ||
token.RefreshToken = v.Get("refresh_token") | ||
} | ||
return token, err | ||
} | ||
|
||
type tokenJSON struct { | ||
AccessToken string `json:"access_token"` | ||
TokenType string `json:"token_type"` | ||
RefreshToken string `json:"refresh_token"` | ||
ExpiresIn int `json:"expires_in"` | ||
} | ||
|
||
func (e *tokenJSON) expiry() (t time.Time) { | ||
if v := e.ExpiresIn; v != 0 { | ||
return time.Now().Add(time.Duration(v) * time.Second) | ||
} | ||
return | ||
} | ||
|
||
// BasicAuth returns a base64 encoded string of the user and password | ||
func BasicAuth(user, password string) string { | ||
auth := fmt.Sprintf("%s:%s", user, password) | ||
return base64.StdEncoding.EncodeToString([]byte(auth)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
package common | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"testing" | ||
|
||
"github.com/gin-gonic/gin" | ||
"github.com/stretchr/testify/require" | ||
"golang.org/x/oauth2/clientcredentials" | ||
) | ||
|
||
func ExampleBasicAuth() { | ||
fmt.Println(BasicAuth("username", "password")) | ||
// Output: dXNlcm5hbWU6cGFzc3dvcmQ= | ||
} | ||
|
||
func TestDecodeBasicAuth(t *testing.T) { | ||
out, err := Base64decode("dXNlcm5hbWU6cGFzc3dvcmQ=") | ||
require.NoError(t, err) | ||
require.Equal(t, "username:password", out) | ||
} | ||
|
||
func TestNewClient(t *testing.T) { | ||
secret := "secret" | ||
srv := mockSrv(secret) | ||
t.Cleanup(func() { | ||
srv.Close() | ||
}) | ||
|
||
creds := &clientcredentials.Config{ | ||
ClientID: "clientid", | ||
TokenURL: fmt.Sprintf("%s/oauth2/token", srv.URL), | ||
Scopes: []string{"openid", "email", "groups"}, | ||
EndpointParams: url.Values{ | ||
"groups": []string{"test"}, | ||
}, | ||
} | ||
|
||
ctx := context.Background() | ||
c := NewClient(ctx, creds, secret, "") | ||
|
||
req, err := http.NewRequest("GET", fmt.Sprintf("%s?foo=bar", srv.URL), nil) | ||
require.NoError(t, err) | ||
|
||
resp, err := c.Do(req) | ||
require.NoError(t, err) | ||
require.Equal(t, http.StatusOK, resp.StatusCode) | ||
} | ||
|
||
func TestNewClientToken(t *testing.T) { | ||
secret := "tokenfile" | ||
srv := mockSrv(secret) | ||
t.Cleanup(func() { | ||
srv.Close() | ||
}) | ||
|
||
creds := &clientcredentials.Config{ | ||
ClientID: "clientid", | ||
TokenURL: fmt.Sprintf("%s/oauth2/token", srv.URL), | ||
Scopes: []string{"openid", "email", "groups"}, | ||
EndpointParams: url.Values{ | ||
"groups": []string{"test"}, | ||
}, | ||
} | ||
|
||
ctx := context.Background() | ||
c := NewClient(ctx, creds, "secret", "./testdata/token") | ||
|
||
req, err := http.NewRequest("GET", fmt.Sprintf("%s?foo=bar", srv.URL), nil) | ||
require.NoError(t, err) | ||
|
||
resp, err := c.Do(req) | ||
require.NoError(t, err) | ||
require.Equal(t, http.StatusOK, resp.StatusCode) | ||
} | ||
|
||
type tokenRequest struct { | ||
GrantType string `form:"grant_type" json:"grant_type"` | ||
Scope string `form:"scope" json:"scope"` | ||
ClientID string `form:"client_id" json:"client_id"` | ||
ClientSecret string `form:"client_secret" json:"client_secret"` | ||
} | ||
|
||
func mockSrv(secret string) *httptest.Server { | ||
r := gin.New() | ||
r.GET("/", func(c *gin.Context) { | ||
c.String(http.StatusOK, "ok") | ||
}) | ||
r.POST("/oauth2/token", func(c *gin.Context) { | ||
var payload tokenRequest | ||
err := c.Bind(&payload) | ||
if err != nil { | ||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||
return | ||
} | ||
|
||
if payload.GrantType != "client_credentials" { | ||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid grant_type"}) | ||
return | ||
} | ||
|
||
requestedGroups := c.Request.Form["groups"] | ||
if len(requestedGroups) != 1 || requestedGroups[0] != "test" { | ||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid groups: %+v", requestedGroups)}) | ||
return | ||
} | ||
|
||
if payload.Scope != "openid email groups" { | ||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid scope: %s", payload.Scope)}) | ||
return | ||
} | ||
|
||
if payload.ClientID != "clientid" || payload.ClientSecret != secret { | ||
c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("invalid client: %+v", payload)}) | ||
return | ||
} | ||
|
||
c.JSON(http.StatusOK, gin.H{ | ||
"access_token": "token", | ||
"token_type": "Bearer", | ||
"expires_in": 3600, | ||
}) | ||
}) | ||
return httptest.NewServer(r) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
tokenfile |