Skip to content

Commit

Permalink
Verify functions now return errors instead of bool
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto committed Feb 20, 2023
1 parent e237304 commit 84a369b
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 84 deletions.
7 changes: 4 additions & 3 deletions map_claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jwt

import (
"encoding/json"
"fmt"
)

// MapClaims is a claims type that uses the map[string]interface{} for JSON decoding.
Expand Down Expand Up @@ -60,7 +61,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) {
return newNumericDateFromSeconds(v), nil
}

return nil, ErrInvalidType
return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
}

// parseClaimsString tries to parse a key in the map claims type as a
Expand All @@ -76,7 +77,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) {
for _, a := range v {
vs, ok := a.(string)
if !ok {
return nil, ErrInvalidType
return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
}
cs = append(cs, vs)
}
Expand All @@ -101,7 +102,7 @@ func (m MapClaims) parseString(key string) (string, error) {

iss, ok = raw.(string)
if !ok {
return "", ErrInvalidType
return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType)
}

return iss, nil
Expand Down
6 changes: 2 additions & 4 deletions none.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package jwt

import "fmt"

// SigningMethodNone implements the none signing method. This is required by the spec
// but you probably should never use it.
var SigningMethodNone *signingMethodNone
Expand All @@ -15,7 +13,7 @@ type unsafeNoneMagicConstant string

func init() {
SigningMethodNone = &signingMethodNone{}
NoneSignatureTypeDisallowedError = fmt.Errorf("%w: 'none' signature type is not allowed", ErrTokenUnverifiable)
NoneSignatureTypeDisallowedError = newError("'none' signature type is not allowed", ErrTokenUnverifiable)

RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod {
return SigningMethodNone
Expand All @@ -35,7 +33,7 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac
}
// If signing method is none, signature must be an empty string
if signature != "" {
return fmt.Errorf("%w: 'none' signing method with non-empty signature", ErrTokenUnverifiable)
return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable)
}

// Accept 'none' signing method.
Expand Down
1 change: 0 additions & 1 deletion parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ func TestParser_Parse(t *testing.T) {

// Parse the token
var token *jwt.Token
//var ve *jwt.ValidationError
var err error
var parser = data.parser
if parser == nil {
Expand Down
202 changes: 126 additions & 76 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,32 @@ package jwt

import (
"crypto/subtle"
"fmt"
"time"
)

// ClaimsValidator is an interface that can be implemented by custom claims who
// wish to execute any additional claims validation based on
// application-specific logic. The Validate function is then executed in
// addition to the regular claims validation and any error returned is appended
// to the final validation result.
//
// type MyCustomClaims struct {
// Foo string `json:"foo"`
// jwt.RegisteredClaims
// }
//
// func (m MyCustomClaims) Validate() error {
// if m.Foo != "bar" {
// return errors.New("must be foobar")
// }
// return nil
// }
type ClaimsValidator interface {
Claims
Validate() error
}

// validator is the core of the new Validation API. It is automatically used by
// a [Parser] during parsing and can be modified with various parser options.
//
Expand Down Expand Up @@ -46,11 +69,12 @@ func newValidator(opts ...ParserOption) *validator {
}

// Validate validates the given claims. It will also perform any custom
// validation if claims implements the CustomValidator interface.
// validation if claims implements the [ClaimsValidator] interface.
func (v *validator) Validate(claims Claims) error {
var (
now time.Time
errs []error = make([]error, 0)
errs []error = make([]error, 0, 6)
err error
)

// Check, if we have a time func
Expand All @@ -61,42 +85,48 @@ func (v *validator) Validate(claims Claims) error {
}

// We always need to check the expiration time, but usage of the claim
// itself is OPTIONAL
if !v.VerifyExpiresAt(claims, now, false) {
errs = append(errs, ErrTokenExpired)
// itself is OPTIONAL.
if err = v.verifyExpiresAt(claims, now, false); err != nil {
errs = append(errs, err)
}

// We always need to check not-before, but usage of the claim itself is
// OPTIONAL
if !v.VerifyNotBefore(claims, now, false) {
errs = append(errs, ErrTokenNotValidYet)
// OPTIONAL.
if err = v.verifyNotBefore(claims, now, false); err != nil {
errs = append(errs, err)
}

// Check issued-at if the option is enabled
if v.verifyIat && !v.VerifyIssuedAt(claims, now, false) {
errs = append(errs, ErrTokenUsedBeforeIssued)
if v.verifyIat {
if err = v.verifyIssuedAt(claims, now, false); err != nil {
errs = append(errs, err)
}
}

// If we have an expected audience, we also require the audience claim
if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) {
errs = append(errs, ErrTokenInvalidAudience)
if v.expectedAud != "" {
if err = v.verifyAudience(claims, v.expectedAud, true); err != nil {
errs = append(errs, err)
}
}

// If we have an expected issuer, we also require the issuer claim
if v.expectedIss != "" && !v.VerifyIssuer(claims, v.expectedIss, true) {
errs = append(errs, ErrTokenInvalidIssuer)
if v.expectedIss != "" {
if err = v.verifyIssuer(claims, v.expectedIss, true); err != nil {
errs = append(errs, err)
}
}

// If we have an expected subject, we also require the subject claim
if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) {
errs = append(errs, ErrTokenInvalidSubject)
if v.expectedSub != "" {
if err = v.verifySubject(claims, v.expectedSub, true); err != nil {
errs = append(errs, ErrTokenInvalidSubject)
}
}

// Finally, we want to give the claim itself some possibility to do some
// additional custom validation based on a custom Validate function.
cvt, ok := claims.(interface {
Validate() error
})
cvt, ok := claims.(ClaimsValidator)
if ok {
if err := cvt.Validate(); err != nil {
errs = append(errs, err)
Expand All @@ -110,84 +140,84 @@ func (v *validator) Validate(claims Claims) error {
return joinErrors(errs)
}

// VerifyExpiresAt compares the exp claim in claims against cmp. This function
// will return true if cmp < exp. Additional leeway is taken into account.
// verifyExpiresAt compares the exp claim in claims against cmp. This function
// will succeed if cmp < exp. Additional leeway is taken into account.
//
// If exp is not set, it will return true if the claim is not required,
// otherwise false will be returned.
// If exp is not set, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned.
func (v *validator) VerifyExpiresAt(claims Claims, cmp time.Time, required bool) bool {
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error {
exp, err := claims.GetExpirationTime()
if err != nil {
return false
return err
}

if exp != nil {
return cmp.Before((exp.Time).Add(+v.leeway))
} else {
return !required
if exp == nil {
return errorIfRequired(required, "exp")
}

return errorIfFalse(cmp.Before((exp.Time).Add(+v.leeway)), ErrTokenExpired)
}

// VerifyIssuedAt compares the iat claim in claims against cmp. This function
// will return true if cmp >= iat. Additional leeway is taken into account.
// verifyIssuedAt compares the iat claim in claims against cmp. This function
// will succeed if cmp >= iat. Additional leeway is taken into account.
//
// If iat is not set, it will return true if the claim is not required,
// otherwise false will be returned.
// If iat is not set, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned.
func (v *validator) VerifyIssuedAt(claims Claims, cmp time.Time, required bool) bool {
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error {
iat, err := claims.GetIssuedAt()
if err != nil {
return false
return err
}

if iat != nil {
return !cmp.Before(iat.Add(-v.leeway))
} else {
return !required
if iat == nil {
return errorIfRequired(required, "iat")
}

return errorIfFalse(!cmp.Before(iat.Add(-v.leeway)), ErrTokenUsedBeforeIssued)
}

// VerifyNotBefore compares the nbf claim in claims against cmp. This function
// verifyNotBefore compares the nbf claim in claims against cmp. This function
// will return true if cmp >= nbf. Additional leeway is taken into account.
//
// If nbf is not set, it will return true if the claim is not required,
// otherwise false will be returned.
// If nbf is not set, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned.
func (v *validator) VerifyNotBefore(claims Claims, cmp time.Time, required bool) bool {
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error {
nbf, err := claims.GetNotBefore()
if err != nil {
return false
return err
}

if nbf != nil {
return !cmp.Before(nbf.Add(-v.leeway))
} else {
return !required
if nbf == nil {
return errorIfRequired(required, "nbf")
}

return errorIfFalse(!cmp.Before(nbf.Add(-v.leeway)), ErrTokenNotValidYet)
}

// VerifyAudience compares the aud claim against cmp.
// verifyAudience compares the aud claim against cmp.
//
// If aud is not set or an empty list, it will return true if the claim is not
// required, otherwise false will be returned.
// If aud is not set or an empty list, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned.
func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) bool {
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifyAudience(claims Claims, cmp string, required bool) error {
aud, err := claims.GetAudience()
if err != nil {
return false
return err
}

if len(aud) == 0 {
return !required
return errorIfRequired(required, "aud")
}

// use a var here to keep constant time compare when looping over a number of claims
Expand All @@ -203,48 +233,68 @@ func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) boo

// case where "" is sent in one or many aud claims
if stringClaims == "" {
return !required
return errorIfRequired(required, "aud")
}

return result
return errorIfFalse(result, ErrTokenInvalidAudience)
}

// VerifyIssuer compares the iss claim in claims against cmp.
// verifyIssuer compares the iss claim in claims against cmp.
//
// If iss is not set, it will return true if the claim is not required,
// otherwise false will be returned.
// If iss is not set, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned.
func (v *validator) VerifyIssuer(claims Claims, cmp string, required bool) bool {
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error {
iss, err := claims.GetIssuer()
if err != nil {
return false
return err
}

if iss == "" {
return !required
return errorIfRequired(required, "iss")
}

return iss == cmp
return errorIfFalse(iss == cmp, ErrTokenInvalidIssuer)
}

// VerifySubject compares the sub claim against cmp.
// verifySubject compares the sub claim against cmp.
//
// If sub is not set, it will return true if the claim is not required,
// otherwise false will be returned.
// If sub is not set, it will succeed if the claim is not required,
// otherwise ErrTokenRequiredClaimMissing will be returned.
//
// Additionally, if any error occurs while retrieving the claim, e.g., when its
// the wrong type, false will be returned.
func (v *validator) VerifySubject(claims Claims, cmp string, required bool) bool {
// the wrong type, an ErrTokenUnverifiable error will be returned.
func (v *validator) verifySubject(claims Claims, cmp string, required bool) error {
sub, err := claims.GetSubject()
if err != nil {
return false
return err
}

if sub == "" {
return !required
return errorIfRequired(required, "sub")
}

return sub == cmp
return errorIfFalse(sub == cmp, ErrTokenInvalidIssuer)
}

// errorIfFalse returns the error specified in err, if the value is true.
// Otherwise, nil is returned.
func errorIfFalse(value bool, err error) error {
if value {
return nil
} else {
return err
}
}

// errorIfRequired returns an ErrTokenRequiredClaimMissing error if required is
// true. Otherwise, nil is returned.
func errorIfRequired(required bool, claim string) error {
if required {
return newError(fmt.Sprintf("%s claim is required", claim), ErrTokenRequiredClaimMissing)
} else {
return nil
}
}

0 comments on commit 84a369b

Please sign in to comment.