diff --git a/driver/configuration/config_keys.go b/driver/configuration/config_keys.go index 6fac519d5c..3a1b532c0c 100644 --- a/driver/configuration/config_keys.go +++ b/driver/configuration/config_keys.go @@ -24,6 +24,7 @@ const ( PrometheusServeCollapseRequestPaths Key = "serve.prometheus.collapse_request_paths" AccessRuleRepositories Key = "access_rules.repositories" AccessRuleMatchingStrategy Key = "access_rules.matching_strategy" + AcccessRulePrefixMatchingEnabled Key = "access_rules.prefix_matching_enabled" ) // Authorizers diff --git a/driver/configuration/provider.go b/driver/configuration/provider.go index 2b66517b97..68f221bfa5 100644 --- a/driver/configuration/provider.go +++ b/driver/configuration/provider.go @@ -58,6 +58,7 @@ type Provider interface { AccessRuleRepositories() []url.URL AccessRuleMatchingStrategy() MatchingStrategy + AcccessRulePrefixMatchingEnabled() bool ProxyServeAddress() string APIServeAddress() string diff --git a/driver/configuration/provider_koanf.go b/driver/configuration/provider_koanf.go index 2589b3a1ca..23674f267d 100644 --- a/driver/configuration/provider_koanf.go +++ b/driver/configuration/provider_koanf.go @@ -172,6 +172,11 @@ func (v *KoanfProvider) AccessRuleMatchingStrategy() MatchingStrategy { return MatchingStrategy(v.source.String(AccessRuleMatchingStrategy)) } +// AcccessRulePrefixMatching returns if prefix matching should be used. +func (v *KoanfProvider) AcccessRulePrefixMatchingEnabled() bool { + return v.source.Bool(AcccessRulePrefixMatchingEnabled) +} + func (v *KoanfProvider) CORSEnabled(iface string) bool { _, enabled := v.CORS(iface) return enabled diff --git a/internal/config/.oathkeeper.yaml b/internal/config/.oathkeeper.yaml index 39a02cb7fc..88426cb6a7 100644 --- a/internal/config/.oathkeeper.yaml +++ b/internal/config/.oathkeeper.yaml @@ -103,6 +103,8 @@ access_rules: - https://path-to-my-rules/rules.json # Optional fields describing matching strategy, defaults to "regexp". matching_strategy: glob + # Optional fields describing if rules should be matched using path prefixes, defaults to false. + prefix_matching_enabled: false errors: fallback: diff --git a/rule/matcher.go b/rule/matcher.go index 69d7b2e75e..5eca6619d6 100644 --- a/rule/matcher.go +++ b/rule/matcher.go @@ -9,7 +9,7 @@ import ( ) type ( - Protocol int + Protocol string Matcher interface { Match(ctx context.Context, method string, u *url.URL, protocol Protocol) (*Rule, error) @@ -17,6 +17,6 @@ type ( ) const ( - ProtocolHTTP Protocol = iota - ProtocolGRPC + ProtocolHTTP Protocol = "http" + ProtocolGRPC Protocol = "grpc" ) diff --git a/rule/matcher_test.go b/rule/matcher_test.go index 28ee2d051e..2a1561dd8d 100644 --- a/rule/matcher_test.go +++ b/rule/matcher_test.go @@ -99,6 +99,63 @@ var testRulesGlob = []Rule{ }, } +var testRulesPrefix = []Rule{ + { + ID: "foo1", + Match: &Match{URL: "https://localhost:1234/", Methods: []string{"POST"}}, + Description: "Create users rule", + Authorizer: Handler{Handler: "allow", Config: []byte(`{"type":"any"}`)}, + Authenticators: []Handler{{Handler: "anonymous", Config: []byte(`{"name":"anonymous1"}`)}}, + Mutators: []Handler{{Handler: "id_token", Config: []byte(`{"issuer":"anything"}`)}}, + Upstream: Upstream{URL: "http://localhost:1235/", StripPath: "/bar", PreserveHost: true}, + }, + { + ID: "foo2", + Match: &Match{URL: "https://localhost:1234/foo/", Methods: []string{"POST"}}, + Description: "Create users rule", + Authorizer: Handler{Handler: "allow", Config: []byte(`{"type":"any"}`)}, + Authenticators: []Handler{{Handler: "anonymous", Config: []byte(`{"name":"anonymous1"}`)}}, + Mutators: []Handler{{Handler: "id_token", Config: []byte(`{"issuer":"anything"}`)}}, + Upstream: Upstream{URL: "http://localhost:1235/", StripPath: "/bar", PreserveHost: true}, + }, + { + ID: "foo3", + Match: &Match{URL: "https://localhost:1234/foo/something/", Methods: []string{"POST"}}, + Description: "Create users rule", + Authorizer: Handler{Handler: "allow", Config: []byte(`{"type":"any"}`)}, + Authenticators: []Handler{{Handler: "anonymous", Config: []byte(`{"name":"anonymous1"}`)}}, + Mutators: []Handler{{Handler: "id_token", Config: []byte(`{"issuer":"anything"}`)}}, + Upstream: Upstream{URL: "http://localhost:1235/", StripPath: "/bar", PreserveHost: true}, + }, + { + ID: "foo4", + Match: &Match{URL: "https://localhost:34/", Methods: []string{"GET"}}, + Description: "Get users rule", + Authorizer: Handler{Handler: "deny", Config: []byte(`{"type":"any"}`)}, + Authenticators: []Handler{{Handler: "oauth2_introspection", Config: []byte(`{"name":"anonymous1"}`)}}, + Mutators: []Handler{{Handler: "id_token", Config: []byte(`{"issuer":"anything"}`)}}, + Upstream: Upstream{URL: "http://localhost:333/", StripPath: "/foo", PreserveHost: false}, + }, + { + ID: "foo5", + Match: &Match{URL: "https://localhost:343/", Methods: []string{"GET"}}, + Description: "Get users rule", + Authorizer: Handler{Handler: "deny"}, + Authenticators: []Handler{{Handler: "oauth2_introspection"}}, + Mutators: []Handler{{Handler: "id_token"}}, + Upstream: Upstream{URL: "http://localhost:3333/", StripPath: "/foo", PreserveHost: false}, + }, + { + ID: "grpc1", + Match: &MatchGRPC{Authority: "bar.example.com", FullMethod: "grpc.api/Call"}, + Description: "gRPC Rule", + Authorizer: Handler{Handler: "allow", Config: []byte(`{"type":"any"}`)}, + Authenticators: []Handler{{Handler: "anonymous", Config: []byte(`{"name":"anonymous1"}`)}}, + Mutators: []Handler{{Handler: "id_token", Config: []byte(`{"issuer":"anything"}`)}}, + Upstream: Upstream{URL: "http://bar.example.com/", PreserveHost: false}, + }, +} + func TestMatcher(t *testing.T) { type m interface { Matcher @@ -192,5 +249,43 @@ func TestMatcher(t *testing.T) { testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", ProtocolHTTP, true, nil) }) }) + t.Run(fmt.Sprintf("prefix matcher=%s", name), func(t *testing.T) { + require.NoError(t, matcher.SetMatchingStrategy(context.Background(), configuration.Regexp)) + require.NoError(t, matcher.SetPrefixMatching(context.Background(), true)) + require.NoError(t, matcher.Set(context.Background(), []Rule{})) + t.Run("case=empty", func(t *testing.T) { + testMatcher(t, matcher, "GET", "https://localhost:34/baz", ProtocolHTTP, true, nil) + testMatcher(t, matcher, "POST", "https://localhost:1234/foo", ProtocolHTTP, true, nil) + testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", ProtocolHTTP, true, nil) + }) + + require.NoError(t, matcher.Set(context.Background(), testRulesPrefix)) + + t.Run("case=created", func(t *testing.T) { + testMatcher(t, matcher, "POST", "https://localhost:1234/", ProtocolHTTP, false, &testRulesPrefix[0]) + testMatcher(t, matcher, "POST", "https://localhost:1234/foo", ProtocolHTTP, false, &testRulesPrefix[1]) + testMatcher(t, matcher, "POST", "https://localhost:1234/foo/something/very/long", ProtocolHTTP, false, &testRulesPrefix[2]) + testMatcher(t, matcher, "POST", "https://localhost:1234/foo/baz/something/very/long", ProtocolHTTP, false, &testRulesPrefix[1]) + testMatcher(t, matcher, "GET", "https://localhost:34/baz", ProtocolHTTP, false, &testRulesPrefix[3]) + testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", ProtocolHTTP, true, nil) + testMatcher(t, matcher, "POST", "grpc://bar.example.com/grpc.api/Call", ProtocolGRPC, false, &testRulesPrefix[5]) + }) + + t.Run("case=cache", func(t *testing.T) { + r, err := matcher.Match(context.Background(), "GET", mustParseURL(t, "https://localhost:34/baz"), ProtocolHTTP) + require.NoError(t, err) + _, err = matcher.Get(context.Background(), r.ID) + require.NoError(t, err) + // assert.NotEmpty(t, got.matchingEngine.Checksum()) + }) + + require.NoError(t, matcher.Set(context.Background(), testRulesPrefix[3:])) + + t.Run("case=updated", func(t *testing.T) { + testMatcher(t, matcher, "GET", "https://localhost:34/baz", ProtocolHTTP, false, &testRulesPrefix[3]) + testMatcher(t, matcher, "POST", "https://localhost:1234/foo", ProtocolHTTP, true, nil) + testMatcher(t, matcher, "DELETE", "https://localhost:1234/foo", ProtocolHTTP, true, nil) + }) + }) } } diff --git a/rule/repository.go b/rule/repository.go index de1462399f..9037ac9c50 100644 --- a/rule/repository.go +++ b/rule/repository.go @@ -17,5 +17,7 @@ type Repository interface { Count(context.Context) (int, error) MatchingStrategy(context.Context) (configuration.MatchingStrategy, error) SetMatchingStrategy(context.Context, configuration.MatchingStrategy) error + PrefixMatching(context.Context) (bool, error) + SetPrefixMatching(context.Context, bool) error ReadyChecker(*http.Request) error } diff --git a/rule/repository_memory.go b/rule/repository_memory.go index 3dc90c1b03..a529a6bba2 100644 --- a/rule/repository_memory.go +++ b/rule/repository_memory.go @@ -30,7 +30,9 @@ type RepositoryMemory struct { rules []Rule invalidRules []Rule matchingStrategy configuration.MatchingStrategy + prefixMatching bool r repositoryMemoryRegistry + trie *Trie } // MatchingStrategy returns current MatchingStrategy. @@ -48,10 +50,26 @@ func (m *RepositoryMemory) SetMatchingStrategy(_ context.Context, ms configurati return nil } +// PrefixMatching returns current PrefixMatching. +func (m *RepositoryMemory) PrefixMatching(_ context.Context) (bool, error) { + m.RLock() + defer m.RUnlock() + return m.prefixMatching, nil +} + +// SetPrefixMatching updates PrefixMatching. +func (m *RepositoryMemory) SetPrefixMatching(_ context.Context, enabled bool) error { + m.Lock() + defer m.Unlock() + m.prefixMatching = enabled + return nil +} + func NewRepositoryMemory(r repositoryMemoryRegistry) *RepositoryMemory { return &RepositoryMemory{ r: r, rules: make([]Rule, 0), + trie: NewTrie(), } } @@ -59,6 +77,9 @@ func NewRepositoryMemory(r repositoryMemoryRegistry) *RepositoryMemory { func (m *RepositoryMemory) WithRules(rules []Rule) { m.Lock() m.rules = rules + for _, rule := range rules { + m.trie.InsertRule(rule) + } m.Unlock() } @@ -97,6 +118,11 @@ func (m *RepositoryMemory) Set(ctx context.Context, rules []Rule) error { m.rules = make([]Rule, 0, len(rules)) m.invalidRules = make([]Rule, 0) + // Reset the trie if we are using prefix matching and the rules have changed. + if m.prefixMatching { + m.trie = NewTrie() + } + for _, check := range rules { if err := m.r.RuleValidator().Validate(&check); err != nil { m.r.Logger().WithError(err).WithField("rule_id", check.ID). @@ -104,6 +130,12 @@ func (m *RepositoryMemory) Set(ctx context.Context, rules []Rule) error { m.invalidRules = append(m.invalidRules, check) } else { m.rules = append(m.rules, check) + if m.prefixMatching { + if err := m.trie.InsertRule(check); err != nil { + m.r.Logger().WithError(err).WithField("rule_id", check.ID). + Errorf("A Prefix Rule could not be loaded into the trie so all requests will be sent to the closest matching prefix. You should resolve this issue now.") + } + } } } @@ -119,20 +151,41 @@ func (m *RepositoryMemory) Match(ctx context.Context, method string, u *url.URL, defer m.Unlock() var rules []*Rule - for k := range m.rules { - r := &m.rules[k] - if matched, err := r.IsMatching(m.matchingStrategy, method, u, protocol); err != nil { - return nil, errors.WithStack(err) - } else if matched { - rules = append(rules, r) + + if m.prefixMatching { + if m.trie.root == nil { + return nil, errors.WithStack(errors.New("prefix trie is nil")) + } else { + matchedRules := m.trie.Match(method, u, protocol) + for _, r := range matchedRules { + // if there are multiple rules that match, we will procede to filter them using the matching strategy + if len(matchedRules) > 1 { + if matched, err := r.IsMatching(m.matchingStrategy, method, u, protocol); err != nil { + return nil, errors.WithStack(err) + } else if matched { + rules = append(rules, &r) + } + } else { + rules = append(rules, &r) + } + } } - } - for k := range m.invalidRules { - r := &m.invalidRules[k] - if matched, err := r.IsMatching(m.matchingStrategy, method, u, protocol); err != nil { - return nil, errors.WithStack(err) - } else if matched { - rules = append(rules, r) + } else { + for k := range m.rules { + r := &m.rules[k] + if matched, err := r.IsMatching(m.matchingStrategy, method, u, protocol); err != nil { + return nil, errors.WithStack(err) + } else if matched { + rules = append(rules, r) + } + } + for k := range m.invalidRules { + r := &m.invalidRules[k] + if matched, err := r.IsMatching(m.matchingStrategy, method, u, protocol); err != nil { + return nil, errors.WithStack(err) + } else if matched { + rules = append(rules, r) + } } } diff --git a/rule/trie.go b/rule/trie.go new file mode 100644 index 0000000000..3e1f2fc408 --- /dev/null +++ b/rule/trie.go @@ -0,0 +1,193 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package rule + +import ( + "net/url" + "strings" + + "github.com/dlclark/regexp2" +) + +// types for a trie root node +type TrieNode struct { + children map[string]*TrieNode + rules []Rule + // isWord bool +} + +// types for a trie node +type Trie struct { + root *TrieNode + pathCleaner *regexp2.Regexp +} + +// NewTrie creates a new trie +func NewTrie() *Trie { + return &Trie{ + root: &TrieNode{ + children: make(map[string]*TrieNode), + }, + // if the path contains a regex, we don't need to insert it or anything after into the trie + pathCleaner: regexp2.MustCompile(`<.*>.*`, 0), + } +} + +// Insert a url host and paths into the trie +func (t *Trie) InsertRule(r Rule) error { + node := t.root + + matchURL, err := url.Parse(r.Match.GetURL()) + if err != nil { + return err + } + + // insert the protocol into the trie + if _, ok := node.children[string(r.Match.Protocol())]; !ok { + node.children[string(r.Match.Protocol())] = &TrieNode{ + children: make(map[string]*TrieNode), + } + } + node = node.children[string(r.Match.Protocol())] + + // insert the methods into the trie + for _, method := range r.Match.GetMethods() { + // reset the node to the root, followed by the protocol + node = t.root + node = node.children[string(r.Match.Protocol())] + if _, ok := node.children[method]; !ok { + node.children[method] = &TrieNode{ + children: make(map[string]*TrieNode), + } + } + node = node.children[method] + + // insert the scheme into the trie + if _, ok := node.children[matchURL.Scheme]; !ok { + node.children[matchURL.Scheme] = &TrieNode{ + children: make(map[string]*TrieNode), + } + } + node = node.children[matchURL.Scheme] + + // insert the host into the trie + if _, ok := node.children[matchURL.Host]; !ok { + node.children[matchURL.Host] = &TrieNode{ + children: make(map[string]*TrieNode), + } + } + node = node.children[matchURL.Host] + + // remove any regex from the path + cleanPath, err := t.pathCleaner.Replace(matchURL.Path, "", 0, -1) + if err != nil { + return err + } + + // remove the leading and trailing slash + trimmedPath := strings.Trim(cleanPath, "/") + + if len(trimmedPath) == 0 { + node.rules = append(node.rules, r) + } else { + + // insert the paths into the trie + splitPaths := strings.Split(trimmedPath, "/") + i := 0 + for _, path := range splitPaths { + i++ + if _, ok := node.children[string(path)]; !ok { + node.children[string(path)] = &TrieNode{ + children: make(map[string]*TrieNode), + } + } + node = node.children[string(path)] + // if this is the last path, append the rule + if i == len(splitPaths) { + node.rules = append(node.rules, r) + } + } + } + } + return nil +} + +// return the longest prefix of the url that is in the trie +func (t *Trie) LongestPrefix(u *url.URL) string { + node := t.root + var prefix string + // check the scheme + if _, ok := node.children[u.Scheme]; !ok { + return prefix + } + prefix += u.Scheme + node = node.children[u.Scheme] + + // check the host + if _, ok := node.children[u.Host]; !ok { + return prefix + } + prefix += u.Host + node = node.children[u.Host] + // check the paths + // remove the leading and trailing slash + trimmedPath := strings.Trim(u.Path, "/") + + if len(trimmedPath) > 0 { + splitPaths := strings.Split(trimmedPath, "/") + for _, path := range splitPaths { + if _, ok := node.children[string(path)]; !ok { + break + } + prefix += "/" + string(path) + node = node.children[string(path)] + } + } + return prefix + +} + +// return the rules of the longest prefix of the url that is in the trie +func (t *Trie) Match(method string, u *url.URL, protocol Protocol) []Rule { + node := t.root + + // check the protocol + if _, ok := node.children[string(protocol)]; !ok { + return nil + } + node = node.children[string(protocol)] + + // check the method + if _, ok := node.children[method]; !ok { + return nil + } + node = node.children[method] + + // check the scheme + if _, ok := node.children[u.Scheme]; !ok { + return nil + } + node = node.children[u.Scheme] + + // check the host + if _, ok := node.children[u.Host]; !ok { + return nil + } + node = node.children[u.Host] + // remove the leading and trailing slash + trimmedPath := strings.Trim(u.Path, "/") + if len(trimmedPath) == 0 { + return node.rules + } else { + // check the paths + splitPaths := strings.Split(trimmedPath, "/") + for _, path := range splitPaths { + if _, ok := node.children[string(path)]; !ok { + break + } + node = node.children[string(path)] + } + return node.rules + } +} diff --git a/spec/config.schema.json b/spec/config.schema.json index 3a90a77a6b..45d84325ce 100644 --- a/spec/config.schema.json +++ b/spec/config.schema.json @@ -1267,6 +1267,12 @@ "default": "regexp", "enum": ["glob", "regexp"], "examples": ["glob"] + }, + "prefix_matching_enabled": { + "title": "Enable prefix matching", + "description": "This an optional field describing if rules should be matched using path prefixes, defaults to false.", + "type": "boolean", + "default": false } } },