Skip to content

Commit

Permalink
refactor: refactor rule-matching function (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
danroc authored Nov 21, 2024
1 parent f2c1a45 commit a995cb5
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 36 deletions.
60 changes: 24 additions & 36 deletions pkg/rules/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ type Query struct {
SourceASN uint32
}

// match checks if any of the conditions match the given matchFunc.
func match[T any](conditions []T, matchFunc func(T) bool) bool {
return len(conditions) == 0 || utils.Any(conditions, matchFunc)
}

// ruleApplies checks if the given query is allowed or denied by the given
// rule. For a rule to be applicable, the query must match all of the rule's
// conditions.
Expand All @@ -43,47 +48,30 @@ type Query struct {
//
// Domains, methods and countries are case-insensitive.
func ruleApplies(rule *schema.AccessControlRule, query *Query) bool {
if len(rule.Domains) > 0 {
if utils.None(rule.Domains, func(domain string) bool {
return glob.Star(domain, query.RequestedDomain)
}) {
return false
}
}
matchDomain := match(rule.Domains, func(domain string) bool {
return glob.Star(
strings.ToLower(domain),
strings.ToLower(query.RequestedDomain),
)
})

if len(rule.Methods) > 0 {
if utils.None(rule.Methods, func(method string) bool {
return strings.EqualFold(method, query.RequestedMethod)
}) {
return false
}
}
matchMethod := match(rule.Methods, func(method string) bool {
return strings.EqualFold(method, query.RequestedMethod)
})

if len(rule.Networks) > 0 {
if utils.None(rule.Networks, func(network schema.CIDR) bool {
return network.Contains(query.SourceIP)
}) {
return false
}
}
matchIP := match(rule.Networks, func(network schema.CIDR) bool {
return network.Contains(query.SourceIP)
})

if len(rule.Countries) > 0 {
if utils.None(rule.Countries, func(country string) bool {
return strings.EqualFold(country, query.SourceCountry)
}) {
return false
}
}
matchCountry := match(rule.Countries, func(country string) bool {
return strings.EqualFold(country, query.SourceCountry)
})

if len(rule.AutonomousSystems) > 0 {
if utils.None(rule.AutonomousSystems, func(asn uint32) bool {
return asn == query.SourceASN
}) {
return false
}
}
matchANS := match(rule.AutonomousSystems, func(asn uint32) bool {
return asn == query.SourceASN
})

return true
return matchDomain && matchMethod && matchIP && matchCountry && matchANS
}

// UpdateConfig updates the engine's configuration with the given access
Expand Down
48 changes: 48 additions & 0 deletions pkg/rules/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ func TestEngineAuthorize(t *testing.T) {
},
want: false,
},
{
name: "domains are case-insensitive",
config: &schema.AccessControl{
Rules: []schema.AccessControlRule{
{
Domains: []string{"example.org", "example.com"},
Policy: schema.PolicyAllow,
},
},
DefaultPolicy: schema.PolicyDeny,
},
query: &rules.Query{
RequestedDomain: "EXAMPLE.ORG",
},
want: true,
},
{
name: "allow by method",
config: &schema.AccessControl{
Expand Down Expand Up @@ -165,6 +181,22 @@ func TestEngineAuthorize(t *testing.T) {
},
want: false,
},
{
name: "methods are case-insensitive",
config: &schema.AccessControl{
Rules: []schema.AccessControlRule{
{
Methods: []string{"GET", "POST"},
Policy: schema.PolicyAllow,
},
},
DefaultPolicy: schema.PolicyDeny,
},
query: &rules.Query{
RequestedMethod: "get",
},
want: true,
},
{
name: "allow by network",
config: &schema.AccessControl{
Expand Down Expand Up @@ -263,6 +295,22 @@ func TestEngineAuthorize(t *testing.T) {
},
want: false,
},
{
name: "countries are case-insensitive",
config: &schema.AccessControl{
Rules: []schema.AccessControlRule{
{
Countries: []string{"FR", "US"},
Policy: schema.PolicyAllow,
},
},
DefaultPolicy: schema.PolicyDeny,
},
query: &rules.Query{
SourceCountry: "fr",
},
want: true,
},
{
name: "allow by ASN",
config: &schema.AccessControl{
Expand Down
1 change: 1 addition & 0 deletions pkg/utils/glob/glob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ func TestStar(t *testing.T) {
}{
{"", "", true},
{"*", "", true},
{"a", "", false},
{"", "abc", false},
{"*", "abc", true},
{"a*", "abc", true},
Expand Down

0 comments on commit a995cb5

Please sign in to comment.