Skip to content

Commit

Permalink
Update rule index's trie node scalars to use custom map
Browse files Browse the repository at this point in the history
Earlier the trie's scalar used map[ast.Value]trieNode.
This resulted in unexpected key comparison results such
as 1 != 1.0 when key types were ast.Number for example.
This change updates the scalar to use a util.HashMap instead
that will utilize ast.Compare to perfom key comparisons.

Fixes: #5585

Signed-off-by: Ashutosh Narkar <[email protected]>
  • Loading branch information
ashutosh-narkar committed Feb 1, 2023
1 parent ea06d89 commit 35b2b6d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 25 deletions.
59 changes: 34 additions & 25 deletions ast/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ type trieNode struct {
next *trieNode
any *trieNode
undefined *trieNode
scalars map[Value]*trieNode
scalars *util.HashMap
array *trieNode
rules []*ruleNode
}
Expand All @@ -465,11 +465,14 @@ func (node *trieNode) String() string {
if node.array != nil {
flags = append(flags, fmt.Sprintf("array:%p", node.array))
}
if len(node.scalars) > 0 {
buf := make([]string, 0, len(node.scalars))
for k, v := range node.scalars {
buf = append(buf, fmt.Sprintf("scalar(%v):%p", k, v))
}
if node.scalars.Len() > 0 {
buf := make([]string, 0, node.scalars.Len())
node.scalars.Iter(func(k, v util.T) bool {
key := k.(Value)
val := v.(*trieNode)
buf = append(buf, fmt.Sprintf("scalar(%v):%p", key, val))
return false
})
sort.Strings(buf)
flags = append(flags, strings.Join(buf, " "))
}
Expand Down Expand Up @@ -505,7 +508,7 @@ type ruleNode struct {

func newTrieNodeImpl() *trieNode {
return &trieNode{
scalars: map[Value]*trieNode{},
scalars: util.NewHashMap(valueEq, valueHash),
}
}

Expand All @@ -520,9 +523,13 @@ func (node *trieNode) Do(walker trieWalker) {
if node.undefined != nil {
node.undefined.Do(next)
}
for _, child := range node.scalars {

node.scalars.Iter(func(_, v util.T) bool {
child := v.(*trieNode)
child.Do(next)
}
return false
})

if node.array != nil {
node.array.Do(next)
}
Expand Down Expand Up @@ -579,12 +586,12 @@ func (node *trieNode) insertValue(value Value) *trieNode {
}
return node.any
case Null, Boolean, Number, String:
child, ok := node.scalars[value]
child, ok := node.scalars.Get(value)
if !ok {
child = newTrieNodeImpl()
node.scalars[value] = child
node.scalars.Put(value, child)
}
return child
return child.(*trieNode)
case *Array:
if node.array == nil {
node.array = newTrieNodeImpl()
Expand All @@ -608,12 +615,12 @@ func (node *trieNode) insertArray(arr *Array) *trieNode {
}
return node.any.insertArray(arr.Slice(1, -1))
case Null, Boolean, Number, String:
child, ok := node.scalars[head]
child, ok := node.scalars.Get(head)
if !ok {
child = newTrieNodeImpl()
node.scalars[head] = child
node.scalars.Put(head, child)
}
return child.insertArray(arr.Slice(1, -1))
return child.(*trieNode).insertArray(arr.Slice(1, -1))
}

panic("illegal value")
Expand Down Expand Up @@ -674,11 +681,11 @@ func (node *trieNode) traverseValue(resolver ValueResolver, tr *trieTraversalRes
return node.array.traverseArray(resolver, tr, value)

case Null, Boolean, Number, String:
child, ok := node.scalars[value]
child, ok := node.scalars.Get(value)
if !ok {
return nil
}
return child.Traverse(resolver, tr)
return child.(*trieNode).Traverse(resolver, tr)
}

return nil
Expand All @@ -703,12 +710,11 @@ func (node *trieNode) traverseArray(resolver ValueResolver, tr *trieTraversalRes
}
}

child, ok := node.scalars[head]
child, ok := node.scalars.Get(head)
if !ok {
return nil
}

return child.traverseArray(resolver, tr, arr.Slice(1, -1))
return child.(*trieNode).traverseArray(resolver, tr, arr.Slice(1, -1))
}

func (node *trieNode) traverseUnknown(resolver ValueResolver, tr *trieTraversalResult) error {
Expand All @@ -733,13 +739,16 @@ func (node *trieNode) traverseUnknown(resolver ValueResolver, tr *trieTraversalR
return err
}

for _, child := range node.scalars {
if err := child.traverseUnknown(resolver, tr); err != nil {
return err
var iterErr error
node.scalars.Iter(func(_, v util.T) bool {
child := v.(*trieNode)
if iterErr = child.traverseUnknown(resolver, tr); iterErr != nil {
return true
}
}
return false
})

return nil
return iterErr
}

// If term `a` is one of the function's operands, we store a Ref: `args[0]`
Expand Down
19 changes: 19 additions & 0 deletions test/cases/testdata/eqexpr/test-eqexpr-0599.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
cases:
- input_term: "1.0"
modules:
- |
package test
p { input == 1.0 }
note: "eqexpr/indexing: input is 1.0"
query: data.test.p = x
want_result:
- x: true
- input_term: "1"
modules:
- |
package test
p { input == 1.0 }
note: "eqexpr/indexing: input is 1"
query: data.test.p = x
want_result:
- x: true

0 comments on commit 35b2b6d

Please sign in to comment.