Skip to content

Commit

Permalink
Add Custom Provider for SecureAuth IdP (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
markafarrell authored Jun 27, 2022
1 parent 5bd90ee commit c817ca9
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 1 deletion.
19 changes: 19 additions & 0 deletions claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@ import (
"github.com/ryanuber/go-glob"
)

// setClaim sets a claim value from allClaims given a provided claim string.
// If this string is a valid JSONPointer, it will be interpreted as such to locate
// the claim. Otherwise, the claim string will be used directly.
func setClaim(logger log.Logger, allClaims map[string]interface{}, claim string, val interface{}) interface{} {
var err error

if !strings.HasPrefix(claim, "/") {
allClaims[claim] = val
} else {
val, err = pointerstructure.Set(allClaims, claim, val)
if err != nil {
logger.Warn(fmt.Sprintf("unable to set %s in claims: %s", claim, err.Error()))
return nil
}
}

return val
}

// getClaim returns a claim value from allClaims given a provided claim string.
// If this string is a valid JSONPointer, it will be interpreted as such to locate
// the claim. Otherwise, the claim string will be used directly.
Expand Down
44 changes: 44 additions & 0 deletions claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,50 @@ func TestGetClaim(t *testing.T) {
}
}

func TestSetClaim(t *testing.T) {
data := `{
"a": 42,
"b": "bar",
"c": {
"d": 95,
"e": [
"dog",
"cat",
"bird"
],
"f": {
"g": "zebra"
}
}
}`
var claims map[string]interface{}
if err := json.Unmarshal([]byte(data), &claims); err != nil {
t.Fatal(err)
}

tests := []struct {
claim string
value interface{}
}{
{"a", float64(43)},
{"/a", float64(43)},
{"b", "foo"},
{"/c/d", float64(96)},
{"/c/e/1", "dog"},
{"/c/f/g", "elephant"},
}

for _, test := range tests {
_ = setClaim(hclog.NewNullLogger(), claims, test.claim, test.value)

v := getClaim(hclog.NewNullLogger(), claims, test.claim)

if diff := deep.Equal(v, test.value); diff != nil {
t.Fatal(diff)
}
}
}

func TestExtractMetadata(t *testing.T) {
emptyMap := make(map[string]string)

Expand Down
2 changes: 1 addition & 1 deletion path_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func (b *jwtAuthBackend) fetchGroups(ctx context.Context, pConfig CustomProvider

// Add groups obtained by provider-specific fetching to the claims
// so that they can be used for bound_claims validation on the role.
allClaims["groups"] = groupsRaw
setClaim(b.Logger(), allClaims, role.GroupsClaim, groupsRaw)
}
}
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)
Expand Down
1 change: 1 addition & 0 deletions provider_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func ProviderMap() map[string]CustomProvider {
return map[string]CustomProvider{
"azure": &AzureProvider{},
"gsuite": &GSuiteProvider{},
"secureauth": &SecureAuthProvider{},
}
}

Expand Down
45 changes: 45 additions & 0 deletions provider_secureauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package jwtauth

import (
"context"
"fmt"
"strings"

"golang.org/x/oauth2"
)

// SecureAuthProvider is used for SecureAuth-specific configuration
type SecureAuthProvider struct {
}

// Initialize anything in the SecureAuthProvider struct - satisfying the CustomProvider interface
func (a *SecureAuthProvider) Initialize(_ context.Context, _ *jwtConfig) error {
return nil
}

// SensitiveKeys - satisfying the CustomProvider interface
func (a *SecureAuthProvider) SensitiveKeys() []string {
return []string{}
}

// FetchGroups - custom groups fetching for secureauth - satisfying GroupsFetcher interface
// SecureAuth by default will return groups not as a json list but as a list of comma seperated strings
// We need to convert this to a json list
func (a *SecureAuthProvider) FetchGroups(_ context.Context, b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole, _ oauth2.TokenSource) (interface{}, error) {
groupsClaimRaw := getClaim(b.Logger(), allClaims, role.GroupsClaim)

if groupsClaimRaw != nil {
// Try to convert the comma seperated list of strings into a list
if groupsstr, ok := groupsClaimRaw.(string); ok {
rawsecureauthGroups := strings.Split(groupsstr, ",")

var secureauthGroups = make([]interface{}, 0, len(rawsecureauthGroups))
for group := range rawsecureauthGroups {
secureauthGroups = append(secureauthGroups, rawsecureauthGroups[group])
}
groupsClaimRaw = secureauthGroups
}
}
b.Logger().Debug(fmt.Sprintf("post: groups claim raw is %v", groupsClaimRaw))
return groupsClaimRaw, nil
}
145 changes: 145 additions & 0 deletions provider_secureauth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package jwtauth

import (
"bytes"
"context"
"encoding/pem"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

type secureauthServer struct {
t *testing.T
server *httptest.Server
}

func newsecureauthServer(t *testing.T) *secureauthServer {
a := new(secureauthServer)
a.t = t
a.server = httptest.NewTLSServer(a)

return a
}

func (a *secureauthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")

switch r.URL.Path {
case "/.well-known/openid-configuration":
w.Write([]byte(strings.Replace(`
{
"issuer": "%s",
"authorization_endpoint": "%s/auth",
"token_endpoint": "%s/oauth2/v2.0/token",
"jwks_uri": "%s/certs",
"userinfo_endpoint": "%s/userinfo"
}`, "%s", a.server.URL, -1)))
default:
a.t.Fatalf("unexpected path: %q", r.URL.Path)
}
}

// getTLSCert returns the certificate for this provider in PEM format
func (a *secureauthServer) getTLSCert() (string, error) {
cert := a.server.Certificate()
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}

pemBuf := new(bytes.Buffer)
if err := pem.Encode(pemBuf, block); err != nil {
return "", err
}

return pemBuf.String(), nil
}

func TestLogin_secureauth_fetchGroups(t *testing.T) {

aServer := newsecureauthServer(t)
aCert, err := aServer.getTLSCert()
require.NoError(t, err)

b, storage := getBackend(t)
ctx := context.Background()

data := map[string]interface{}{
"oidc_discovery_url": aServer.server.URL,
"oidc_discovery_ca_pem": aCert,
"oidc_client_id": "abc",
"oidc_client_secret": "def",
"default_role": "test",
"bound_issuer": "http://vault.example.com/",
"provider_config": map[string]interface{}{
"provider": "secureauth",
},
}

// basic configuration
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: configPath,
Storage: storage,
Data: data,
}

resp, err := b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}

// set up test role
data = map[string]interface{}{
"user_claim": "email",
"groups_claim": "groups",
"allowed_redirect_uris": []string{"https://example.com"},
}

req = &logical.Request{
Operation: logical.CreateOperation,
Path: "role/test",
Storage: storage,
Data: data,
}

resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}

role := &jwtRole{
GroupsClaim: "groups",
}
allClaims := map[string]interface{}{
"groups": "a-group,another-group",
}

// Ensure b.cachedConfig is populated
config, err := b.(*jwtAuthBackend).config(ctx, storage)
if err != nil {
t.Fatal(err)
}

// Initialize the secureauth provider
provider, err := NewProviderConfig(ctx, config, ProviderMap())
if err != nil {
t.Fatal(err)
}

// Ensure groups are as expected
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "test.access.token"})
groupsRaw, err := b.(*jwtAuthBackend).fetchGroups(ctx, provider, allClaims, role, tokenSource)
assert.NoError(t, err)

groupsResp, ok := normalizeList(groupsRaw)
assert.True(t, ok)
assert.Equal(t, []interface{}{"a-group", "another-group"}, groupsResp)
}

0 comments on commit c817ca9

Please sign in to comment.