diff --git a/pkg/rules/engine.go b/pkg/rules/engine.go index c37c735..79b91e7 100644 --- a/pkg/rules/engine.go +++ b/pkg/rules/engine.go @@ -34,6 +34,11 @@ type Query struct { SourceASN uint32 } +// match checks if any of the conditions match the given matchFunc. +func match[T any](conditions []T, matchFunc func(T) bool) bool { + return len(conditions) == 0 || utils.Any(conditions, matchFunc) +} + // ruleApplies checks if the given query is allowed or denied by the given // rule. For a rule to be applicable, the query must match all of the rule's // conditions. @@ -43,47 +48,30 @@ type Query struct { // // Domains, methods and countries are case-insensitive. func ruleApplies(rule *schema.AccessControlRule, query *Query) bool { - if len(rule.Domains) > 0 { - if utils.None(rule.Domains, func(domain string) bool { - return glob.Star(domain, query.RequestedDomain) - }) { - return false - } - } + matchDomain := match(rule.Domains, func(domain string) bool { + return glob.Star( + strings.ToLower(domain), + strings.ToLower(query.RequestedDomain), + ) + }) - if len(rule.Methods) > 0 { - if utils.None(rule.Methods, func(method string) bool { - return strings.EqualFold(method, query.RequestedMethod) - }) { - return false - } - } + matchMethod := match(rule.Methods, func(method string) bool { + return strings.EqualFold(method, query.RequestedMethod) + }) - if len(rule.Networks) > 0 { - if utils.None(rule.Networks, func(network schema.CIDR) bool { - return network.Contains(query.SourceIP) - }) { - return false - } - } + matchIP := match(rule.Networks, func(network schema.CIDR) bool { + return network.Contains(query.SourceIP) + }) - if len(rule.Countries) > 0 { - if utils.None(rule.Countries, func(country string) bool { - return strings.EqualFold(country, query.SourceCountry) - }) { - return false - } - } + matchCountry := match(rule.Countries, func(country string) bool { + return strings.EqualFold(country, query.SourceCountry) + }) - if len(rule.AutonomousSystems) > 0 { - if utils.None(rule.AutonomousSystems, func(asn uint32) bool { - return asn == query.SourceASN - }) { - return false - } - } + matchANS := match(rule.AutonomousSystems, func(asn uint32) bool { + return asn == query.SourceASN + }) - return true + return matchDomain && matchMethod && matchIP && matchCountry && matchANS } // UpdateConfig updates the engine's configuration with the given access diff --git a/pkg/rules/engine_test.go b/pkg/rules/engine_test.go index 557df73..d8a6d89 100644 --- a/pkg/rules/engine_test.go +++ b/pkg/rules/engine_test.go @@ -117,6 +117,22 @@ func TestEngineAuthorize(t *testing.T) { }, want: false, }, + { + name: "domains are case-insensitive", + config: &schema.AccessControl{ + Rules: []schema.AccessControlRule{ + { + Domains: []string{"example.org", "example.com"}, + Policy: schema.PolicyAllow, + }, + }, + DefaultPolicy: schema.PolicyDeny, + }, + query: &rules.Query{ + RequestedDomain: "EXAMPLE.ORG", + }, + want: true, + }, { name: "allow by method", config: &schema.AccessControl{ @@ -165,6 +181,22 @@ func TestEngineAuthorize(t *testing.T) { }, want: false, }, + { + name: "methods are case-insensitive", + config: &schema.AccessControl{ + Rules: []schema.AccessControlRule{ + { + Methods: []string{"GET", "POST"}, + Policy: schema.PolicyAllow, + }, + }, + DefaultPolicy: schema.PolicyDeny, + }, + query: &rules.Query{ + RequestedMethod: "get", + }, + want: true, + }, { name: "allow by network", config: &schema.AccessControl{ @@ -263,6 +295,22 @@ func TestEngineAuthorize(t *testing.T) { }, want: false, }, + { + name: "countries are case-insensitive", + config: &schema.AccessControl{ + Rules: []schema.AccessControlRule{ + { + Countries: []string{"FR", "US"}, + Policy: schema.PolicyAllow, + }, + }, + DefaultPolicy: schema.PolicyDeny, + }, + query: &rules.Query{ + SourceCountry: "fr", + }, + want: true, + }, { name: "allow by ASN", config: &schema.AccessControl{ diff --git a/pkg/utils/glob/glob_test.go b/pkg/utils/glob/glob_test.go index 0248a80..577c37c 100644 --- a/pkg/utils/glob/glob_test.go +++ b/pkg/utils/glob/glob_test.go @@ -14,6 +14,7 @@ func TestStar(t *testing.T) { }{ {"", "", true}, {"*", "", true}, + {"a", "", false}, {"", "abc", false}, {"*", "abc", true}, {"a*", "abc", true},