diff --git a/ast/compile.go b/ast/compile.go index 7292926acc..458eed9445 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -873,6 +873,7 @@ func (c *Compiler) checkRuleConflicts() { arities := make(map[int]struct{}, len(node.Values)) name := "" var singleValueConflicts []Ref + var multiValueConflicts []Ref for _, rule := range node.Values { r := rule.(*Rule) @@ -908,12 +909,24 @@ func (c *Compiler) checkRuleConflicts() { singleValueConflicts = node.flattenChildren() } } + + // Multi-value rules may not have any other rules in their extent; e.g.: + // + // data.p[v] { v := ... } + // data.p.q := 42 # In direct conflict with data.p[v], which is constructing a set and cannot have values assigned to a sub-path. + + if r.Head.RuleKind() == MultiValue && len(node.Children) > 0 { + multiValueConflicts = node.flattenChildren() + } } switch { case singleValueConflicts != nil: c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "single-value rule %v conflicts with %v", name, singleValueConflicts)) + case multiValueConflicts != nil: + c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multi-value rule %v conflicts with %v", name, multiValueConflicts)) + case len(kinds) > 1 || len(arities) > 1: c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name)) diff --git a/ast/compile_test.go b/ast/compile_test.go index 14c9da7fd0..b7cbab338f 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -1959,6 +1959,24 @@ func TestCompilerCheckRuleConflictsDotsInRuleHeads(t *testing.T) { p.q.s = "x" { true } `), }, + { + note: "multi-value rule with other rule overlap", + modules: modules( + `package pkg + p[v] { v := ["a", "b"][_] } + p.q := 42 + `), + err: "rego_type_error: multi-value rule data.pkg.p conflicts with [data.pkg.p.q]", + }, + { + note: "multi-value rule with other rule (ref) overlap", + modules: modules( + `package pkg + p[v] { v := ["a", "b"][_] } + p.q.r { true } + `), + err: "rego_type_error: multi-value rule data.pkg.p conflicts with [data.pkg.p.q.r]", + }, } for _, tc := range tests { t.Run(tc.note, func(t *testing.T) {