Skip to content

Commit

Permalink
internal/vulncheck: manipulate packages from PackageGraph
Browse files Browse the repository at this point in the history
Packages obtained by querying PackageGraph are fixed up. Most notably,
this means that stdlib packages will have stdlib module set. We should
then use these packages. We in fact do, at least for the parts that
matter, but this CL tries to refactor code so that is made explicit.

Change-Id: I194b819ec40eba6726a68be7766ec220b80ec2f8
Reviewed-on: https://go-review.googlesource.com/c/vuln/+/564155
Run-TryBot: Zvonimir Pavlinovic <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Maceo Thompson <[email protected]>
TryBot-Result: Gopher Robot <[email protected]>
  • Loading branch information
zpavlinovic committed Feb 15, 2024
1 parent 27078ae commit bb77557
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 122 deletions.
76 changes: 27 additions & 49 deletions internal/scan/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,51 +29,48 @@ func runSource(ctx context.Context, handler govulncheck.Handler, cfg *config, cl
if !gomodExists(dir) {
return errNoGoMod
}
var pkgs []*packages.Package
var mods []*packages.Module
graph := vulncheck.NewPackageGraph(cfg.GoVersion)
pkgConfig := &packages.Config{
Dir: dir,
Tests: cfg.test,
Env: cfg.env,
}
pkgs, mods, err = graph.LoadPackagesAndMods(pkgConfig, cfg.tags, cfg.patterns)
if err != nil {
if err := graph.LoadPackagesAndMods(pkgConfig, cfg.tags, cfg.patterns); err != nil {
if isGoVersionMismatchError(err) {
return fmt.Errorf("%v\n\n%v", errGoVersionMismatch, err)
}
return fmt.Errorf("loading packages: %w", err)
}

if err := handler.Progress(sourceProgressMessage(pkgs, len(mods)-1, cfg.ScanLevel)); err != nil {
if err := handler.Progress(sourceProgressMessage(graph, cfg.ScanLevel)); err != nil {
return err
}

if cfg.ScanLevel.WantPackages() && len(pkgs) == 0 {
if cfg.ScanLevel.WantPackages() && len(graph.TopPkgs()) == 0 {
return nil // early exit
}
return vulncheck.Source(ctx, handler, pkgs, mods, &cfg.Config, client, graph)
return vulncheck.Source(ctx, handler, &cfg.Config, client, graph)
}

// sourceProgressMessage returns a string of the form
//
// "Scanning your code and P packages across M dependent modules for known vulnerabilities..."
//
// P is the number of strictly dependent packages of
// topPkgs and Y is the number of their modules. If P
// is 0, then the following message is returned
// graph.TopPkgs() and Y is the number of their modules.
// If P is 0, then the following message is returned
//
// "No packages matching the provided pattern."
func sourceProgressMessage(topPkgs []*packages.Package, mods int, mode govulncheck.ScanLevel) *govulncheck.Progress {
func sourceProgressMessage(graph *vulncheck.PackageGraph, mode govulncheck.ScanLevel) *govulncheck.Progress {
var pkgsPhrase, modsPhrase string

mods := uniqueAnalyzableMods(graph)
if mode.WantPackages() {
if len(topPkgs) == 0 {
if len(graph.TopPkgs()) == 0 {
// The package pattern is valid, but no packages are matching.
// Example is pkg/strace/... (see #59623).
return &govulncheck.Progress{Message: "No packages matching the provided pattern."}
}
pkgs := depPkgs(topPkgs)
pkgs := len(graph.DepPkgs())
pkgsPhrase = fmt.Sprintf(" and %d package%s", pkgs, choose(pkgs != 1, "s", ""))
}
modsPhrase = fmt.Sprintf(" %d dependent module%s", mods, choose(mods != 1, "s", ""))
Expand All @@ -82,43 +79,24 @@ func sourceProgressMessage(topPkgs []*packages.Package, mods int, mode govulnche
return &govulncheck.Progress{Message: msg}
}

// depPkgs returns the number of packages that topPkgs depend on
func depPkgs(topPkgs []*packages.Package) int {
tops := make(map[string]bool)
depPkgs := make(map[string]bool)

for _, t := range topPkgs {
tops[t.PkgPath] = true
}

var visit func(*packages.Package, bool)
visit = func(p *packages.Package, top bool) {
path := p.PkgPath
if depPkgs[path] {
return
}
if tops[path] && !top {
// A top package that is a dependency
// will not be in depPkgs, so we skip
// reiterating on it here.
return
// uniqueAnalyzableMods returns the number of unique modules
// that are analyzable. Those are basically all modules except
// those that are replaced. The latter won't be analyzed as
// their code is never reachable.
func uniqueAnalyzableMods(graph *vulncheck.PackageGraph) int {
replaced := 0
mods := graph.Modules()
for _, m := range mods {
if m.Replace == nil {
continue
}

// We don't count a top-level package as
// a dependency even when they are used
// as a dependent package.
if !tops[path] {
depPkgs[path] = true
}

for _, d := range p.Imports {
visit(d, false)
if m.Path == m.Replace.Path {
// If the replacing path is the same as
// the one being replaced, then only one
// of these modules is in mods.
continue
}
replaced++
}

for _, t := range topPkgs {
visit(t, true)
}

return len(depPkgs)
return len(mods) - replaced - 1 // don't include stdlib
}
127 changes: 86 additions & 41 deletions internal/vulncheck/packages.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ import (
)

// PackageGraph holds a complete module and package graph.
// Its primary purpose is to allow fast access to the nodes by path.
// Its primary purpose is to allow fast access to the nodes
// by path and make sure all(stdlib) packages have a module.
type PackageGraph struct {
modules map[string]*packages.Module
packages map[string]*packages.Package
// topPkgs are top-level packages specified by the user.
// Empty in binary mode.
topPkgs []*packages.Package
modules map[string]*packages.Module // all modules (even replacing ones)
packages map[string]*packages.Package // all packages (even dependencies)
}

func NewPackageGraph(goVersion string) *PackageGraph {
Expand All @@ -40,8 +44,69 @@ func NewPackageGraph(goVersion string) *PackageGraph {
return graph
}

func (g *PackageGraph) TopPkgs() []*packages.Package {
return g.topPkgs
}

// DepPkgs returns the number of packages that graph.TopPkgs()
// strictly depend on. This does not include topPkgs even if
// they are dependency of each other.
func (g *PackageGraph) DepPkgs() []*packages.Package {
topPkgs := g.TopPkgs()
tops := make(map[string]bool)
depPkgs := make(map[string]*packages.Package)

for _, t := range topPkgs {
tops[t.PkgPath] = true
}

var visit func(*packages.Package, bool)
visit = func(p *packages.Package, top bool) {
path := p.PkgPath
if _, ok := depPkgs[path]; ok {
return
}
if tops[path] && !top {
// A top package that is a dependency
// will not be in depPkgs, so we skip
// reiterating on it here.
return
}

// We don't count a top-level package as
// a dependency even when they are used
// as a dependent package.
if !tops[path] {
depPkgs[path] = p
}

for _, d := range p.Imports {
visit(d, false)
}
}

for _, t := range topPkgs {
visit(t, true)
}

var deps []*packages.Package
for _, d := range depPkgs {
deps = append(deps, g.GetPackage(d.PkgPath))
}
return deps
}

func (g *PackageGraph) Modules() []*packages.Module {
var mods []*packages.Module
for _, m := range g.modules {
mods = append(mods, m)
}
return mods
}

// AddModules adds the modules and any replace modules provided.
// It will ignore modules that have duplicate paths to ones the graph already holds.
// It will ignore modules that have duplicate paths to ones the
// graph already holds.
func (g *PackageGraph) AddModules(mods ...*packages.Module) {
for _, mod := range mods {
if _, found := g.modules[mod.Path]; found {
Expand All @@ -55,7 +120,8 @@ func (g *PackageGraph) AddModules(mods ...*packages.Module) {
}
}

// .
// GetModule gets module at path if one exists. Otherwise,
// it creates a module and returns it.
func (g *PackageGraph) GetModule(path string) *packages.Module {
if mod, ok := g.modules[path]; ok {
return mod
Expand All @@ -68,8 +134,9 @@ func (g *PackageGraph) GetModule(path string) *packages.Module {
return mod
}

// AddPackages adds the packages and the full graph of imported packages.
// It will ignore packages that have duplicate paths to ones the graph already holds.
// AddPackages adds the packages and their full graph of imported packages.
// It also adds the modules of the added packages. It will ignore packages
// that have duplicate paths to ones the graph already holds.
func (g *PackageGraph) AddPackages(pkgs ...*packages.Package) {
for _, pkg := range pkgs {
if _, found := g.packages[pkg.PkgPath]; found {
Expand All @@ -84,6 +151,9 @@ func (g *PackageGraph) AddPackages(pkgs ...*packages.Package) {
}
}

// fixupPackage adds the module of pkg, if any, to the set
// of all modules in g. If packages is not assigned a module
// (likely stdlib package), a module set for pkg.
func (g *PackageGraph) fixupPackage(pkg *packages.Package) {
if pkg.Module != nil {
g.AddModules(pkg.Module)
Expand Down Expand Up @@ -124,7 +194,7 @@ func (g *PackageGraph) GetPackage(path string) *packages.Package {

// LoadPackages loads the packages specified by the patterns into the graph.
// See golang.org/x/tools/go/packages.Load for details of how it works.
func (g *PackageGraph) LoadPackagesAndMods(cfg *packages.Config, tags []string, patterns []string) ([]*packages.Package, []*packages.Module, error) {
func (g *PackageGraph) LoadPackagesAndMods(cfg *packages.Config, tags []string, patterns []string) error {
if len(tags) > 0 {
cfg.BuildFlags = []string{fmt.Sprintf("-tags=%s", strings.Join(tags, ","))}
}
Expand All @@ -139,7 +209,7 @@ func (g *PackageGraph) LoadPackagesAndMods(cfg *packages.Config, tags []string,

pkgs, err := packages.Load(cfg, patterns...)
if err != nil {
return nil, nil, err
return err
}
var perrs []packages.Error
packages.Visit(pkgs, nil, func(p *packages.Package) {
Expand All @@ -148,41 +218,16 @@ func (g *PackageGraph) LoadPackagesAndMods(cfg *packages.Config, tags []string,
if len(perrs) > 0 {
err = &packageError{perrs}
}
g.AddPackages(pkgs...)
return pkgs, extractModules(pkgs), err
}

// extractModules collects modules in `pkgs` up to uniqueness of
// module path and version.
func extractModules(pkgs []*packages.Package) []*packages.Module {
modMap := map[string]*packages.Module{}
seen := map[*packages.Package]bool{}
var extract func(*packages.Package, map[string]*packages.Module)
extract = func(pkg *packages.Package, modMap map[string]*packages.Module) {
if pkg == nil || seen[pkg] {
return
}
if pkg.Module != nil {
if pkg.Module.Replace != nil {
modMap[pkg.Module.Replace.Path] = pkg.Module
} else {
modMap[pkg.Module.Path] = pkg.Module
}
}
seen[pkg] = true
for _, imp := range pkg.Imports {
extract(imp, modMap)
}
}
for _, pkg := range pkgs {
extract(pkg, modMap)
}
// Add all packages, top-level ones and their imports.
// This will also add their respective modules.
g.AddPackages(pkgs...)

modules := []*packages.Module{}
for _, mod := range modMap {
modules = append(modules, mod)
// save top-level packages
for _, p := range pkgs {
g.topPkgs = append(g.topPkgs, g.GetPackage(p.PkgPath))
}
return modules
return err
}

// packageError contains errors from loading a set of packages.
Expand Down
4 changes: 2 additions & 2 deletions internal/vulncheck/slicing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ func Do(i I, input string) {
})

graph := NewPackageGraph("go1.18")
pkgs, _, err := graph.LoadPackagesAndMods(e.Config, nil, []string{path.Join(e.Temp(), "/module/slice")})
err := graph.LoadPackagesAndMods(e.Config, nil, []string{path.Join(e.Temp(), "/module/slice")})
if err != nil {
t.Fatal(err)
}
prog, ssaPkgs := ssautil.AllPackages(pkgs, 0)
prog, ssaPkgs := ssautil.AllPackages(graph.TopPkgs(), 0)
prog.Build()

pkg := ssaPkgs[0]
Expand Down
Loading

0 comments on commit bb77557

Please sign in to comment.