Skip to content

Commit

Permalink
gopls: add fill switch cases code action
Browse files Browse the repository at this point in the history
  • Loading branch information
martskins committed Feb 6, 2024
1 parent d077888 commit bfd0297
Show file tree
Hide file tree
Showing 7 changed files with 584 additions and 0 deletions.
350 changes: 350 additions & 0 deletions gopls/internal/analysis/fillswitch/fillswitch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,350 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package fillswitch defines an Analyzer that automatically
// fills the missing cases in type switches or switches over named types.
//
// The analyzer's diagnostic is merely a prompt.
// The actual fix is created by a separate direct call from gopls to
// the SuggestedFixes function.
// Tests of Analyzer.Run can be found in ./testdata/src.
// Tests of the SuggestedFixes logic live in ../../testdata/fillswitch.
package fillswitch

import (
"bytes"
"context"
"errors"
"fmt"
"go/ast"
"go/token"
"go/types"
"slices"
"strings"

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/gopls/internal/cache"
"golang.org/x/tools/gopls/internal/cache/parsego"
)

const FixCategory = "fillswitch" // recognized by gopls ApplyFix

// errNoSuggestedFix is returned when no suggested fix is available. This could
// be because all cases are already covered, or (in the case of a type switch)
// because the remaining cases are for types not accessible by the current
// package.
var errNoSuggestedFix = errors.New("no suggested fix")

// Diagnose computes diagnostics for switch statements with missing cases
// overlapping with the provided start and end position.
//
// The diagnostic contains a lazy fix; the actual patch is computed
// (via the ApplyFix command) by a call to [SuggestedFix].
//
// If either start or end is invalid, the entire package is inspected.
func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic {
var diags []analysis.Diagnostic
nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)}
inspect.Preorder(nodeFilter, func(n ast.Node) {
if expr, ok := n.(*ast.SwitchStmt); ok {
if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) {
return // non-overlapping
}

if defaultHandled(expr.Body) {
return
}

namedType, err := namedTypeFromSwitch(expr, info)
if err != nil {
return
}

if _, err := suggestedFixSwitch(expr, pkg, info); err != nil {
return
}

diags = append(diags, analysis.Diagnostic{
Message: "Switch has missing cases",
Pos: expr.Pos(),
End: expr.End(),
Category: FixCategory,
SuggestedFixes: []analysis.SuggestedFix{{
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
// No TextEdits => computed later by gopls.
}},
})
}

if expr, ok := n.(*ast.TypeSwitchStmt); ok {
if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) {
return // non-overlapping
}

if defaultHandled(expr.Body) {
return
}

namedType, err := namedTypeFromTypeSwitch(expr, info)
if err != nil {
return
}

if _, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil {
return
}

diags = append(diags, analysis.Diagnostic{
Message: "Switch has missing cases",
Pos: expr.Pos(),
End: expr.End(),
Category: FixCategory,
SuggestedFixes: []analysis.SuggestedFix{{
Message: fmt.Sprintf("Add cases for %v", namedType.Obj().Name()),
// No TextEdits => computed later by gopls.
}},
})
}
})

return diags
}

func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
namedType, err := namedTypeFromTypeSwitch(stmt, info)
if err != nil {
return nil, err
}

scope := namedType.Obj().Pkg().Scope()
variants := make([]string, 0)
for _, name := range scope.Names() {
obj := scope.Lookup(name)
if _, ok := obj.(*types.TypeName); !ok {
continue
}

if types.Identical(obj.Type(), namedType.Obj().Type()) {
continue
}

if types.AssignableTo(obj.Type(), namedType.Obj().Type()) {
if obj.Pkg().Name() != pkg.Name() {
if !obj.Exported() {
continue
}

variants = append(variants, obj.Pkg().Name()+"."+obj.Name())
} else {
variants = append(variants, obj.Name())
}
} else if types.AssignableTo(types.NewPointer(obj.Type()), namedType.Obj().Type()) {
if obj.Pkg().Name() != pkg.Name() {
if !obj.Exported() {
continue
}

variants = append(variants, "*"+obj.Pkg().Name()+"."+obj.Name())
} else {
variants = append(variants, "*"+obj.Name())
}
}
}

handledVariants := getHandledVariants(stmt.Body)
if len(variants) == 0 || len(variants) == len(handledVariants) {
return nil, errNoSuggestedFix
}

newText := buildNewText(variants, handledVariants)
return &analysis.SuggestedFix{
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
TextEdits: []analysis.TextEdit{{
Pos: stmt.End() - 1,
End: stmt.End() - 1,
NewText: indent([]byte(newText), []byte{'\t'}),
}},
}, nil
}

func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
namedType, err := namedTypeFromSwitch(stmt, info)
if err != nil {
return nil, err
}

scope := namedType.Obj().Pkg().Scope()
variants := make([]string, 0)
for _, name := range scope.Names() {
obj := scope.Lookup(name)
if obj.Id() == namedType.Obj().Id() {
continue
}

if types.Identical(obj.Type(), namedType.Obj().Type()) {
// TODO: comparing the package name like this feels wrong, is it?
if obj.Pkg().Name() != pkg.Name() {
if !obj.Exported() {
continue
}

variants = append(variants, obj.Pkg().Name()+"."+obj.Name())
} else {
variants = append(variants, obj.Name())
}
}
}

handledVariants := getHandledVariants(stmt.Body)
if len(variants) == 0 || len(variants) == len(handledVariants) {
return nil, errNoSuggestedFix
}

newText := buildNewText(variants, handledVariants)
return &analysis.SuggestedFix{
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
TextEdits: []analysis.TextEdit{{
Pos: stmt.End() - 1,
End: stmt.End() - 1,
NewText: indent([]byte(newText), []byte{'\t'}),
}},
}, nil
}

func namedTypeFromSwitch(stmt *ast.SwitchStmt, info *types.Info) (*types.Named, error) {
typ := info.TypeOf(stmt.Tag)
if typ == nil {
return nil, errors.New("expected switch statement to have a tag")
}

namedType, ok := typ.(*types.Named)
if !ok {
return nil, errors.New("switch statement is not on a named type")
}

return namedType, nil
}

func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) (*types.Named, error) {
switch s := stmt.Assign.(type) {
case *ast.ExprStmt:
typ := s.X.(*ast.TypeAssertExpr)
namedType, ok := info.TypeOf(typ.X).(*types.Named)
if !ok {
return nil, errors.New("type switch expression is not on a named type")
}

return namedType, nil
case *ast.AssignStmt:
for _, expr := range s.Rhs {
typ, ok := expr.(*ast.TypeAssertExpr)
if !ok {
continue
}

namedType, ok := info.TypeOf(typ.X).(*types.Named)
if !ok {
continue
}

return namedType, nil
}

return nil, errors.New("expected type switch expression to have a named type")
default:
return nil, errors.New("node is not a type switch statement")
}
}

func defaultHandled(body *ast.BlockStmt) bool {
for _, bl := range body.List {
if len(bl.(*ast.CaseClause).List) == 0 {
return true
}
}

return false
}

func buildNewText(variants []string, handledVariants []string) string {
var textBuilder strings.Builder
for _, c := range variants {
if slices.Contains(handledVariants, c) {
continue
}

textBuilder.WriteString("case ")
textBuilder.WriteString(c)
textBuilder.WriteString(":\n")
}

return textBuilder.String()
}

func getHandledVariants(body *ast.BlockStmt) []string {
out := make([]string, 0)
for _, bl := range body.List {
for _, c := range bl.(*ast.CaseClause).List {
switch v := c.(type) {
case *ast.Ident:
out = append(out, v.Name)
case *ast.SelectorExpr:
out = append(out, v.X.(*ast.Ident).Name+"."+v.Sel.Name)
case *ast.StarExpr:
out = append(out, "*"+v.X.(*ast.Ident).Name)
}
}
}

return out
}

// SuggestedFix computes the suggested fix for the kinds of
// diagnostics produced by the Analyzer above.
func SuggestedFix(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) {
pos := start // don't use the end
path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos)
if len(path) < 2 {
return nil, nil, fmt.Errorf("no expression found")
}

switch stmt := path[0].(type) {
case *ast.SwitchStmt:
fix, err := suggestedFixSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo())
if err != nil {
return nil, nil, err
}

return pkg.FileSet(), fix, nil
case *ast.TypeSwitchStmt:
fix, err := suggestedFixTypeSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo())
if err != nil {
return nil, nil, err
}

return pkg.FileSet(), fix, nil
default:
return nil, nil, fmt.Errorf("no switch statement found")
}
}

// indent works line by line through str, prefixing each line with
// prefix.
func indent(str, prefix []byte) []byte {
split := bytes.Split(str, []byte("\n"))
newText := bytes.NewBuffer(nil)
for i, s := range split {
if i != 0 {
newText.Write(prefix)
}

newText.Write(s)
if i < len(split)-1 {
newText.WriteByte('\n')
}
}
return newText.Bytes()
}
38 changes: 38 additions & 0 deletions gopls/internal/analysis/fillswitch/fillswitch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package fillswitch_test

import (
"go/token"
"testing"

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/analysistest"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/gopls/internal/analysis/fillswitch"
)

// analyzer allows us to test the fillswitch code action using the analysistest
// harness. (fillswitch used to be a gopls analyzer.)
var analyzer = &analysis.Analyzer{
Name: "fillswitch",
Doc: "test only",
Requires: []*analysis.Analyzer{inspect.Analyzer},
Run: func(pass *analysis.Pass) (any, error) {
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
for _, d := range fillswitch.Diagnose(inspect, token.NoPos, token.NoPos, pass.Pkg, pass.TypesInfo) {
pass.Report(d)
}
return nil, nil
},
URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/fillswitch",
RunDespiteErrors: true,
}

func Test(t *testing.T) {
testdata := analysistest.TestData()
analysistest.Run(t, testdata, analyzer, "a")
}
Loading

0 comments on commit bfd0297

Please sign in to comment.