Skip to content

Commit

Permalink
Add some missing tests
Browse files Browse the repository at this point in the history
Signed-off-by: Romain Caire <[email protected]>
  • Loading branch information
supercairos committed Mar 8, 2024
1 parent 3c5a336 commit aaa53da
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
6 changes: 3 additions & 3 deletions server/introspectionhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (s *Server) introspectRefreshToken(_ context.Context, token string) (*Intro
return nil, newIntrospectInternalServerError()
}

subjectString, sErr := s.genSubject(rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID)
subjectString, sErr := genSubject(rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID)
if sErr != nil {
s.logger.Errorf("failed to marshal offline session ID: %v", err)
return nil, newIntrospectInternalServerError()
Expand All @@ -96,7 +96,7 @@ func (s *Server) introspectRefreshToken(_ context.Context, token string) (*Intro
Expiry: rCtx.storageToken.CreatedAt.Add(s.refreshTokenPolicy.absoluteLifetime).Unix(),
Subject: subjectString,
Username: rCtx.storageToken.Claims.PreferredUsername,
Audience: s.getAudience(rCtx.storageToken.ClientID, rCtx.scopes),
Audience: getAudience(rCtx.storageToken.ClientID, rCtx.scopes),
Issuer: s.issuerURL.String(),

Extra: IntrospectionExtra{
Expand All @@ -123,7 +123,7 @@ func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Intr
return nil, newIntrospectInternalServerError()
}

clientID, err := s.getClientID(idToken.Audience, claims.AuthorizingParty)
clientID, err := getClientID(idToken.Audience, claims.AuthorizingParty)
if err != nil {
return nil, newIntrospectInternalServerError()
}
Expand Down
10 changes: 5 additions & 5 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes [
return s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID)
}

func (s *Server) getClientID(aud audience, azp string) (string, error) {
func getClientID(aud audience, azp string) (string, error) {
switch len(aud) {
case 0:
return "", fmt.Errorf("no audience is set, could not find ClientID")
Expand All @@ -318,7 +318,7 @@ func (s *Server) getClientID(aud audience, azp string) (string, error) {
}
}

func (s *Server) getAudience(clientID string, scopes []string) audience {
func getAudience(clientID string, scopes []string) audience {
var aud audience

for _, scope := range scopes {
Expand All @@ -341,7 +341,7 @@ func (s *Server) getAudience(clientID string, scopes []string) audience {
return aud
}

func (s *Server) genSubject(userID string, connID string) (string, error) {
func genSubject(userID string, connID string) (string, error) {
sub := &internal.IDTokenSubject{
UserId: userID,
ConnId: connID,
Expand Down Expand Up @@ -369,7 +369,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
issuedAt := s.now()
expiry = issuedAt.Add(s.idTokensValidFor)

subjectString, err := s.genSubject(claims.UserID, connID)
subjectString, err := genSubject(claims.UserID, connID)
if err != nil {
s.logger.Errorf("failed to marshal offline session ID: %v", err)
return "", expiry, fmt.Errorf("failed to marshal offline session ID: %v", err)
Expand Down Expand Up @@ -434,7 +434,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
}
}

tok.Audience = s.getAudience(clientID, scopes)
tok.Audience = getAudience(clientID, scopes)
if len(tok.Audience) > 1 {
// The current client becomes the authorizing party.
tok.AuthorizingParty = clientID
Expand Down
32 changes: 32 additions & 0 deletions server/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,40 @@ import (

"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
"github.com/stretchr/testify/require"
)

func TestGetClientID(t *testing.T) {
cid, err := getClientID(audience{}, "")
require.Equal(t, "", cid)
require.Equal(t, "no audience is set, could not find ClientID", err.Error())

cid, err = getClientID(audience{"a"}, "")
require.Equal(t, "a", cid)
require.NoError(t, err)

cid, err = getClientID(audience{"a", "b"}, "azp")
require.Equal(t, "azp", cid)
require.NoError(t, err)
}

func TestGetAudience(t *testing.T) {
aud := getAudience("client-id", []string{})
require.Equal(t, aud, audience{"client-id"})

aud = getAudience("client-id", []string{"ascope"})
require.Equal(t, aud, audience{"client-id"})

aud = getAudience("client-id", []string{"ascope", "audience:server:client_id:aa", "audience:server:client_id:bb"})
require.Equal(t, aud, audience{"aa", "bb", "client-id"})
}

func TestGetSubject(t *testing.T) {
sub, err := genSubject("foo", "bar")
require.Equal(t, "CgNmb28SA2Jhcg", sub)
require.NoError(t, err)
}

func TestParseAuthorizationRequest(t *testing.T) {
tests := []struct {
name string
Expand Down

0 comments on commit aaa53da

Please sign in to comment.