Skip to content

Commit

Permalink
code suggestion (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
notJoon authored Jul 18, 2024
1 parent 45363f2 commit 96b6b5d
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 177 deletions.
9 changes: 9 additions & 0 deletions formatter/fmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ func TestFormatIssuesWithArrows_UnnecessaryElse(t *testing.T) {
| ~~~~~~~~~~~~~~~~~~~~
| unnecessary else block
Suggestion:
4 | if condition {
5 | return true
6 | }
7 | return false
Note: Unnecessary 'else' block removed.
The code inside the 'else' block has been moved outside, as it will only be executed when the 'if' condition is false.
`

result := FormatIssuesWithArrows(issues, sourceCode)
Expand Down
11 changes: 6 additions & 5 deletions formatter/general.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (
const tabWidth = 8

var (
errorStyle = color.New(color.FgRed, color.Bold)
ruleStyle = color.New(color.FgYellow, color.Bold)
fileStyle = color.New(color.FgCyan, color.Bold)
lineStyle = color.New(color.FgBlue, color.Bold)
messageStyle = color.New(color.FgRed, color.Bold)
errorStyle = color.New(color.FgRed, color.Bold)
ruleStyle = color.New(color.FgYellow, color.Bold)
fileStyle = color.New(color.FgCyan, color.Bold)
lineStyle = color.New(color.FgBlue, color.Bold)
messageStyle = color.New(color.FgRed, color.Bold)
suggestionStyle = color.New(color.FgGreen, color.Bold)
)

// GeneralIssueFormatter is a formatter for general lint issues.
Expand Down
60 changes: 51 additions & 9 deletions formatter/unnecessary_else.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@ func (f *UnnecessaryElseFormatter) Format(
) string {
var result strings.Builder
ifStartLine, elseEndLine := issue.Start.Line-2, issue.End.Line
maxLineNumberStr := fmt.Sprintf("%d", elseEndLine)
padding := strings.Repeat(" ", len(maxLineNumberStr)-1)

code := strings.Join(snippet.Lines, "\n")
problemSnippet := internal.ExtractSnippet(issue, code, ifStartLine-1, elseEndLine-1)
suggestion, err := internal.RemoveUnnecessaryElse(problemSnippet)
if err != nil {
suggestion = problemSnippet
}

maxLineNumWidth := calculateMaxLineNumWidth(elseEndLine)
padding := strings.Repeat(" ", maxLineNumWidth-1)
result.WriteString(lineStyle.Sprintf(" %s|\n", padding))

maxLen := 0
maxLen := calculateMaxLineLength(snippet.Lines, ifStartLine, elseEndLine)
for i := ifStartLine; i <= elseEndLine; i++ {
if len(snippet.Lines[i-1]) > maxLen {
maxLen = len(snippet.Lines[i-1])
}
line := expandTabs(snippet.Lines[i-1])
lineNumberStr := fmt.Sprintf("%d", i)
linePadding := strings.Repeat(" ", len(maxLineNumberStr)-len(lineNumberStr))
result.WriteString(lineStyle.Sprintf("%s%s | ", linePadding, lineNumberStr))
lineNumberStr := fmt.Sprintf("%*d", maxLineNumWidth, i)
result.WriteString(lineStyle.Sprintf("%s | ", lineNumberStr))
result.WriteString(line + "\n")
}

Expand All @@ -38,5 +41,44 @@ func (f *UnnecessaryElseFormatter) Format(
result.WriteString(lineStyle.Sprintf(" %s| ", padding))
result.WriteString(messageStyle.Sprintf("%s\n\n", issue.Message))

result.WriteString(formatSuggestion(issue, suggestion, ifStartLine))
result.WriteString("\n")

return result.String()
}

func calculateMaxLineNumWidth(endLine int) int {
return len(fmt.Sprintf("%d", endLine))
}

func calculateMaxLineLength(lines []string, start, end int) int {
maxLen := 0
for i := start - 1; i < end; i++ {
if len(lines[i]) > maxLen {
maxLen = len(lines[i])
}
}
return maxLen
}

func formatSuggestion(issue internal.Issue, improvedSnippet string, startLine int) string {
var result strings.Builder
lines := strings.Split(improvedSnippet, "\n")
maxLineNumWidth := calculateMaxLineNumWidth(issue.End.Line)

result.WriteString(suggestionStyle.Sprint("Suggestion:\n"))

for i, line := range lines {
lineNum := fmt.Sprintf("%*d", maxLineNumWidth, startLine+i)
result.WriteString(lineStyle.Sprintf("%s | ", lineNum))
result.WriteString(fmt.Sprintln(line))
}

// Add a note explaining the improvement
result.WriteString("\n")
result.WriteString(suggestionStyle.Sprint("Note: "))
result.WriteString("Unnecessary 'else' block removed.\n")
result.WriteString("The code inside the 'else' block has been moved outside, as it will only be executed when the 'if' condition is false.\n")

return result.String()
}
157 changes: 115 additions & 42 deletions internal/fixer.go
Original file line number Diff line number Diff line change
@@ -1,81 +1,154 @@
package internal

import (
"bytes"
"go/ast"
"go/format"
"go/parser"
"go/token"
"strings"
)

// TODO: Must flattening the nested unnecessary if-else blocks.
func RemoveUnnecessaryElse(snippet string) (string, error) {
wrappedSnippet := "package main\nfunc main() {\n" + snippet + "\n}"

// improveCode refactors the input source code and returns the formatted version.
func improveCode(src []byte) (string, error) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "", src, parser.ParseComments)
file, err := parser.ParseFile(fset, "", wrappedSnippet, parser.ParseComments)
if err != nil {
return "", err
}

err = refactorAST(file)
var funcBody *ast.BlockStmt
ast.Inspect(file, func(n ast.Node) bool {
if fd, ok := n.(*ast.FuncDecl); ok {
funcBody = fd.Body
return false
}
return true
})

removeUnnecessaryElseRecursive(funcBody)

var buf strings.Builder
err = format.Node(&buf, fset, funcBody)
if err != nil {
return "", err
}

return formatSource(fset, file)
result := cleanUpResult(buf.String())

return result, nil
}

// refactorAST processes the AST to modify specific patterns.
func refactorAST(file *ast.File) error {
ast.Inspect(file, func(n ast.Node) bool {
ifStmt, ok := n.(*ast.IfStmt)
if !ok || ifStmt.Else == nil {
return true
}
func cleanUpResult(result string) string {
result = strings.TrimSpace(result)
result = strings.TrimPrefix(result, "{")
result = strings.TrimSuffix(result, "}")
result = strings.TrimSpace(result)

lines := strings.Split(result, "\n")
for i, line := range lines {
lines[i] = strings.TrimPrefix(line, "\t")
}
return strings.Join(lines, "\n")
}

blockStmt, ok := ifStmt.Else.(*ast.BlockStmt)
if !ok || len(ifStmt.Body.List) == 0 {
return true
func removeUnnecessaryElseRecursive(node ast.Node) {
ast.Inspect(node, func(n ast.Node) bool {
if ifStmt, ok := n.(*ast.IfStmt); ok {
processIfStmt(ifStmt, node)
removeUnnecessaryElseRecursive(ifStmt.Body)
if ifStmt.Else != nil {
removeUnnecessaryElseRecursive(ifStmt.Else)
}
return false
}
return true
})
}

_, isReturn := ifStmt.Body.List[len(ifStmt.Body.List)-1].(*ast.ReturnStmt)
if !isReturn {
return true
func processIfStmt(ifStmt *ast.IfStmt, node ast.Node) {
if ifStmt.Else != nil && endsWithReturn(ifStmt.Body) {
parent := findParentBlockStmt(node, ifStmt)
if parent != nil {
switch elseBody := ifStmt.Else.(type) {
case *ast.BlockStmt:
insertStatementsAfter(parent, ifStmt, elseBody.List)
case *ast.IfStmt:
insertStatementsAfter(parent, ifStmt, []ast.Stmt{elseBody})
}
ifStmt.Else = nil
}
} else if ifStmt.Else != nil {
if elseIfStmt, ok := ifStmt.Else.(*ast.IfStmt); ok && endsWithReturn(elseIfStmt.Body) {
processIfStmt(elseIfStmt, ifStmt)
}
}
}

mergeElseIntoIf(file, ifStmt, blockStmt)
ifStmt.Else = nil
func endsWithReturn(block *ast.BlockStmt) bool {
if len(block.List) == 0 {
return false
}
_, isReturn := block.List[len(block.List)-1].(*ast.ReturnStmt)
return isReturn
}

func findParentBlockStmt(root ast.Node, child ast.Node) *ast.BlockStmt {
var parent *ast.BlockStmt
ast.Inspect(root, func(n ast.Node) bool {
if n == child {
return false
}
if block, ok := n.(*ast.BlockStmt); ok {
for _, stmt := range block.List {
if stmt == child {
parent = block
return false
}
}
}
return true
})
return nil
return parent
}

// mergeElseIntoIf merges the statements of an 'else' block into the enclosing function body.
func mergeElseIntoIf(file *ast.File, ifStmt *ast.IfStmt, blockStmt *ast.BlockStmt) {
for _, list := range file.Decls {
decl, ok := list.(*ast.FuncDecl)
if !ok {
continue
}
for i, stmt := range decl.Body.List {
if ifStmt != stmt {
continue
}
decl.Body.List = append(decl.Body.List[:i+1], append(blockStmt.List, decl.Body.List[i+1:]...)...)
func insertStatementsAfter(block *ast.BlockStmt, target ast.Stmt, stmts []ast.Stmt) {
for i, stmt := range block.List {
if stmt == target {
block.List = append(block.List[:i+1], append(stmts, block.List[i+1:]...)...)
break
}
}
}

// formatSource formats the AST back to source code.
func formatSource(fset *token.FileSet, file *ast.File) (string, error) {
var buf bytes.Buffer
err := format.Node(&buf, fset, file)
if err != nil {
return "", err
func ExtractSnippet(issue Issue, code string, startLine, endLine int) string {
lines := strings.Split(code, "\n")

// ensure we don't go out of bounds
if startLine < 0 {
startLine = 0
}
return strings.TrimRight(buf.String(), "\n"), nil
if endLine > len(lines) {
endLine = len(lines)
}

// extract the relevant lines
snippet := lines[startLine:endLine]

// trim any leading empty lines
for len(snippet) > 0 && strings.TrimSpace(snippet[0]) == "" {
snippet = snippet[1:]
}

// ensure the last line is included if it's a closing brace
if endLine < len(lines) && strings.TrimSpace(lines[endLine]) == "}" {
snippet = append(snippet, lines[endLine])
}

// trim any trailing empty lines
for len(snippet) > 0 && strings.TrimSpace(snippet[len(snippet)-1]) == "" {
snippet = snippet[:len(snippet)-1]
}

return strings.Join(snippet, "\n")
}
Loading

0 comments on commit 96b6b5d

Please sign in to comment.