Skip to content

Commit

Permalink
Merge branch 'master' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
matryer authored Mar 7, 2018
2 parents 1b1675c + 9e83319 commit ee06b7e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 222 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package main

import (
"bytes"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"errors"

"github.com/matryer/moq/pkg/moq"
)
Expand Down
186 changes: 0 additions & 186 deletions pkg/moq/importer.go

This file was deleted.

129 changes: 94 additions & 35 deletions pkg/moq/moq.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ import (
"go/types"
"io"
"os"
"path"
"path/filepath"
"strings"
"text/template"

"golang.org/x/tools/go/loader"
)

// This list comes from the golint codebase. Golint will complain about any of
Expand Down Expand Up @@ -75,11 +79,13 @@ func New(src, packageName string) (*Mocker, error) {
noTestFiles := func(i os.FileInfo) bool {
return !strings.HasSuffix(i.Name(), "_test.go")
}

pkgs, err := parser.ParseDir(fset, src, noTestFiles, parser.SpuriousErrors)
if err != nil {
return nil, err
}
if len(packageName) == 0 {

for pkgName := range pkgs {
if strings.Contains(pkgName, "_test") {
continue
Expand Down Expand Up @@ -110,57 +116,56 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error {
if len(name) == 0 {
return errors.New("must specify one interface")
}

pkgInfo, err := m.pkgInfoFromPath(m.src)
if err != nil {
return err
}

doc := doc{
PackageName: m.pkgName,
Imports: moqImports,
}

mocksMethods := false
for _, pkg := range m.pkgs {
i := 0
files := make([]*ast.File, len(pkg.Files))
for _, file := range pkg.Files {
files[i] = file
i++

tpkg := pkgInfo.Pkg
for _, n := range name {
iface := tpkg.Scope().Lookup(n)
if iface == nil {
return fmt.Errorf("cannot find interface %s", n)
}
conf := types.Config{Importer: newImporter(m.src)}
tpkg, err := conf.Check(m.src, m.fset, files, nil)
if err != nil {
return err
if !types.IsInterface(iface.Type()) {
return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String())
}
for _, n := range name {
iface := tpkg.Scope().Lookup(n)
if iface == nil {
return fmt.Errorf("cannot find interface %s", n)
}
if !types.IsInterface(iface.Type()) {
return fmt.Errorf("%s (%s) not an interface", n, iface.Type().String())
}
iiface := iface.Type().Underlying().(*types.Interface).Complete()
obj := obj{
InterfaceName: n,
}
for i := 0; i < iiface.NumMethods(); i++ {
mocksMethods = true
meth := iiface.Method(i)
sig := meth.Type().(*types.Signature)
method := &method{
Name: meth.Name(),
}
obj.Methods = append(obj.Methods, method)
method.Params = m.extractArgs(sig, sig.Params(), "in%d")
method.Returns = m.extractArgs(sig, sig.Results(), "out%d")
iiface := iface.Type().Underlying().(*types.Interface).Complete()
obj := obj{
InterfaceName: n,
}
for i := 0; i < iiface.NumMethods(); i++ {
mocksMethods = true
meth := iiface.Method(i)
sig := meth.Type().(*types.Signature)
method := &method{
Name: meth.Name(),
}
doc.Objects = append(doc.Objects, obj)
obj.Methods = append(obj.Methods, method)
method.Params = m.extractArgs(sig, sig.Params(), "in%d")
method.Returns = m.extractArgs(sig, sig.Results(), "out%d")
}
doc.Objects = append(doc.Objects, obj)
}

if mocksMethods {
doc.Imports = append(doc.Imports, "sync")
}

for pkgToImport := range m.imports {
doc.Imports = append(doc.Imports, pkgToImport)
doc.Imports = append(doc.Imports, stripVendorPath(pkgToImport))
}

var buf bytes.Buffer
err := m.tmpl.Execute(&buf, doc)
err = m.tmpl.Execute(&buf, doc)
if err != nil {
return err
}
Expand Down Expand Up @@ -211,6 +216,32 @@ func (m *Mocker) extractArgs(sig *types.Signature, list *types.Tuple, nameFormat
return params
}

func (*Mocker) pkgInfoFromPath(src string) (*loader.PackageInfo, error) {

abs, err := filepath.Abs(src)
if err != nil {
return nil, err
}
pkgFull := stripGopath(abs)

conf := loader.Config{
ParserMode: parser.SpuriousErrors,
Cwd: src,
}
conf.Import(pkgFull)
lprog, err := conf.Load()
if err != nil {
return nil, err
}

pkgInfo := lprog.Package(pkgFull)
if pkgInfo == nil {
return nil, errors.New("package was nil")
}

return pkgInfo, nil
}

type doc struct {
PackageName string
Objects []obj
Expand Down Expand Up @@ -291,3 +322,31 @@ var templateFuncs = template.FuncMap{
return strings.ToUpper(s[0:1]) + s[1:]
},
}

// stripVendorPath strips the vendor dir prefix from a package path.
// For example we might encounter an absolute path like
// github.com/foo/bar/vendor/github.com/pkg/errors which is resolved
// to github.com/pkg/errors.
func stripVendorPath(p string) string {
parts := strings.Split(p, "/vendor/")
if len(parts) == 1 {
return p
}
return strings.TrimLeft(path.Join(parts[1:]...), "/")
}

// stripGopath takes the directory to a package and remove the gopath to get the
// canonical package name.
//
// taken from https://github.com/ernesto-jimenez/gogen
// Copyright (c) 2015 Ernesto Jiménez
func stripGopath(p string) string {
for _, gopath := range gopaths() {
p = strings.TrimPrefix(p, path.Join(gopath, "src")+"/")
}
return p
}

func gopaths() []string {
return strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator))
}

0 comments on commit ee06b7e

Please sign in to comment.