-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
224 additions
and
177 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,154 @@ | ||
package internal | ||
|
||
import ( | ||
"bytes" | ||
"go/ast" | ||
"go/format" | ||
"go/parser" | ||
"go/token" | ||
"strings" | ||
) | ||
|
||
// TODO: Must flattening the nested unnecessary if-else blocks. | ||
func RemoveUnnecessaryElse(snippet string) (string, error) { | ||
wrappedSnippet := "package main\nfunc main() {\n" + snippet + "\n}" | ||
|
||
// improveCode refactors the input source code and returns the formatted version. | ||
func improveCode(src []byte) (string, error) { | ||
fset := token.NewFileSet() | ||
file, err := parser.ParseFile(fset, "", src, parser.ParseComments) | ||
file, err := parser.ParseFile(fset, "", wrappedSnippet, parser.ParseComments) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
err = refactorAST(file) | ||
var funcBody *ast.BlockStmt | ||
ast.Inspect(file, func(n ast.Node) bool { | ||
if fd, ok := n.(*ast.FuncDecl); ok { | ||
funcBody = fd.Body | ||
return false | ||
} | ||
return true | ||
}) | ||
|
||
removeUnnecessaryElseRecursive(funcBody) | ||
|
||
var buf strings.Builder | ||
err = format.Node(&buf, fset, funcBody) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
return formatSource(fset, file) | ||
result := cleanUpResult(buf.String()) | ||
|
||
return result, nil | ||
} | ||
|
||
// refactorAST processes the AST to modify specific patterns. | ||
func refactorAST(file *ast.File) error { | ||
ast.Inspect(file, func(n ast.Node) bool { | ||
ifStmt, ok := n.(*ast.IfStmt) | ||
if !ok || ifStmt.Else == nil { | ||
return true | ||
} | ||
func cleanUpResult(result string) string { | ||
result = strings.TrimSpace(result) | ||
result = strings.TrimPrefix(result, "{") | ||
result = strings.TrimSuffix(result, "}") | ||
result = strings.TrimSpace(result) | ||
|
||
lines := strings.Split(result, "\n") | ||
for i, line := range lines { | ||
lines[i] = strings.TrimPrefix(line, "\t") | ||
} | ||
return strings.Join(lines, "\n") | ||
} | ||
|
||
blockStmt, ok := ifStmt.Else.(*ast.BlockStmt) | ||
if !ok || len(ifStmt.Body.List) == 0 { | ||
return true | ||
func removeUnnecessaryElseRecursive(node ast.Node) { | ||
ast.Inspect(node, func(n ast.Node) bool { | ||
if ifStmt, ok := n.(*ast.IfStmt); ok { | ||
processIfStmt(ifStmt, node) | ||
removeUnnecessaryElseRecursive(ifStmt.Body) | ||
if ifStmt.Else != nil { | ||
removeUnnecessaryElseRecursive(ifStmt.Else) | ||
} | ||
return false | ||
} | ||
return true | ||
}) | ||
} | ||
|
||
_, isReturn := ifStmt.Body.List[len(ifStmt.Body.List)-1].(*ast.ReturnStmt) | ||
if !isReturn { | ||
return true | ||
func processIfStmt(ifStmt *ast.IfStmt, node ast.Node) { | ||
if ifStmt.Else != nil && endsWithReturn(ifStmt.Body) { | ||
parent := findParentBlockStmt(node, ifStmt) | ||
if parent != nil { | ||
switch elseBody := ifStmt.Else.(type) { | ||
case *ast.BlockStmt: | ||
insertStatementsAfter(parent, ifStmt, elseBody.List) | ||
case *ast.IfStmt: | ||
insertStatementsAfter(parent, ifStmt, []ast.Stmt{elseBody}) | ||
} | ||
ifStmt.Else = nil | ||
} | ||
} else if ifStmt.Else != nil { | ||
if elseIfStmt, ok := ifStmt.Else.(*ast.IfStmt); ok && endsWithReturn(elseIfStmt.Body) { | ||
processIfStmt(elseIfStmt, ifStmt) | ||
} | ||
} | ||
} | ||
|
||
mergeElseIntoIf(file, ifStmt, blockStmt) | ||
ifStmt.Else = nil | ||
func endsWithReturn(block *ast.BlockStmt) bool { | ||
if len(block.List) == 0 { | ||
return false | ||
} | ||
_, isReturn := block.List[len(block.List)-1].(*ast.ReturnStmt) | ||
return isReturn | ||
} | ||
|
||
func findParentBlockStmt(root ast.Node, child ast.Node) *ast.BlockStmt { | ||
var parent *ast.BlockStmt | ||
ast.Inspect(root, func(n ast.Node) bool { | ||
if n == child { | ||
return false | ||
} | ||
if block, ok := n.(*ast.BlockStmt); ok { | ||
for _, stmt := range block.List { | ||
if stmt == child { | ||
parent = block | ||
return false | ||
} | ||
} | ||
} | ||
return true | ||
}) | ||
return nil | ||
return parent | ||
} | ||
|
||
// mergeElseIntoIf merges the statements of an 'else' block into the enclosing function body. | ||
func mergeElseIntoIf(file *ast.File, ifStmt *ast.IfStmt, blockStmt *ast.BlockStmt) { | ||
for _, list := range file.Decls { | ||
decl, ok := list.(*ast.FuncDecl) | ||
if !ok { | ||
continue | ||
} | ||
for i, stmt := range decl.Body.List { | ||
if ifStmt != stmt { | ||
continue | ||
} | ||
decl.Body.List = append(decl.Body.List[:i+1], append(blockStmt.List, decl.Body.List[i+1:]...)...) | ||
func insertStatementsAfter(block *ast.BlockStmt, target ast.Stmt, stmts []ast.Stmt) { | ||
for i, stmt := range block.List { | ||
if stmt == target { | ||
block.List = append(block.List[:i+1], append(stmts, block.List[i+1:]...)...) | ||
break | ||
} | ||
} | ||
} | ||
|
||
// formatSource formats the AST back to source code. | ||
func formatSource(fset *token.FileSet, file *ast.File) (string, error) { | ||
var buf bytes.Buffer | ||
err := format.Node(&buf, fset, file) | ||
if err != nil { | ||
return "", err | ||
func ExtractSnippet(issue Issue, code string, startLine, endLine int) string { | ||
lines := strings.Split(code, "\n") | ||
|
||
// ensure we don't go out of bounds | ||
if startLine < 0 { | ||
startLine = 0 | ||
} | ||
return strings.TrimRight(buf.String(), "\n"), nil | ||
if endLine > len(lines) { | ||
endLine = len(lines) | ||
} | ||
|
||
// extract the relevant lines | ||
snippet := lines[startLine:endLine] | ||
|
||
// trim any leading empty lines | ||
for len(snippet) > 0 && strings.TrimSpace(snippet[0]) == "" { | ||
snippet = snippet[1:] | ||
} | ||
|
||
// ensure the last line is included if it's a closing brace | ||
if endLine < len(lines) && strings.TrimSpace(lines[endLine]) == "}" { | ||
snippet = append(snippet, lines[endLine]) | ||
} | ||
|
||
// trim any trailing empty lines | ||
for len(snippet) > 0 && strings.TrimSpace(snippet[len(snippet)-1]) == "" { | ||
snippet = snippet[:len(snippet)-1] | ||
} | ||
|
||
return strings.Join(snippet, "\n") | ||
} |
Oops, something went wrong.