Skip to content

Commit

Permalink
Refactor deref (#398)
Browse files Browse the repository at this point in the history
* Refactor deref types

* Refactor deref opcode inserts

* Add multiple pointers deref
  • Loading branch information
antonmedv authored Aug 10, 2023
1 parent d2100ec commit 0dd3702
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 93 deletions.
4 changes: 1 addition & 3 deletions ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ type NilNode struct {
type IdentifierNode struct {
base
Value string
Deref bool
FieldIndex []int
Method bool // true if method, false if field
MethodIndex int // index of method, set only if Method is true
Expand Down Expand Up @@ -106,10 +105,9 @@ type MemberNode struct {
Property Node
Name string // Name of the filed or method. Used for error reporting.
Optional bool
Deref bool
FieldIndex []int

// TODO: Replace with a single MethodIndex field of &int type.
// TODO: Combine Method and MethodIndex into a single MethodIndex field of &int type.
Method bool
MethodIndex int
}
Expand Down
26 changes: 10 additions & 16 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,14 @@ func (v *visitor) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info)
// when the arguments are known in CallNode.
return anyType, info{fn: fn}
}
if v.config.Types == nil {
node.Deref = true
} else if t, ok := v.config.Types[node.Value]; ok {
if t, ok := v.config.Types[node.Value]; ok {
if t.Ambiguous {
return v.error(node, "ambiguous identifier %v", node.Value)
}
d, c := deref(t.Type)
node.Deref = c
node.Method = t.Method
node.MethodIndex = t.MethodIndex
node.FieldIndex = t.FieldIndex
return d, info{method: t.Method}
return t.Type, info{method: t.Method}
}
if v.config.Strict {
return v.error(node, "unknown name %v", node.Value)
Expand Down Expand Up @@ -180,6 +176,8 @@ func (v *visitor) ConstantNode(node *ast.ConstantNode) (reflect.Type, info) {
func (v *visitor) UnaryNode(node *ast.UnaryNode) (reflect.Type, info) {
t, _ := v.visit(node.Node)

t = deref(t)

switch node.Operator {

case "!", "not":
Expand Down Expand Up @@ -209,6 +207,9 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
l, _ := v.visit(node.Left)
r, _ := v.visit(node.Right)

l = deref(l)
r = deref(r)

// check operator overloading
if fns, ok := v.config.Operators[node.Operator]; ok {
t, _, ok := conf.FindSuitableOperatorOverload(fns, v.config.Types, l, r)
Expand Down Expand Up @@ -427,34 +428,27 @@ func (v *visitor) MemberNode(node *ast.MemberNode) (reflect.Type, info) {

switch base.Kind() {
case reflect.Interface:
node.Deref = true
return anyType, info{}

case reflect.Map:
if prop != nil && !prop.AssignableTo(base.Key()) && !isAny(prop) {
return v.error(node.Property, "cannot use %v to get an element from %v", prop, base)
}
t, c := deref(base.Elem())
node.Deref = c
return t, info{}
return base.Elem(), info{}

case reflect.Array, reflect.Slice:
if !isInteger(prop) && !isAny(prop) {
return v.error(node.Property, "array elements can only be selected using an integer (got %v)", prop)
}
t, c := deref(base.Elem())
node.Deref = c
return t, info{}
return base.Elem(), info{}

case reflect.Struct:
if name, ok := node.Property.(*ast.StringNode); ok {
propertyName := name.Value
if field, ok := fetchField(base, propertyName); ok {
t, c := deref(field.Type)
node.Deref = c
node.FieldIndex = field.Index
node.Name = propertyName
return t, info{}
return field.Type, info{}
}
if len(v.parents) > 1 {
if _, ok := v.parents[len(v.parents)-2].(*ast.CallNode); ok {
Expand Down
12 changes: 5 additions & 7 deletions checker/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,23 @@ func fetchField(t reflect.Type, name string) (reflect.StructField, bool) {
return reflect.StructField{}, false
}

func deref(t reflect.Type) (reflect.Type, bool) {
func deref(t reflect.Type) reflect.Type {
if t == nil {
return nil, false
return nil
}
if t.Kind() == reflect.Interface {
return t, true
return t
}
found := false
for t != nil && t.Kind() == reflect.Ptr {
e := t.Elem()
switch e.Kind() {
case reflect.Struct, reflect.Map, reflect.Array, reflect.Slice:
return t, false
return t
default:
found = true
t = e
}
}
return t, found
return t
}

func isIntegerOrArithmeticOperation(node ast.Node) bool {
Expand Down
72 changes: 52 additions & 20 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,6 @@ func (c *compiler) IdentifierNode(node *ast.IdentifierNode) {
} else {
c.emit(OpLoadConst, c.addConstant(node.Value))
}
if node.Deref {
c.emit(OpDeref)
} else if node.Type() == nil {
c.emit(OpDeref)
}
}

func (c *compiler) IntegerNode(node *ast.IntegerNode) {
Expand Down Expand Up @@ -289,6 +284,7 @@ func (c *compiler) ConstantNode(node *ast.ConstantNode) {

func (c *compiler) UnaryNode(node *ast.UnaryNode) {
c.compile(node.Node)
c.derefInNeeded(node.Node)

switch node.Operator {

Expand All @@ -313,7 +309,9 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
switch node.Operator {
case "==":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Left)

if l == r && l == reflect.Int {
c.emit(OpEqualInt)
Expand All @@ -325,114 +323,155 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {

case "!=":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Left)
c.emit(OpEqual)
c.emit(OpNot)

case "or", "||":
c.compile(node.Left)
c.derefInNeeded(node.Left)
end := c.emit(OpJumpIfTrue, placeholder)
c.emit(OpPop)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.patchJump(end)

case "and", "&&":
c.compile(node.Left)
c.derefInNeeded(node.Left)
end := c.emit(OpJumpIfFalse, placeholder)
c.emit(OpPop)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.patchJump(end)

case "<":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpLess)

case ">":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpMore)

case "<=":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpLessOrEqual)

case ">=":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpMoreOrEqual)

case "+":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpAdd)

case "-":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpSubtract)

case "*":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpMultiply)

case "/":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpDivide)

case "%":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpModulo)

case "**", "^":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpExponent)

case "in":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpIn)

case "matches":
if node.Regexp != nil {
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.emit(OpMatchesConst, c.addConstant(node.Regexp))
} else {
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpMatches)
}

case "contains":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpContains)

case "startsWith":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpStartsWith)

case "endsWith":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpEndsWith)

case "..":
c.compile(node.Left)
c.derefInNeeded(node.Left)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.emit(OpRange)

case "??":
c.compile(node.Left)
c.derefInNeeded(node.Left)
end := c.emit(OpJumpIfNotNil, placeholder)
c.emit(OpPop)
c.compile(node.Right)
c.derefInNeeded(node.Right)
c.patchJump(end)

default:
Expand Down Expand Up @@ -461,7 +500,6 @@ func (c *compiler) MemberNode(node *ast.MemberNode) {
return
}
op := OpFetch
original := node
index := node.FieldIndex
path := []string{node.Name}
base := node.Node
Expand All @@ -470,21 +508,15 @@ func (c *compiler) MemberNode(node *ast.MemberNode) {
for !node.Optional {
ident, ok := base.(*ast.IdentifierNode)
if ok && len(ident.FieldIndex) > 0 {
if ident.Deref {
panic("IdentifierNode should not be dereferenced")
}
index = append(ident.FieldIndex, index...)
path = append([]string{ident.Value}, path...)
c.emitLocation(ident.Location(), OpLoadField, c.addConstant(
&runtime.Field{Index: index, Path: path},
))
goto deref
return
}
member, ok := base.(*ast.MemberNode)
if ok && len(member.FieldIndex) > 0 {
if member.Deref {
panic("MemberNode should not be dereferenced")
}
index = append(member.FieldIndex, index...)
path = append([]string{member.Name}, path...)
node = member
Expand All @@ -509,13 +541,6 @@ func (c *compiler) MemberNode(node *ast.MemberNode) {
&runtime.Field{Index: index, Path: path},
))
}

deref:
if original.Deref {
c.emit(OpDeref)
} else if original.Type() == nil {
c.emit(OpDeref)
}
}

func (c *compiler) SliceNode(node *ast.SliceNode) {
Expand Down Expand Up @@ -734,6 +759,13 @@ func (c *compiler) PairNode(node *ast.PairNode) {
c.compile(node.Value)
}

func (c *compiler) derefInNeeded(node ast.Node) {
switch kind(node) {
case reflect.Ptr, reflect.Interface:
c.emit(OpDeref)
}
}

func kind(node ast.Node) reflect.Kind {
t := node.Type()
if t == nil {
Expand Down
Loading

0 comments on commit 0dd3702

Please sign in to comment.