Skip to content

Commit

Permalink
More validation testing
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto committed Feb 20, 2023
1 parent 84a369b commit f759348
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 28 deletions.
32 changes: 14 additions & 18 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,23 @@ import (
"strings"
)

// Error constants
var (
ErrInvalidKey = errors.New("key is invalid")
ErrInvalidKeyType = errors.New("key is of invalid type")
ErrHashUnavailable = errors.New("the requested hash function is unavailable")

ErrInvalidKey = errors.New("key is invalid")
ErrInvalidKeyType = errors.New("key is of invalid type")
ErrHashUnavailable = errors.New("the requested hash function is unavailable")
ErrTokenMalformed = errors.New("token is malformed")
ErrTokenUnverifiable = errors.New("token is unverifiable")
ErrTokenRequiredClaimMissing = errors.New("a required claim is missing")
ErrTokenSignatureInvalid = errors.New("token signature is invalid")

ErrTokenInvalidAudience = errors.New("token has invalid audience")
ErrTokenExpired = errors.New("token is expired")
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
ErrTokenInvalidIssuer = errors.New("token has invalid issuer")
ErrTokenInvalidSubject = errors.New("token has invalid subject")
ErrTokenNotValidYet = errors.New("token is not valid yet")
ErrTokenInvalidId = errors.New("token has invalid id")
ErrTokenInvalidClaims = errors.New("token has invalid claims")

ErrInvalidType = errors.New("invalid type for claim")
ErrTokenRequiredClaimMissing = errors.New("token is missing required claim")
ErrTokenInvalidAudience = errors.New("token has invalid audience")
ErrTokenExpired = errors.New("token is expired")
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
ErrTokenInvalidIssuer = errors.New("token has invalid issuer")
ErrTokenInvalidSubject = errors.New("token has invalid subject")
ErrTokenNotValidYet = errors.New("token is not valid yet")
ErrTokenInvalidId = errors.New("token has invalid id")
ErrTokenInvalidClaims = errors.New("token has invalid claims")
ErrInvalidType = errors.New("invalid type for claim")
)

// joinedError is an error type that works similar to what [errors.Join]
Expand All @@ -46,7 +42,7 @@ func (je joinedError) Error() string {

// joinErrors joins together multiple errors. Useful for scenarios where
// multiple errors next to each other occur, e.g., in claims validation.
func joinErrors(errs []error) error {
func joinErrors(errs ...error) error {
return &joinedError{
errs: errs,
}
Expand Down
11 changes: 9 additions & 2 deletions errors_go1_20.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@ func (je joinedError) Unwrap() []error {
//
// "token is unverifiable: no keyfunc was provided"
func newError(message string, err error, more ...error) error {
format := "%w: %s"
args := []any{err, message}
var format string
var args []any
if message != "" {
format = "%w: %s"
args = []any{err, message}
} else {
format = "%w"
args = []any{err}
}

for _, e := range more {
format += ": %w"
Expand Down
11 changes: 9 additions & 2 deletions errors_go_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,15 @@ func newError(message string, err error, more ...error) error {
// We cannot wrap multiple errors here with %w, so we have to be a little
// bit creative. Basically, we are using %s instead of %w to produce the
// same error message and then throw the result into a custom error struct.
format := "%s: %s"
args := []any{err, message}
var format string
var args []any
if message != "" {
format = "%s: %s"
args = []any{err, message}
} else {
format = "%s"
args = []any{err}
}
errs := []error{err}

for _, e := range more {
Expand Down
14 changes: 13 additions & 1 deletion errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func Test_joinErrors(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := joinErrors(tt.args.errs)
err := joinErrors(tt.args.errs...)
for _, wantErr := range tt.wantErrors {
if !errors.Is(err, wantErr) {
t.Errorf("joinErrors() error = %v, does not contain %v", err, wantErr)
Expand Down Expand Up @@ -65,6 +65,18 @@ func Test_newError(t *testing.T) {
wantMessage: "token is malformed: something is wrong: unexpected EOF",
wantErrors: []error{ErrTokenMalformed},
},
{
name: "two errors, no detail",
args: args{message: "", err: ErrTokenInvalidClaims, more: []error{ErrTokenExpired}},
wantMessage: "token has invalid claims: token is expired",
wantErrors: []error{ErrTokenInvalidClaims, ErrTokenExpired},
},
{
name: "two errors, no detail and join error",
args: args{message: "", err: ErrTokenInvalidClaims, more: []error{joinErrors(ErrTokenExpired, ErrTokenNotValidYet)}},
wantMessage: "token has invalid claims: token is expired, token is not valid yet",
wantErrors: []error{ErrTokenInvalidClaims, ErrTokenExpired, ErrTokenNotValidYet},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
// Perform signature validation
token.Signature = parts[2]
if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil {
return token, newError("could not verify", ErrTokenSignatureInvalid, err)
return token, newError("", ErrTokenSignatureInvalid, err)
}

// Validate Claims
Expand All @@ -93,7 +93,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf
}

if err := p.validator.Validate(claims); err != nil {
return token, err
return token, newError("", ErrTokenInvalidClaims, err)
}
}

Expand Down
6 changes: 3 additions & 3 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (v *validator) Validate(claims Claims) error {
// If we have an expected subject, we also require the subject claim
if v.expectedSub != "" {
if err = v.verifySubject(claims, v.expectedSub, true); err != nil {
errs = append(errs, ErrTokenInvalidSubject)
errs = append(errs, err)
}
}

Expand All @@ -137,7 +137,7 @@ func (v *validator) Validate(claims Claims) error {
return nil
}

return joinErrors(errs)
return joinErrors(errs...)
}

// verifyExpiresAt compares the exp claim in claims against cmp. This function
Expand Down Expand Up @@ -276,7 +276,7 @@ func (v *validator) verifySubject(claims Claims, cmp string, required bool) erro
return errorIfRequired(required, "sub")
}

return errorIfFalse(sub == cmp, ErrTokenInvalidIssuer)
return errorIfFalse(sub == cmp, ErrTokenInvalidSubject)
}

// errorIfFalse returns the error specified in err, if the value is true.
Expand Down
Loading

0 comments on commit f759348

Please sign in to comment.