diff --git a/ast/compile.go b/ast/compile.go index d54cf93dcb..c59cfede62 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -4809,16 +4809,14 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) { result = append(result, expr) case *Every: var extras []*Expr - if _, ok := terms.Domain.Value.(Call); ok { - extras, terms.Domain = expandExprTerm(gen, terms.Domain) - } else { - term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location) - eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location) - eq.Generated = true - eq.With = expr.With - extras = append(extras, eq) - terms.Domain = term - } + + term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location) + eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location) + eq.Generated = true + eq.With = expr.With + extras = expandExpr(gen, eq) + terms.Domain = term + terms.Body = rewriteExprTermsInBody(gen, terms.Body) result = append(result, extras...) result = append(result, expr) diff --git a/ast/compile_test.go b/ast/compile_test.go index 0f87795305..6c8e8dd1b8 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -1775,6 +1775,21 @@ func TestCompilerCheckTypes(t *testing.T) { assertNotFailed(t, c) } +// Regression test for GH issue #6790 +func TestCompilerCheckEveryWithNestedDomainCalls(t *testing.T) { + c := NewCompiler() + c.Modules = map[string]*Module{"test": MustParseModule(`package test +import rego.v1 + +x if { + every p in [1 / 2] { + p == true + } +}`)} + compileStages(c, c.checkTypes) + assertNotFailed(t, c) +} + func TestCompilerCheckRuleConflicts(t *testing.T) { c := getCompilerWithParsedModules(map[string]string{ @@ -2567,7 +2582,7 @@ func TestCompilerRewriteExprTerms(t *testing.T) { f(__local0__[0]) { true; __local0__ = [1] }`, }, { - note: "every: domain", + note: "every: domain (array)", module: ` package test @@ -2577,6 +2592,79 @@ func TestCompilerRewriteExprTerms(t *testing.T) { p { __local2__ = [1, 2]; every __local0__, __local1__ in __local2__ { __local1__ } }`, }, + { + note: "every: domain (call)", + module: ` + package test + + p { every x in numbers.range(1, 3) { x } }`, + expected: ` + package test + + p = true { + numbers.range(1, 3, __local3__) + __local2__ = __local3__ + every __local0__, __local1__ in __local2__ { + __local1__ + } + }`, + }, + { + note: "every: domain (nested calls)", + module: ` + package test + + p { every x in numbers.range(1 + 2, 3 * 4) { x } }`, + expected: ` + package test + + p = true { + plus(1, 2, __local3__) + mul(3, 4, __local4__) + numbers.range(__local3__, __local4__, __local5__) + __local2__ = __local5__ + every __local0__, __local1__ in __local2__ { + __local1__ + } + }`, + }, + // Regression test for GH issue #6790 + { + note: "every: domain (array with call)", + module: ` + package test + + p { every x in [1 / 2, "foo", abs(-1)] { x } }`, + expected: ` + package test + + p = true { + div(1, 2, __local3__) + abs(-1, __local4__) + __local2__ = [__local3__, "foo", __local4__] + every __local0__, __local1__ in __local2__ { + __local1__ + } + }`, + }, + { + note: "every: domain (nested array with call)", + module: ` + package test + + p { every x in [1 / 2, ["foo", abs(-1)]] { x } }`, + expected: ` + package test + + p = true { + div(1, 2, __local3__) + abs(-1, __local4__) + __local2__ = [__local3__, ["foo", __local4__]] + every __local0__, __local1__ in __local2__ { + __local1__ + } + }`, + }, } for _, tc := range cases { @@ -6634,7 +6722,9 @@ func TestCompilerRewriteDynamicTerms(t *testing.T) { {`call_func { f(input, "foo") } f(x,y) { x[y] }`, `__local2__ = input; data.test.f(__local2__, "foo")`}, {`call_func2 { f(input.foo, "foo") } f(x,y) { x[y] }`, `__local2__ = input.foo; data.test.f(__local2__, "foo")`}, {`every_domain { every _ in str { true } }`, `__local1__ = data.test.str; every __local0__, _ in __local1__ { true }`}, - {`every_domain_call { every _ in numbers.range(1, 10) { true } }`, `numbers.range(1, 10, __local1__); every __local0__, _ in __local1__ { true }`}, + {`every_domain_array { every _ in [1, 2, 3] { true } }`, `__local1__ = [1, 2, 3]; every __local0__, _ in __local1__ { true }`}, + {`every_domain_call { every _ in numbers.range(1, 10) { true } }`, `numbers.range(1, 10, __local2__); __local1__ = __local2__; every __local0__, _ in __local1__ { true }`}, + {`every_domain_array_w_calls { every _ in [1 / 2, "foo", abs(-1)] { true } }`, `div(1, 2, __local2__); abs(-1, __local3__); __local1__ = [__local2__, "foo", __local3__]; every __local0__, _ in __local1__ { true }`}, {`every_body { every _ in [] { [str] } }`, `__local1__ = []; every __local0__, _ in __local1__ { __local2__ = data.test.str; [__local2__] }`}, } diff --git a/test/cases/testdata/every/every.yaml b/test/cases/testdata/every/every.yaml index d5419e8145..37e485c8f3 100644 --- a/test/cases/testdata/every/every.yaml +++ b/test/cases/testdata/every/every.yaml @@ -219,3 +219,31 @@ cases: query: data.test.p = x want_result: - x: [10] + - note: "every/array with calls" + modules: + - | + package test + import future.keywords.every + + p { + every v in [1 / 2, 3, 4 + 5] { v < 10 } + } + query: data.test.p = x + want_result: + - x: true + - note: "every/array with calls (fail)" + modules: + - | + package test + import future.keywords.every + + p { + every v in [1 / 2, 3, 4 + 5] { v < 9 } + } + + q { + not p + } + query: data.test.q = x + want_result: + - x: true