Skip to content

Commit

Permalink
interp: improve support of composed interfaces
Browse files Browse the repository at this point in the history
Fixes #1260.
  • Loading branch information
mvertes authored Oct 7, 2021
1 parent 286d6c6 commit d3bbe01
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 26 deletions.
68 changes: 68 additions & 0 deletions _test/issue-1260.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package main

import (
"fmt"
"io"
"io/ioutil"
"os"
)

type WriteSyncer interface {
io.Writer
Sync() error
}

type Sink interface {
WriteSyncer
io.Closer
}

func newFileSink(path string) (Sink, error) {
return os.OpenFile(path, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666)
}

type Sink1 struct{ name string }

func (s Sink1) Write(b []byte) (int, error) { println("in Write"); return 0, nil }
func (s Sink1) Sync() error { println("in Sync"); return nil }
func (s Sink1) Close() error { println("in Close", s.name); return nil }
func newS1(name string) Sink { return Sink1{name} }
func newS1p(name string) Sink { return &Sink1{name} }

type Sink2 struct{ name string }

func (s *Sink2) Write(b []byte) (int, error) { println("in Write"); return 0, nil }
func (s *Sink2) Sync() error { println("in Sync"); return nil }
func (s *Sink2) Close() error { println("in Close", s.name); return nil }
func newS2(name string) Sink { return Sink1{name} }

func main() {
tmpfile, err := ioutil.TempFile("", "xxx")
if err != nil {
panic(err)
}
defer os.Remove(tmpfile.Name())
closers := []io.Closer{}
sink, err := newFileSink(tmpfile.Name())
if err != nil {
panic(err)
}
closers = append(closers, sink)

s1p := newS1p("ptr")
s1 := newS1("struct")
s2 := newS2("ptr2")
closers = append(closers, s1p, s1, s2)
for _, closer := range closers {
fmt.Println(closer.Close())
}
}

// Output:
// <nil>
// in Close ptr
// <nil>
// in Close struct
// <nil>
// in Close ptr2
// <nil>
72 changes: 49 additions & 23 deletions interp/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -952,20 +952,24 @@ func genFunctionWrapper(n *node) func(*frame) reflect.Value {
d[i] = reflect.New(t).Elem()
}

// Copy method receiver as first argument, if defined.
if rcvr != nil {
if rcvr == nil {
d = d[numRet:]
} else {
// Copy method receiver as first argument.
src, dest := rcvr(f), d[numRet]
if src.Type().Kind() != dest.Type().Kind() {
sk, dk := src.Type().Kind(), dest.Type().Kind()
switch {
case sk == reflect.Ptr && dk != reflect.Ptr:
dest.Set(src.Elem())
case sk != reflect.Ptr && dk == reflect.Ptr:
dest.Set(src.Addr())
} else {
default:
if wrappedSrc, ok := src.Interface().(valueInterface); ok {
src = wrappedSrc.value
}
dest.Set(src)
}
d = d[numRet+1:]
} else {
d = d[numRet:]
}

// Copy function input arguments in local frame.
Expand Down Expand Up @@ -1034,32 +1038,38 @@ func genInterfaceWrapper(n *node, typ reflect.Type) func(*frame) reflect.Value {
if tc != structT && v.Type().Implements(typ) {
return v
}
vv := v
switch v.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
if v.IsNil() {
return reflect.New(typ).Elem()
}
if v.Kind() == reflect.Ptr {
vv = v.Elem()
}
}
var n2 *node
if vi, ok := v.Interface().(valueInterface); ok {
n2 = vi.node
}
v = getConcreteValue(v)
w := reflect.New(wrap).Elem()
w.Field(0).Set(v)
for i, m := range methods {
if m == nil {
if r := v.MethodByName(names[i]); r.IsValid() {
// First direct method lookup on field.
if r := methodByName(v, names[i]); r.IsValid() {
w.Field(i + 1).Set(r)
continue
}
o := vv.FieldByIndex(indexes[i])
if r := o.MethodByName(names[i]); r.IsValid() {
w.Field(i + 1).Set(r)
} else {
if n2 == nil {
panic(n.cfgErrorf("method not found: %s", names[i]))
}
continue
// Method lookup in embedded valueInterface.
m2, i2 := n2.typ.lookupMethod(names[i])
if m2 != nil {
nod := *m2
nod.recv = &receiver{n, v, i2}
w.Field(i + 1).Set(genFunctionWrapper(&nod)(f))
continue
}
panic(n.cfgErrorf("method not found: %s", names[i]))
}
nod := *m
nod.recv = &receiver{n, v, indexes[i]}
Expand All @@ -1069,6 +1079,17 @@ func genInterfaceWrapper(n *node, typ reflect.Type) func(*frame) reflect.Value {
}
}

// methodByName return the method corresponding to name on value, or nil if not found.
// The search is extended on valueInterface wrapper if present.
func methodByName(value reflect.Value, name string) reflect.Value {
if vi, ok := value.Interface().(valueInterface); ok {
if v := getConcreteValue(vi.value).MethodByName(name); v.IsValid() {
return v
}
}
return value.MethodByName(name)
}

func call(n *node) {
goroutine := n.anc.kind == goStmt
var method bool
Expand Down Expand Up @@ -1492,12 +1513,13 @@ func callBin(n *node) {
rvalues := make([]func(*frame) reflect.Value, funcType.NumOut())
for i := range rvalues {
c := n.anc.child[i]
if c.ident != "_" {
if isInterfaceSrc(c.typ) {
rvalues[i] = genValueInterfaceValue(c)
} else {
rvalues[i] = genValue(c)
}
if c.ident == "_" {
continue
}
if isInterfaceSrc(c.typ) {
rvalues[i] = genValueInterfaceValue(c)
} else {
rvalues[i] = genValue(c)
}
}
n.exec = func(f *frame) bltn {
Expand All @@ -1524,7 +1546,11 @@ func callBin(n *node) {
}
out := callFn(value(f), in)
for i, v := range out {
f.data[b+i].Set(v)
dest := f.data[b+i]
if _, ok := dest.Interface().(valueInterface); ok {
v = reflect.ValueOf(valueInterface{value: v})
}
dest.Set(v)
}
return tnext
}
Expand Down
12 changes: 9 additions & 3 deletions interp/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,14 @@ func genValueBinMethodOnInterface(n *node, defaultGen func(*frame) reflect.Value
}

func genValueRecvIndirect(n *node) func(*frame) reflect.Value {
v := genValueRecv(n)
return func(f *frame) reflect.Value { return v(f).Elem() }
vr := genValueRecv(n)
return func(f *frame) reflect.Value {
v := vr(f)
if vi, ok := v.Interface().(valueInterface); ok {
return vi.value
}
return v.Elem()
}
}

func genValueRecv(n *node) func(*frame) reflect.Value {
Expand Down Expand Up @@ -312,7 +318,7 @@ func genValueInterface(n *node) func(*frame) reflect.Value {
}

// empty interface, do not wrap.
if nod.typ.cat == interfaceT && len(nod.typ.field) == 0 {
if nod != nil && isEmptyInterface(nod.typ) {
return v
}

Expand Down

0 comments on commit d3bbe01

Please sign in to comment.