Skip to content

Commit

Permalink
ast: expanding nested expressions in every domain (#6832)
Browse files Browse the repository at this point in the history
Fixes: #6790
Signed-off-by: Johan Fylling <[email protected]>
  • Loading branch information
johanfylling authored Jun 26, 2024
1 parent cb77956 commit c2cede7
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 12 deletions.
18 changes: 8 additions & 10 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -4809,16 +4809,14 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
result = append(result, expr)
case *Every:
var extras []*Expr
if _, ok := terms.Domain.Value.(Call); ok {
extras, terms.Domain = expandExprTerm(gen, terms.Domain)
} else {
term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location)
eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location)
eq.Generated = true
eq.With = expr.With
extras = append(extras, eq)
terms.Domain = term
}

term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location)
eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location)
eq.Generated = true
eq.With = expr.With
extras = expandExpr(gen, eq)
terms.Domain = term

terms.Body = rewriteExprTermsInBody(gen, terms.Body)
result = append(result, extras...)
result = append(result, expr)
Expand Down
94 changes: 92 additions & 2 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,21 @@ func TestCompilerCheckTypes(t *testing.T) {
assertNotFailed(t, c)
}

// Regression test for GH issue #6790
func TestCompilerCheckEveryWithNestedDomainCalls(t *testing.T) {
c := NewCompiler()
c.Modules = map[string]*Module{"test": MustParseModule(`package test
import rego.v1
x if {
every p in [1 / 2] {
p == true
}
}`)}
compileStages(c, c.checkTypes)
assertNotFailed(t, c)
}

func TestCompilerCheckRuleConflicts(t *testing.T) {

c := getCompilerWithParsedModules(map[string]string{
Expand Down Expand Up @@ -2567,7 +2582,7 @@ func TestCompilerRewriteExprTerms(t *testing.T) {
f(__local0__[0]) { true; __local0__ = [1] }`,
},
{
note: "every: domain",
note: "every: domain (array)",
module: `
package test
Expand All @@ -2577,6 +2592,79 @@ func TestCompilerRewriteExprTerms(t *testing.T) {
p { __local2__ = [1, 2]; every __local0__, __local1__ in __local2__ { __local1__ } }`,
},
{
note: "every: domain (call)",
module: `
package test
p { every x in numbers.range(1, 3) { x } }`,
expected: `
package test
p = true {
numbers.range(1, 3, __local3__)
__local2__ = __local3__
every __local0__, __local1__ in __local2__ {
__local1__
}
}`,
},
{
note: "every: domain (nested calls)",
module: `
package test
p { every x in numbers.range(1 + 2, 3 * 4) { x } }`,
expected: `
package test
p = true {
plus(1, 2, __local3__)
mul(3, 4, __local4__)
numbers.range(__local3__, __local4__, __local5__)
__local2__ = __local5__
every __local0__, __local1__ in __local2__ {
__local1__
}
}`,
},
// Regression test for GH issue #6790
{
note: "every: domain (array with call)",
module: `
package test
p { every x in [1 / 2, "foo", abs(-1)] { x } }`,
expected: `
package test
p = true {
div(1, 2, __local3__)
abs(-1, __local4__)
__local2__ = [__local3__, "foo", __local4__]
every __local0__, __local1__ in __local2__ {
__local1__
}
}`,
},
{
note: "every: domain (nested array with call)",
module: `
package test
p { every x in [1 / 2, ["foo", abs(-1)]] { x } }`,
expected: `
package test
p = true {
div(1, 2, __local3__)
abs(-1, __local4__)
__local2__ = [__local3__, ["foo", __local4__]]
every __local0__, __local1__ in __local2__ {
__local1__
}
}`,
},
}

for _, tc := range cases {
Expand Down Expand Up @@ -6634,7 +6722,9 @@ func TestCompilerRewriteDynamicTerms(t *testing.T) {
{`call_func { f(input, "foo") } f(x,y) { x[y] }`, `__local2__ = input; data.test.f(__local2__, "foo")`},
{`call_func2 { f(input.foo, "foo") } f(x,y) { x[y] }`, `__local2__ = input.foo; data.test.f(__local2__, "foo")`},
{`every_domain { every _ in str { true } }`, `__local1__ = data.test.str; every __local0__, _ in __local1__ { true }`},
{`every_domain_call { every _ in numbers.range(1, 10) { true } }`, `numbers.range(1, 10, __local1__); every __local0__, _ in __local1__ { true }`},
{`every_domain_array { every _ in [1, 2, 3] { true } }`, `__local1__ = [1, 2, 3]; every __local0__, _ in __local1__ { true }`},
{`every_domain_call { every _ in numbers.range(1, 10) { true } }`, `numbers.range(1, 10, __local2__); __local1__ = __local2__; every __local0__, _ in __local1__ { true }`},
{`every_domain_array_w_calls { every _ in [1 / 2, "foo", abs(-1)] { true } }`, `div(1, 2, __local2__); abs(-1, __local3__); __local1__ = [__local2__, "foo", __local3__]; every __local0__, _ in __local1__ { true }`},
{`every_body { every _ in [] { [str] } }`,
`__local1__ = []; every __local0__, _ in __local1__ { __local2__ = data.test.str; [__local2__] }`},
}
Expand Down
28 changes: 28 additions & 0 deletions test/cases/testdata/every/every.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,31 @@ cases:
query: data.test.p = x
want_result:
- x: [10]
- note: "every/array with calls"
modules:
- |
package test
import future.keywords.every
p {
every v in [1 / 2, 3, 4 + 5] { v < 10 }
}
query: data.test.p = x
want_result:
- x: true
- note: "every/array with calls (fail)"
modules:
- |
package test
import future.keywords.every
p {
every v in [1 / 2, 3, 4 + 5] { v < 9 }
}
q {
not p
}
query: data.test.q = x
want_result:
- x: true

0 comments on commit c2cede7

Please sign in to comment.