diff --git a/pkg/plugin/oauth2/middleware_access_rules.go b/pkg/plugin/oauth2/middleware_access_rules.go index 9d2076103..81d4e702b 100644 --- a/pkg/plugin/oauth2/middleware_access_rules.go +++ b/pkg/plugin/oauth2/middleware_access_rules.go @@ -30,15 +30,14 @@ func NewRevokeRulesMiddleware(parser *jwt.Parser, accessRules []*AccessRule) fun for _, rule := range accessRules { allowed, err := rule.IsAllowed(claims) if err != nil { - log.WithError(err).Debug("Rule is not allowed") - continue - } - - if allowed { - handler.ServeHTTP(w, r) - } else { - w.WriteHeader(http.StatusUnauthorized) - return + log.WithError(err).Debug("Rule is invalid") + } else if rule.matched { + if allowed { + break + } else { + w.WriteHeader(http.StatusUnauthorized) + return + } } } } diff --git a/pkg/plugin/oauth2/middleware_access_rules_test.go b/pkg/plugin/oauth2/middleware_access_rules_test.go index c4aec098a..c91ddd9fb 100644 --- a/pkg/plugin/oauth2/middleware_access_rules_test.go +++ b/pkg/plugin/oauth2/middleware_access_rules_test.go @@ -15,137 +15,118 @@ import ( const signingAlg = "HS256" -func TestBlockJWTByCountry(t *testing.T) { - secret := "secret" +func generateToken(alg, key string) (string, error) { + token := basejwt.NewWithClaims(basejwt.GetSigningMethod(alg), basejwt.MapClaims{ + "country": "de", + "username": "test@hellofresh.com", + "iat": time.Now().Unix(), + }) - revokeRules := []*AccessRule{ - {Predicate: "country == 'de'", Action: "deny"}, - } + return token.SignedString([]byte(key)) +} + +func expectRulesToProduceStatus(t *testing.T, statusCode int, rules []*AccessRule) { + secret := "secret" parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret})) - mw := NewRevokeRulesMiddleware(parser, revokeRules) + mw := NewRevokeRulesMiddleware(parser, rules) token, err := generateToken(signingAlg, secret) require.NoError(t, err) - w, err := test.Record( - "GET", - "/", - map[string]string{ - "Content-Type": "application/json", - "Authorization": fmt.Sprintf("Bearer %s", token), - }, - mw(http.HandlerFunc(test.Ping)), - ) - assert.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, w.Code) + for i := 1; i <= 3; i++ { // middleware caches predicate and should return the same response every time + hits := 0 + w, err := test.Record( + "GET", + "/", + map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", token), + }, + mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits++ + test.Ping(w, r) + })), + ) + + assert.NoError(t, err, "%d. pass", i) + assert.Equal(t, statusCode, w.Code, "%d. pass", i) + if statusCode == http.StatusOK { + assert.Equal(t, 1, hits, "%d. pass", i) + } else { + assert.Equal(t, 0, hits, "%d. pass", i) + } + } } -func TestBlockJWTByUsername(t *testing.T) { - secret := "secret" +func TestBlockJWTByCountry(t *testing.T) { + expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{ + {Predicate: "country == 'de'", Action: "deny"}, + }) +} - revokeRules := []*AccessRule{ +func TestBlockJWTByUsername(t *testing.T) { + expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{ {Predicate: "username == 'test@hellofresh.com'", Action: "deny"}, - } - - parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret})) - - mw := NewRevokeRulesMiddleware(parser, revokeRules) - token, err := generateToken(signingAlg, secret) - require.NoError(t, err) - - w, err := test.Record( - "GET", - "/", - map[string]string{ - "Content-Type": "application/json", - "Authorization": fmt.Sprintf("Bearer %s", token), - }, - mw(http.HandlerFunc(test.Ping)), - ) - assert.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, w.Code) + }) } func TestBlockJWTByIssueDate(t *testing.T) { - secret := "secret" - - revokeRules := []*AccessRule{ + expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{ {Predicate: fmt.Sprintf("iat < %d", time.Now().Add(1*time.Hour).Unix()), Action: "deny"}, - } - - parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret})) - - mw := NewRevokeRulesMiddleware(parser, revokeRules) - token, err := generateToken(signingAlg, secret) - require.NoError(t, err) - - w, err := test.Record( - "GET", - "/", - map[string]string{ - "Content-Type": "application/json", - "Authorization": fmt.Sprintf("Bearer %s", token), - }, - mw(http.HandlerFunc(test.Ping)), - ) - assert.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, w.Code) + }) } func TestBlockJWTByCountryAndIssueDate(t *testing.T) { - secret := "secret" - - revokeRules := []*AccessRule{ + expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{ {Predicate: fmt.Sprintf("country == 'de' && iat < %d", time.Now().Add(1*time.Hour).Unix()), Action: "deny"}, - } - - parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret})) - - mw := NewRevokeRulesMiddleware(parser, revokeRules) - token, err := generateToken(signingAlg, secret) - require.NoError(t, err) + }) +} - w, err := test.Record( - "GET", - "/", - map[string]string{ - "Content-Type": "application/json", - "Authorization": fmt.Sprintf("Bearer %s", token), - }, - mw(http.HandlerFunc(test.Ping)), - ) - assert.NoError(t, err) - assert.Equal(t, http.StatusUnauthorized, w.Code) +func TestEmptyAccessRules(t *testing.T) { + expectRulesToProduceStatus(t, http.StatusOK, []*AccessRule{}) } -func generateToken(alg, key string) (string, error) { - token := basejwt.NewWithClaims(basejwt.GetSigningMethod(alg), basejwt.MapClaims{ - "country": "de", - "username": "test@hellofresh.com", - "iat": time.Now().Unix(), +func TestWrongRule(t *testing.T) { + expectRulesToProduceStatus(t, http.StatusOK, []*AccessRule{ + {Predicate: "country == 'wrong'", Action: "deny"}, }) - - return token.SignedString([]byte(key)) } -func TestEmptyAccessRules(t *testing.T) { - secret := "secret" - - revokeRules := []*AccessRule{} +func TestMultipleRulesSecondMatchesAndDenies(t *testing.T) { + expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{ + {Predicate: "country == 'us'", Action: "deny"}, + {Predicate: "country == 'de'", Action: "deny"}, + }) +} - parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret})) +func TestMultipleRulesSecondMatchesAndAllows(t *testing.T) { + expectRulesToProduceStatus(t, http.StatusOK, []*AccessRule{ + {Predicate: "country == 'us'", Action: "allow"}, + {Predicate: "country == 'de'", Action: "allow"}, + {Predicate: "true", Action: "deny"}, + }) +} - mw := NewRevokeRulesMiddleware(parser, revokeRules) +func TestMultipleRulesLastMatchesAndDenies(t *testing.T) { + expectRulesToProduceStatus(t, http.StatusUnauthorized, []*AccessRule{ + {Predicate: "country == 'us'", Action: "allow"}, + {Predicate: "country == 'gb'", Action: "allow"}, + {Predicate: "true", Action: "deny"}, + }) +} - w, err := test.Record( - "GET", - "/", - nil, - mw(http.HandlerFunc(test.Ping)), - ) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, w.Code) +func TestMultipleRulesNoneMatch(t *testing.T) { + expectRulesToProduceStatus(t, http.StatusOK, []*AccessRule{ + {Predicate: "country == 'us'", Action: "deny"}, + {Predicate: "country == 'gb'", Action: "deny"}, + }) +} +func TestMultipleRulesMatchAndAllow(t *testing.T) { + expectRulesToProduceStatus(t, http.StatusOK, []*AccessRule{ + {Predicate: "country == 'de'", Action: "allow"}, + {Predicate: "true", Action: "allow"}, + }) } func TestWrongJWT(t *testing.T) { @@ -153,36 +134,10 @@ func TestWrongJWT(t *testing.T) { {Predicate: fmt.Sprintf("country == 'de' && iat < %d", time.Now().Add(1*time.Hour).Unix()), Action: "deny"}, } - parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: "wrong secret"})) + parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: "secret"})) mw := NewRevokeRulesMiddleware(parser, revokeRules) - token, err := generateToken(signingAlg, "secret") - require.NoError(t, err) - - w, err := test.Record( - "GET", - "/", - map[string]string{ - "Content-Type": "application/json", - "Authorization": fmt.Sprintf("Bearer %s", token), - }, - mw(http.HandlerFunc(test.Ping)), - ) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, w.Code) -} - -func TestWrongRule(t *testing.T) { - secret := "secret" - - revokeRules := []*AccessRule{ - {Predicate: "country == 'wrong'", Action: "deny"}, - } - - parser := jwt.NewParser(jwt.NewParserConfig(0, jwt.SigningMethod{Alg: signingAlg, Key: secret})) - - mw := NewRevokeRulesMiddleware(parser, revokeRules) - token, err := generateToken(signingAlg, secret) + token, err := generateToken(signingAlg, "wrong secret") require.NoError(t, err) w, err := test.Record( diff --git a/pkg/plugin/oauth2/oauth.go b/pkg/plugin/oauth2/oauth.go index 6a10bb216..c071efa79 100644 --- a/pkg/plugin/oauth2/oauth.go +++ b/pkg/plugin/oauth2/oauth.go @@ -139,6 +139,7 @@ type AccessRule struct { Predicate string `bson:"predicate" json:"predicate"` Action string `bson:"action" json:"action"` parsed bool + matched bool } // IsAllowed checks if the rule is allowed to @@ -146,14 +147,14 @@ func (r *AccessRule) IsAllowed(claims map[string]interface{}) (bool, error) { var err error if !r.parsed { - matched, err := r.parse(claims) + r.matched, err = r.parse(claims) if err != nil { return false, err } + } - if !matched { - return true, nil - } + if !r.matched { + return true, nil } return r.Action == "allow", err