Skip to content

Commit

Permalink
rework reflection detection with ssa
Browse files Browse the repository at this point in the history
This is significantly more robust, than the ast based detection and can
record very complex cases of indirect parameter reflection.

Fixes #554
  • Loading branch information
lu4p committed May 4, 2023
1 parent b0ff2fb commit 299918a
Show file tree
Hide file tree
Showing 5 changed files with 569 additions and 259 deletions.
4 changes: 4 additions & 0 deletions go_std_tables.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

291 changes: 34 additions & 257 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ import (
"unicode"
"unicode/utf8"

"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/mod/module"
"golang.org/x/mod/semver"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/ssa"

"mvdan.cc/garble/internal/linker"
"mvdan.cc/garble/internal/literals"
)
Expand Down Expand Up @@ -904,7 +905,27 @@ func transformCompile(args []string) ([]string, error) {
return nil, err
}

tf.findReflectFunctions(files)
ssaProg := ssa.NewProgram(fset, 0)

// Create SSA packages for all imports.
// Order is not significant.
created := make(map[*types.Package]bool)
var createAll func(pkgs []*types.Package)
createAll = func(pkgs []*types.Package) {
for _, p := range pkgs {
if !created[p] {
created[p] = true
ssaProg.CreatePackage(p, nil, nil, true)
createAll(p.Imports())
}
}
}
createAll(tf.pkg.Imports())

ssaPkg := ssaProg.CreatePackage(tf.pkg, files, tf.info, false)
ssaPkg.Build()

tf.recordReflection(ssaPkg)
newImportCfg, err := processImportCfg(flags)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1240,11 +1261,6 @@ type (
funcFullName = string // as per go/types.Func.FullName
objectString = string // as per recordedObjectString

reflectParameter struct {
Position int // 0-indexed
Variadic bool // ...int
}

typeName struct {
PkgPath, Name string
}
Expand All @@ -1267,7 +1283,7 @@ var cachedOutput = struct {
//
// TODO: we're not including fmt.Printf, as it would have many false positives,
// unless we were smart enough to detect which arguments get used as %#v or %T.
KnownReflectAPIs map[funcFullName][]reflectParameter
KnownReflectAPIs map[funcFullName]map[int]bool

// KnownCannotObfuscate is filled with the fully qualified names from each
// package that we cannot obfuscate.
Expand All @@ -1283,9 +1299,9 @@ var cachedOutput = struct {
// bearing in mind that it may be owned by a different package.
KnownEmbeddedAliasFields map[objectString]typeName
}{
KnownReflectAPIs: map[funcFullName][]reflectParameter{
"reflect.TypeOf": {{Position: 0, Variadic: false}},
"reflect.ValueOf": {{Position: 0, Variadic: false}},
KnownReflectAPIs: map[funcFullName]map[int]bool{
"reflect.TypeOf": {0: true},
"reflect.ValueOf": {0: true},
},
KnownCannotObfuscate: map[objectString]struct{}{},
KnownEmbeddedAliasFields: map[objectString]typeName{},
Expand Down Expand Up @@ -1344,90 +1360,6 @@ func loadCachedOutputs() error {
return nil
}

func (tf *transformer) findReflectFunctions(files []*ast.File) {
seenReflectParams := make(map[*types.Var]bool)
visitFuncDecl := func(funcDecl *ast.FuncDecl) {
funcObj := tf.info.Defs[funcDecl.Name].(*types.Func)
funcType := funcObj.Type().(*types.Signature)
funcParams := funcType.Params()

maps.Clear(seenReflectParams)
for i := 0; i < funcParams.Len(); i++ {
seenReflectParams[funcParams.At(i)] = false
}

ast.Inspect(funcDecl, func(node ast.Node) bool {
call, ok := node.(*ast.CallExpr)
if !ok {
return true
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return true
}
calledFunc, _ := tf.info.Uses[sel.Sel].(*types.Func)
if calledFunc == nil || calledFunc.Pkg() == nil {
return true
}

fullName := calledFunc.FullName()
for _, reflectParam := range cachedOutput.KnownReflectAPIs[fullName] {
// We need a range to handle any number of variadic arguments,
// which could be 0 or multiple.
// The non-variadic case is always one argument,
// but we still use the range to deduplicate code.
argStart := reflectParam.Position
argEnd := argStart + 1
if reflectParam.Variadic {
argEnd = len(call.Args)
}
for _, arg := range call.Args[argStart:argEnd] {
ident, ok := arg.(*ast.Ident)
if !ok {
continue
}
obj, _ := tf.info.Uses[ident].(*types.Var)
if obj == nil {
continue
}
if _, ok := seenReflectParams[obj]; ok {
seenReflectParams[obj] = true
}
}
}

var reflectParams []reflectParameter
for i := 0; i < funcParams.Len(); i++ {
if seenReflectParams[funcParams.At(i)] {
reflectParams = append(reflectParams, reflectParameter{
Position: i,
Variadic: funcType.Variadic() && i == funcParams.Len()-1,
})
}
}
if len(reflectParams) > 0 {
cachedOutput.KnownReflectAPIs[funcObj.FullName()] = reflectParams
}

return true
})
}

lenPrevKnownReflectAPIs := len(cachedOutput.KnownReflectAPIs)
for _, file := range files {
for _, decl := range file.Decls {
if decl, ok := decl.(*ast.FuncDecl); ok {
visitFuncDecl(decl)
}
}
}

// if a new reflectAPI is found we need to Re-evaluate all functions which might be using that API
if len(cachedOutput.KnownReflectAPIs) > lenPrevKnownReflectAPIs {
tf.findReflectFunctions(files)
}
}

// cmd/bundle will include a go:generate directive in its output by default.
// Ours specifies a version and doesn't assume bundle is in $PATH, so drop it.

Expand Down Expand Up @@ -1480,46 +1412,6 @@ func (tf *transformer) prefillObjectMaps(files []*ast.File) error {
}
tf.linkerVariableStrings[obj] = stringValue
})

visit := func(node ast.Node) bool {
call, ok := node.(*ast.CallExpr)
if !ok {
return true
}

ident, ok := call.Fun.(*ast.Ident)
if !ok {
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return true
}

ident = sel.Sel
}

fnType, _ := tf.info.Uses[ident].(*types.Func)
if fnType == nil || fnType.Pkg() == nil {
return true
}

fullName := fnType.FullName()
for _, reflectParam := range cachedOutput.KnownReflectAPIs[fullName] {
argStart := reflectParam.Position
argEnd := argStart + 1
if reflectParam.Variadic {
argEnd = len(call.Args)
}
for _, arg := range call.Args[argStart:argEnd] {
argType := tf.info.TypeOf(arg)
tf.recursivelyRecordAsNotObfuscated(argType)
}
}

return true
}
for _, file := range files {
ast.Inspect(file, visit)
}
return nil
}

Expand Down Expand Up @@ -1548,10 +1440,13 @@ type transformer struct {
func newTransformer() *transformer {
return &transformer{
info: &types.Info{
Types: make(map[ast.Expr]types.TypeAndValue),
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
Implicits: make(map[ast.Node]types.Object),
Types: make(map[ast.Expr]types.TypeAndValue),
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
Implicits: make(map[ast.Node]types.Object),
Scopes: make(map[ast.Node]*types.Scope),
Selections: make(map[*ast.SelectorExpr]*types.Selection),
Instances: make(map[*ast.Ident]types.Instance),
},
recordTypeDone: make(map[*types.Named]bool),
fieldToStruct: make(map[*types.Var]*types.Struct),
Expand Down Expand Up @@ -1648,80 +1543,6 @@ func (tf *transformer) recordType(used, origin types.Type) {
}
}

// TODO: consider caching recordedObjectString via a map,
// if that shows an improvement in our benchmark

func recordedObjectString(obj types.Object) objectString {
pkg := obj.Pkg()
if obj, ok := obj.(*types.Var); ok && obj.IsField() {
// For exported fields, "pkgpath.Field" is not unique,
// because two exported top-level types could share "Field".
//
// Moreover, note that not all fields belong to named struct types;
// an API could be exposing:
//
// var usedInReflection = struct{Field string}
//
// For now, a hack: assume that packages don't declare the same field
// more than once in the same line. This works in practice, but one
// could craft Go code to break this assumption.
// Also note that the compiler's object files include filenames and line
// numbers, but not column numbers nor byte offsets.
// TODO(mvdan): give this another think, and add tests involving anon types.
pos := fset.Position(obj.Pos())
return fmt.Sprintf("%s.%s - %s:%d", pkg.Path(), obj.Name(),
filepath.Base(pos.Filename), pos.Line)
}
// Names which are not at the top level cannot be imported,
// so we don't need to record them either.
// Note that this doesn't apply to fields, which are never top-level.
if pkg.Scope() != obj.Parent() {
return ""
}
// For top-level exported names, "pkgpath.Name" is unique.
return pkg.Path() + "." + obj.Name()
}

// recordAsNotObfuscated records all the objects whose names we cannot obfuscate.
// An object is any named entity, such as a declared variable or type.
//
// As of June 2022, this only records types which are used in reflection.
// TODO(mvdan): If this is still the case in a year's time,
// we should probably rename "not obfuscated" and "cannot obfuscate" to be
// directly about reflection, e.g. "used in reflection".
func recordAsNotObfuscated(obj types.Object) {
if obj.Pkg().Path() != curPkg.ImportPath {
panic("called recordedAsNotObfuscated with a foreign object")
}
if !obj.Exported() {
// Unexported names will never be used by other packages,
// so we don't need to bother recording them in cachedOutput.
knownCannotObfuscateUnexported[obj] = true
return
}

objStr := recordedObjectString(obj)
if objStr == "" {
// If the object can't be described via a qualified string,
// then other packages can't use it.
// TODO: should we still record it in knownCannotObfuscateUnexported?
return
}
cachedOutput.KnownCannotObfuscate[objStr] = struct{}{}
}

func recordedAsNotObfuscated(obj types.Object) bool {
if knownCannotObfuscateUnexported[obj] {
return true
}
objStr := recordedObjectString(obj)
if objStr == "" {
return false
}
_, ok := cachedOutput.KnownCannotObfuscate[objStr]
return ok
}

// isSafeForInstanceType returns true if the passed type is safe for var declaration.
// Unsafe types: generic types and non-method interfaces.
func isSafeForInstanceType(typ types.Type) bool {
Expand Down Expand Up @@ -2058,50 +1879,6 @@ func (tf *transformer) transformGoFile(file *ast.File) *ast.File {
return astutil.Apply(file, pre, post).(*ast.File)
}

// recursivelyRecordAsNotObfuscated calls recordAsNotObfuscated on any named
// types and fields under typ.
//
// Only the names declared in the current package are recorded. This is to ensure
// that reflection detection only happens within the package declaring a type.
// Detecting it in downstream packages could result in inconsistencies.
func (tf *transformer) recursivelyRecordAsNotObfuscated(t types.Type) {
switch t := t.(type) {
case *types.Named:
obj := t.Obj()
if pkg := obj.Pkg(); pkg == nil || pkg != tf.pkg {
return // not from the specified package
}
if recordedAsNotObfuscated(obj) {
return // prevent endless recursion
}
recordAsNotObfuscated(obj)

// Record the underlying type, too.
tf.recursivelyRecordAsNotObfuscated(t.Underlying())

case *types.Struct:
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)

// This check is similar to the one in *types.Named.
// It's necessary for unnamed struct types,
// as they aren't named but still have named fields.
if field.Pkg() == nil || field.Pkg() != tf.pkg {
return // not from the specified package
}

// Record the field itself, too.
recordAsNotObfuscated(field)

tf.recursivelyRecordAsNotObfuscated(field.Type())
}

case interface{ Elem() types.Type }:
// Get past pointers, slices, etc.
tf.recursivelyRecordAsNotObfuscated(t.Elem())
}
}

// named tries to obtain the *types.Named behind a type, if there is one.
// This is useful to obtain "testing.T" from "*testing.T", or to obtain the type
// declaration object from an embedded field.
Expand Down
Loading

0 comments on commit 299918a

Please sign in to comment.