Skip to content

Commit

Permalink
interp: add a function to directly compile Go AST
Browse files Browse the repository at this point in the history
Adds CompileAST, which can be used to compile Go AST directly. This
allows users to delegate parsing of source to their own code instead of
relying on the interpreter.

CLoses #1251
  • Loading branch information
firelizzard18 authored Sep 23, 2021
1 parent c5c6012 commit 808f0bd
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 30 deletions.
48 changes: 24 additions & 24 deletions interp/ast.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package interp

import (
"errors"
"fmt"
"go/ast"
"go/constant"
Expand Down Expand Up @@ -362,21 +361,14 @@ func wrapInMain(src string) string {
return fmt.Sprintf("package main; func main() {%s\n}", src)
}

// Note: no type analysis is performed at this stage, it is done in pre-order
// processing of CFG, in order to accommodate forward type declarations.

// ast parses src string containing Go code and generates the corresponding AST.
// The package name and the AST root node are returned.
// The given name is used to set the filename of the relevant source file in the
// interpreter's FileSet.
func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error) {
var inFunc bool
func (interp *Interpreter) parse(src, name string, inc bool) (node ast.Node, err error) {
mode := parser.DeclarationErrors

// Allow incremental parsing of declarations or statements, by inserting
// them in a pseudo file package or function. Those statements or
// declarations will be always evaluated in the global scope.
var tok token.Token
var inFunc bool
if inc {
tok = interp.firstToken(src)
switch tok {
Expand All @@ -393,35 +385,51 @@ func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error
}

if ok, err := interp.buildOk(&interp.context, name, src); !ok || err != nil {
return "", nil, err // skip source not matching build constraints
return nil, err // skip source not matching build constraints
}

f, err := parser.ParseFile(interp.fset, name, src, mode)
if err != nil {
// only retry if we're on an expression/statement about a func
if !inc || tok != token.FUNC {
return "", nil, err
return nil, err
}
// do not bother retrying if we know it's an error we're going to ignore later on.
if ignoreError(err, src) {
return "", nil, err
return nil, err
}
// do not lose initial error, in case retrying fails.
initialError := err
// retry with default source code "wrapping", in the main function scope.
src := wrapInMain(strings.TrimPrefix(src, "package main;"))
f, err = parser.ParseFile(interp.fset, name, src, mode)
if err != nil {
return "", nil, initialError
return nil, initialError
}
}

if inFunc {
// return the body of the wrapper main function
return f.Decls[0].(*ast.FuncDecl).Body, nil
}

setYaegiTags(&interp.context, f.Comments)
return f, nil
}

// Note: no type analysis is performed at this stage, it is done in pre-order
// processing of CFG, in order to accommodate forward type declarations.

// ast parses src string containing Go code and generates the corresponding AST.
// The package name and the AST root node are returned.
// The given name is used to set the filename of the relevant source file in the
// interpreter's FileSet.
func (interp *Interpreter) ast(f ast.Node) (string, *node, error) {
var err error
var root *node
var anc astNode
var st nodestack
var pkgName string
pkgName := "main"

addChild := func(root **node, anc astNode, pos token.Pos, kind nkind, act action) *node {
var i interface{}
Expand Down Expand Up @@ -898,15 +906,7 @@ func (interp *Interpreter) ast(src, name string, inc bool) (string, *node, error
}
return true
})
if inFunc {
// Incremental parsing: statements were inserted in a pseudo function.
// Set root to function body so its statements are evaluated in global scope.
root = root.child[1].child[3]
root.anc = nil
}
if pkgName == "" {
return "", root, errors.New("no package name found")
}

interp.roots = append(interp.roots, root)
return pkgName, root, err
}
Expand Down
84 changes: 84 additions & 0 deletions interp/compile_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package interp

import (
"go/ast"
"go/parser"
"go/token"
"testing"

"github.com/traefik/yaegi/stdlib"
)

func TestCompileAST(t *testing.T) {
file, err := parser.ParseFile(token.NewFileSet(), "_.go", `
package main
import "fmt"
type Foo struct{}
var foo Foo
const bar = "asdf"
func main() {
fmt.Println(1)
}
`, 0)
if err != nil {
panic(err)
}
if len(file.Imports) != 1 || len(file.Decls) != 5 {
panic("wrong number of imports or decls")
}

dType := file.Decls[1].(*ast.GenDecl)
dVar := file.Decls[2].(*ast.GenDecl)
dConst := file.Decls[3].(*ast.GenDecl)
dFunc := file.Decls[4].(*ast.FuncDecl)

if dType.Tok != token.TYPE {
panic("decl[1] is not a type")
}
if dVar.Tok != token.VAR {
panic("decl[2] is not a var")
}
if dConst.Tok != token.CONST {
panic("decl[3] is not a const")
}

cases := []struct {
desc string
node ast.Node
skip string
}{
{desc: "file", node: file},
{desc: "import", node: file.Imports[0]},
{desc: "type", node: dType},
{desc: "var", node: dVar, skip: "not supported"},
{desc: "const", node: dConst},
{desc: "func", node: dFunc},
{desc: "block", node: dFunc.Body},
{desc: "expr", node: dFunc.Body.List[0]},
}

i := New(Options{})
_ = i.Use(stdlib.Symbols)

for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
if c.skip != "" {
t.Skip(c.skip)
}

i := i
if _, ok := c.node.(*ast.File); ok {
i = New(Options{})
_ = i.Use(stdlib.Symbols)
}
_, err := i.CompileAST(c.node)
if err != nil {
t.Fatalf("Failed to compile %s: %v", c.desc, err)
}
})
}
}
2 changes: 1 addition & 1 deletion interp/interp.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ func isFile(filesystem fs.FS, path string) bool {
}

func (interp *Interpreter) eval(src, name string, inc bool) (res reflect.Value, err error) {
prog, err := interp.compile(src, name, inc)
prog, err := interp.compileSrc(src, name, inc)
if err != nil {
return res, err
}
Expand Down
22 changes: 18 additions & 4 deletions interp/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package interp

import (
"context"
"go/ast"
"io/ioutil"
"reflect"
"runtime"
Expand All @@ -17,7 +18,7 @@ type Program struct {

// Compile parses and compiles a Go code represented as a string.
func (interp *Interpreter) Compile(src string) (*Program, error) {
return interp.compile(src, "", true)
return interp.compileSrc(src, "", true)
}

// CompilePath parses and compiles a Go code located at the given path.
Expand All @@ -31,10 +32,10 @@ func (interp *Interpreter) CompilePath(path string) (*Program, error) {
if err != nil {
return nil, err
}
return interp.compile(string(b), path, false)
return interp.compileSrc(string(b), path, false)
}

func (interp *Interpreter) compile(src, name string, inc bool) (*Program, error) {
func (interp *Interpreter) compileSrc(src, name string, inc bool) (*Program, error) {
if name != "" {
interp.name = name
}
Expand All @@ -43,7 +44,20 @@ func (interp *Interpreter) compile(src, name string, inc bool) (*Program, error)
}

// Parse source to AST.
pkgName, root, err := interp.ast(src, interp.name, inc)
n, err := interp.parse(src, interp.name, inc)
if err != nil {
return nil, err
}

return interp.CompileAST(n)
}

// CompileAST builds a Program for the given Go code AST. Files and block
// statements can be compiled, as can most expressions. Var declaration nodes
// cannot be compiled.
func (interp *Interpreter) CompileAST(n ast.Node) (*Program, error) {
// Convert AST.
pkgName, root, err := interp.ast(n)
if err != nil || root == nil {
return nil, err
}
Expand Down
10 changes: 9 additions & 1 deletion interp/src.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,16 @@ func (interp *Interpreter) importSrc(rPath, importPath string, skipTest bool) (s
return "", err
}

n, err := interp.parse(string(buf), name, false)
if err != nil {
return "", err
}
if n == nil {
continue
}

var pname string
if pname, root, err = interp.ast(string(buf), name, false); err != nil {
if pname, root, err = interp.ast(n); err != nil {
return "", err
}
if root == nil {
Expand Down

0 comments on commit 808f0bd

Please sign in to comment.