From 84a369b282223a5486d8938dfd58fb21d3f6a80e Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Mon, 20 Feb 2023 22:22:19 +0100 Subject: [PATCH] Verify functions now return errors instead of bool --- map_claims.go | 7 +- none.go | 6 +- parser_test.go | 1 - validator.go | 202 ++++++++++++++++++++++++++++++------------------- 4 files changed, 132 insertions(+), 84 deletions(-) diff --git a/map_claims.go b/map_claims.go index 014acb94..26ac3fbe 100644 --- a/map_claims.go +++ b/map_claims.go @@ -2,6 +2,7 @@ package jwt import ( "encoding/json" + "fmt" ) // MapClaims is a claims type that uses the map[string]interface{} for JSON decoding. @@ -60,7 +61,7 @@ func (m MapClaims) parseNumericDate(key string) (*NumericDate, error) { return newNumericDateFromSeconds(v), nil } - return nil, ErrInvalidType + return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) } // parseClaimsString tries to parse a key in the map claims type as a @@ -76,7 +77,7 @@ func (m MapClaims) parseClaimsString(key string) (ClaimStrings, error) { for _, a := range v { vs, ok := a.(string) if !ok { - return nil, ErrInvalidType + return nil, newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) } cs = append(cs, vs) } @@ -101,7 +102,7 @@ func (m MapClaims) parseString(key string) (string, error) { iss, ok = raw.(string) if !ok { - return "", ErrInvalidType + return "", newError(fmt.Sprintf("%s is invalid", key), ErrInvalidType) } return iss, nil diff --git a/none.go b/none.go index 75d2d7e3..a16495ac 100644 --- a/none.go +++ b/none.go @@ -1,7 +1,5 @@ package jwt -import "fmt" - // SigningMethodNone implements the none signing method. This is required by the spec // but you probably should never use it. var SigningMethodNone *signingMethodNone @@ -15,7 +13,7 @@ type unsafeNoneMagicConstant string func init() { SigningMethodNone = &signingMethodNone{} - NoneSignatureTypeDisallowedError = fmt.Errorf("%w: 'none' signature type is not allowed", ErrTokenUnverifiable) + NoneSignatureTypeDisallowedError = newError("'none' signature type is not allowed", ErrTokenUnverifiable) RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod { return SigningMethodNone @@ -35,7 +33,7 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac } // If signing method is none, signature must be an empty string if signature != "" { - return fmt.Errorf("%w: 'none' signing method with non-empty signature", ErrTokenUnverifiable) + return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable) } // Accept 'none' signing method. diff --git a/parser_test.go b/parser_test.go index 78b924f0..90c271fc 100644 --- a/parser_test.go +++ b/parser_test.go @@ -356,7 +356,6 @@ func TestParser_Parse(t *testing.T) { // Parse the token var token *jwt.Token - //var ve *jwt.ValidationError var err error var parser = data.parser if parser == nil { diff --git a/validator.go b/validator.go index 8aca744d..912050ad 100644 --- a/validator.go +++ b/validator.go @@ -2,9 +2,32 @@ package jwt import ( "crypto/subtle" + "fmt" "time" ) +// ClaimsValidator is an interface that can be implemented by custom claims who +// wish to execute any additional claims validation based on +// application-specific logic. The Validate function is then executed in +// addition to the regular claims validation and any error returned is appended +// to the final validation result. +// +// type MyCustomClaims struct { +// Foo string `json:"foo"` +// jwt.RegisteredClaims +// } +// +// func (m MyCustomClaims) Validate() error { +// if m.Foo != "bar" { +// return errors.New("must be foobar") +// } +// return nil +// } +type ClaimsValidator interface { + Claims + Validate() error +} + // validator is the core of the new Validation API. It is automatically used by // a [Parser] during parsing and can be modified with various parser options. // @@ -46,11 +69,12 @@ func newValidator(opts ...ParserOption) *validator { } // Validate validates the given claims. It will also perform any custom -// validation if claims implements the CustomValidator interface. +// validation if claims implements the [ClaimsValidator] interface. func (v *validator) Validate(claims Claims) error { var ( now time.Time - errs []error = make([]error, 0) + errs []error = make([]error, 0, 6) + err error ) // Check, if we have a time func @@ -61,42 +85,48 @@ func (v *validator) Validate(claims Claims) error { } // We always need to check the expiration time, but usage of the claim - // itself is OPTIONAL - if !v.VerifyExpiresAt(claims, now, false) { - errs = append(errs, ErrTokenExpired) + // itself is OPTIONAL. + if err = v.verifyExpiresAt(claims, now, false); err != nil { + errs = append(errs, err) } // We always need to check not-before, but usage of the claim itself is - // OPTIONAL - if !v.VerifyNotBefore(claims, now, false) { - errs = append(errs, ErrTokenNotValidYet) + // OPTIONAL. + if err = v.verifyNotBefore(claims, now, false); err != nil { + errs = append(errs, err) } // Check issued-at if the option is enabled - if v.verifyIat && !v.VerifyIssuedAt(claims, now, false) { - errs = append(errs, ErrTokenUsedBeforeIssued) + if v.verifyIat { + if err = v.verifyIssuedAt(claims, now, false); err != nil { + errs = append(errs, err) + } } // If we have an expected audience, we also require the audience claim - if v.expectedAud != "" && !v.VerifyAudience(claims, v.expectedAud, true) { - errs = append(errs, ErrTokenInvalidAudience) + if v.expectedAud != "" { + if err = v.verifyAudience(claims, v.expectedAud, true); err != nil { + errs = append(errs, err) + } } // If we have an expected issuer, we also require the issuer claim - if v.expectedIss != "" && !v.VerifyIssuer(claims, v.expectedIss, true) { - errs = append(errs, ErrTokenInvalidIssuer) + if v.expectedIss != "" { + if err = v.verifyIssuer(claims, v.expectedIss, true); err != nil { + errs = append(errs, err) + } } // If we have an expected subject, we also require the subject claim - if v.expectedSub != "" && !v.VerifySubject(claims, v.expectedSub, true) { - errs = append(errs, ErrTokenInvalidSubject) + if v.expectedSub != "" { + if err = v.verifySubject(claims, v.expectedSub, true); err != nil { + errs = append(errs, ErrTokenInvalidSubject) + } } // Finally, we want to give the claim itself some possibility to do some // additional custom validation based on a custom Validate function. - cvt, ok := claims.(interface { - Validate() error - }) + cvt, ok := claims.(ClaimsValidator) if ok { if err := cvt.Validate(); err != nil { errs = append(errs, err) @@ -110,84 +140,84 @@ func (v *validator) Validate(claims Claims) error { return joinErrors(errs) } -// VerifyExpiresAt compares the exp claim in claims against cmp. This function -// will return true if cmp < exp. Additional leeway is taken into account. +// verifyExpiresAt compares the exp claim in claims against cmp. This function +// will succeed if cmp < exp. Additional leeway is taken into account. // -// If exp is not set, it will return true if the claim is not required, -// otherwise false will be returned. +// If exp is not set, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifyExpiresAt(claims Claims, cmp time.Time, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifyExpiresAt(claims Claims, cmp time.Time, required bool) error { exp, err := claims.GetExpirationTime() if err != nil { - return false + return err } - if exp != nil { - return cmp.Before((exp.Time).Add(+v.leeway)) - } else { - return !required + if exp == nil { + return errorIfRequired(required, "exp") } + + return errorIfFalse(cmp.Before((exp.Time).Add(+v.leeway)), ErrTokenExpired) } -// VerifyIssuedAt compares the iat claim in claims against cmp. This function -// will return true if cmp >= iat. Additional leeway is taken into account. +// verifyIssuedAt compares the iat claim in claims against cmp. This function +// will succeed if cmp >= iat. Additional leeway is taken into account. // -// If iat is not set, it will return true if the claim is not required, -// otherwise false will be returned. +// If iat is not set, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifyIssuedAt(claims Claims, cmp time.Time, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifyIssuedAt(claims Claims, cmp time.Time, required bool) error { iat, err := claims.GetIssuedAt() if err != nil { - return false + return err } - if iat != nil { - return !cmp.Before(iat.Add(-v.leeway)) - } else { - return !required + if iat == nil { + return errorIfRequired(required, "iat") } + + return errorIfFalse(!cmp.Before(iat.Add(-v.leeway)), ErrTokenUsedBeforeIssued) } -// VerifyNotBefore compares the nbf claim in claims against cmp. This function +// verifyNotBefore compares the nbf claim in claims against cmp. This function // will return true if cmp >= nbf. Additional leeway is taken into account. // -// If nbf is not set, it will return true if the claim is not required, -// otherwise false will be returned. +// If nbf is not set, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifyNotBefore(claims Claims, cmp time.Time, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifyNotBefore(claims Claims, cmp time.Time, required bool) error { nbf, err := claims.GetNotBefore() if err != nil { - return false + return err } - if nbf != nil { - return !cmp.Before(nbf.Add(-v.leeway)) - } else { - return !required + if nbf == nil { + return errorIfRequired(required, "nbf") } + + return errorIfFalse(!cmp.Before(nbf.Add(-v.leeway)), ErrTokenNotValidYet) } -// VerifyAudience compares the aud claim against cmp. +// verifyAudience compares the aud claim against cmp. // -// If aud is not set or an empty list, it will return true if the claim is not -// required, otherwise false will be returned. +// If aud is not set or an empty list, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifyAudience(claims Claims, cmp string, required bool) error { aud, err := claims.GetAudience() if err != nil { - return false + return err } if len(aud) == 0 { - return !required + return errorIfRequired(required, "aud") } // use a var here to keep constant time compare when looping over a number of claims @@ -203,48 +233,68 @@ func (v *validator) VerifyAudience(claims Claims, cmp string, required bool) boo // case where "" is sent in one or many aud claims if stringClaims == "" { - return !required + return errorIfRequired(required, "aud") } - return result + return errorIfFalse(result, ErrTokenInvalidAudience) } -// VerifyIssuer compares the iss claim in claims against cmp. +// verifyIssuer compares the iss claim in claims against cmp. // -// If iss is not set, it will return true if the claim is not required, -// otherwise false will be returned. +// If iss is not set, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifyIssuer(claims Claims, cmp string, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifyIssuer(claims Claims, cmp string, required bool) error { iss, err := claims.GetIssuer() if err != nil { - return false + return err } if iss == "" { - return !required + return errorIfRequired(required, "iss") } - return iss == cmp + return errorIfFalse(iss == cmp, ErrTokenInvalidIssuer) } -// VerifySubject compares the sub claim against cmp. +// verifySubject compares the sub claim against cmp. // -// If sub is not set, it will return true if the claim is not required, -// otherwise false will be returned. +// If sub is not set, it will succeed if the claim is not required, +// otherwise ErrTokenRequiredClaimMissing will be returned. // // Additionally, if any error occurs while retrieving the claim, e.g., when its -// the wrong type, false will be returned. -func (v *validator) VerifySubject(claims Claims, cmp string, required bool) bool { +// the wrong type, an ErrTokenUnverifiable error will be returned. +func (v *validator) verifySubject(claims Claims, cmp string, required bool) error { sub, err := claims.GetSubject() if err != nil { - return false + return err } if sub == "" { - return !required + return errorIfRequired(required, "sub") } - return sub == cmp + return errorIfFalse(sub == cmp, ErrTokenInvalidIssuer) +} + +// errorIfFalse returns the error specified in err, if the value is true. +// Otherwise, nil is returned. +func errorIfFalse(value bool, err error) error { + if value { + return nil + } else { + return err + } +} + +// errorIfRequired returns an ErrTokenRequiredClaimMissing error if required is +// true. Otherwise, nil is returned. +func errorIfRequired(required bool, claim string) error { + if required { + return newError(fmt.Sprintf("%s claim is required", claim), ErrTokenRequiredClaimMissing) + } else { + return nil + } }