Skip to content

Commit

Permalink
feat: organise go imports on format (#793)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Hesketh <[email protected]>
  • Loading branch information
joerdav and a-h authored Jun 22, 2024
1 parent 8d27ad1 commit 69bfdb1
Show file tree
Hide file tree
Showing 84 changed files with 1,052 additions and 1,079 deletions.
2 changes: 1 addition & 1 deletion .version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.724
0.2.725
14 changes: 5 additions & 9 deletions benchmarks/templ/template_templ.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 13 additions & 6 deletions cmd/templ/fmtcmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@ import (
"sync"
"time"

imports "github.com/a-h/templ/cmd/templ/import"
"github.com/a-h/templ/cmd/templ/processor"
parser "github.com/a-h/templ/parser/v2"
"github.com/natefinch/atomic"
)

type Arguments struct {
ToStdout bool
Files []string
WorkerCount int
ToStdout bool
StdinFilepath string
Files []string
WorkerCount int
}

func Run(log *slog.Logger, stdin io.Reader, stdout io.Writer, args Arguments) (err error) {
// If no files are provided, read from stdin and write to stdout.
if len(args.Files) == 0 {
return format(writeToWriter(stdout), readFromReader(stdin))
return format(writeToWriter(stdout), readFromReader(stdin, args.StdinFilepath))
}
process := func(fileName string) error {
read := readFromFile(fileName)
Expand Down Expand Up @@ -82,13 +84,13 @@ func (f *Formatter) Run() (err error) {

type reader func() (fileName, src string, err error)

func readFromReader(r io.Reader) func() (fileName, src string, err error) {
func readFromReader(r io.Reader, stdinFilepath string) func() (fileName, src string, err error) {
return func() (fileName, src string, err error) {
b, err := io.ReadAll(r)
if err != nil {
return "", "", fmt.Errorf("failed to read stdin: %w", err)
}
return "stdin.templ", string(b), nil
return stdinFilepath, string(b), nil
}
}

Expand Down Expand Up @@ -128,6 +130,11 @@ func format(write writer, read reader) (err error) {
if err != nil {
return err
}
t.Filepath = fileName
t, err = imports.Process(t)
if err != nil {
return err
}
w := new(bytes.Buffer)
if err = t.Write(w); err != nil {
return fmt.Errorf("formatting error: %w", err)
Expand Down
4 changes: 2 additions & 2 deletions cmd/templ/fmtcmd/testdata.txtar
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
package test

templ a() {
<div><p>A
<div><p class={templ.Class("mapped")}>A
</p></div>
}
-- a.templ --
package test

templ a() {
<div>
<p>
<p class={ templ.Class("mapped") }>
A
</p>
</div>
Expand Down
14 changes: 5 additions & 9 deletions cmd/templ/generatecmd/testwatch/testdata/templates_templ.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

147 changes: 147 additions & 0 deletions cmd/templ/import/process.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package imports

import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/token"
"log"
"path"
"slices"
"strings"

goparser "go/parser"

"golang.org/x/sync/errgroup"
"golang.org/x/tools/imports"

"github.com/a-h/templ/generator"
"github.com/a-h/templ/parser/v2"
)

var internalImports = []string{"github.com/a-h/templ", "github.com/a-h/templ/runtime"}

func convertTemplToGoURI(templURI string) (isTemplFile bool, goURI string) {
base, fileName := path.Split(templURI)
if !strings.HasSuffix(fileName, ".templ") {
return
}
return true, base + (strings.TrimSuffix(fileName, ".templ") + "_templ.go")
}

var fset = token.NewFileSet()

func updateImports(name, src string) (updated []*ast.ImportSpec, err error) {
// Apply auto imports.
updatedGoCode, err := imports.Process(name, []byte(src), nil)
if err != nil {
return updated, fmt.Errorf("failed to process go code %q: %w", src, err)
}
// Get updated imports.
gofile, err := goparser.ParseFile(fset, name, updatedGoCode, goparser.ImportsOnly)
if err != nil {
return updated, fmt.Errorf("failed to get imports from updated go code: %w", err)
}
for _, imp := range gofile.Imports {
if !slices.Contains(internalImports, strings.Trim(imp.Path.Value, "\"")) {
updated = append(updated, imp)
}
}
return updated, nil
}

func Process(t parser.TemplateFile) (parser.TemplateFile, error) {
if t.Filepath == "" {
return t, nil
}
isTemplFile, fileName := convertTemplToGoURI(t.Filepath)
if !isTemplFile {
return t, fmt.Errorf("invalid filepath: %s", t.Filepath)
}

// The first node always contains existing imports.
// If there isn't one, create it.
if len(t.Nodes) == 0 {
t.Nodes = append(t.Nodes, parser.TemplateFileGoExpression{})
}
// If there is one, ensure it is a Go expression.
if _, ok := t.Nodes[0].(parser.TemplateFileGoExpression); !ok {
t.Nodes = append([]parser.TemplateFileNode{parser.TemplateFileGoExpression{}}, t.Nodes...)
}

// Find all existing imports.
importsNode := t.Nodes[0].(parser.TemplateFileGoExpression)

// Generate code.
gw := bytes.NewBuffer(nil)
var updatedImports []*ast.ImportSpec
var eg errgroup.Group
eg.Go(func() (err error) {
if _, _, err := generator.Generate(t, gw); err != nil {
return fmt.Errorf("failed to generate go code: %w", err)
}
updatedImports, err = updateImports(fileName, gw.String())
if err != nil {
return fmt.Errorf("failed to get imports from generated go code: %w", err)
}
return nil
})

var gofile *ast.File
// Update the template with the imports.
// Ensure that there is a Go expression to add the imports to as the first node.
eg.Go(func() (err error) {
gofile, err = goparser.ParseFile(fset, fileName, t.Package.Expression.Value+"\n"+importsNode.Expression.Value, goparser.AllErrors)
if err != nil {
log.Printf("failed to parse go code: %v", importsNode.Expression.Value)
return fmt.Errorf("failed to parse imports section: %w", err)
}
return nil
})
if err := eg.Wait(); err != nil {
return t, err
}
slices.SortFunc(updatedImports, func(a, b *ast.ImportSpec) int {
return strings.Compare(a.Path.Value, b.Path.Value)
})
newImportDecl := &ast.GenDecl{
Tok: token.IMPORT,
Specs: convertSlice(updatedImports),
}
// Delete all the existing imports.
var indicesToDelete []int
for i, decl := range gofile.Decls {
if decl, ok := decl.(*ast.GenDecl); ok && decl.Tok == token.IMPORT {
indicesToDelete = append(indicesToDelete, i)
}
}
for i := len(indicesToDelete) - 1; i >= 0; i-- {
gofile.Decls = append(gofile.Decls[:indicesToDelete[i]], gofile.Decls[indicesToDelete[i]+1:]...)
}
if len(updatedImports) > 0 {
gofile.Imports = updatedImports
gofile.Decls = append([]ast.Decl{newImportDecl}, gofile.Decls...)
}
// Write out the Go code with the imports.
updatedGoCode := new(strings.Builder)
err := format.Node(updatedGoCode, fset, gofile)
if err != nil {
return t, fmt.Errorf("failed to write updated go code: %w", err)
}
importsNode.Expression.Value = strings.TrimSpace(strings.SplitN(updatedGoCode.String(), "\n", 2)[1])
if len(updatedImports) == 0 && importsNode.Expression.Value == "" {
t.Nodes = t.Nodes[1:]
return t, nil
}
t.Nodes[0] = importsNode
return t, nil
}

func convertSlice(slice []*ast.ImportSpec) []ast.Spec {
result := make([]ast.Spec, len(slice))
for i, v := range slice {
result[i] = ast.Spec(v)
}
return result
}
62 changes: 62 additions & 0 deletions cmd/templ/import/process_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package imports

import (
"bytes"
"path/filepath"
"strings"
"testing"

"github.com/a-h/templ/parser/v2"
"github.com/google/go-cmp/cmp"
"golang.org/x/tools/txtar"
)

func TestFormatting(t *testing.T) {
files, _ := filepath.Glob("testdata/*.txtar")
if len(files) == 0 {
t.Errorf("no test files found")
}
for _, file := range files {
t.Run(filepath.Base(file), func(t *testing.T) {
a, err := txtar.ParseFile(file)
if err != nil {
t.Fatalf("failed to parse txtar file: %v", err)
}
if len(a.Files) != 2 {
t.Fatalf("expected 2 files, got %d", len(a.Files))
}
template, err := parser.ParseString(clean(a.Files[0].Data))
if err != nil {
t.Fatalf("failed to parse %v", err)
}
template.Filepath = a.Files[0].Name
tf, err := Process(template)
if err != nil {
t.Fatalf("failed to process file: %v", err)
}
expected := string(a.Files[1].Data)
actual := new(strings.Builder)
if err := tf.Write(actual); err != nil {
t.Fatalf("failed to write template file: %v", err)
}
if diff := cmp.Diff(expected, actual.String()); diff != "" {
t.Errorf("%s:\n%s", file, diff)
t.Errorf("expected:\n%s", showWhitespace(expected))
t.Errorf("actual:\n%s", showWhitespace(actual.String()))
}
})
}
}

func showWhitespace(s string) string {
s = strings.ReplaceAll(s, "\n", "⏎\n")
s = strings.ReplaceAll(s, "\t", "→")
s = strings.ReplaceAll(s, " ", "·")
return s
}

func clean(b []byte) string {
b = bytes.ReplaceAll(b, []byte("$\n"), []byte("\n"))
b = bytes.TrimSuffix(b, []byte("\n"))
return string(b)
}
14 changes: 14 additions & 0 deletions cmd/templ/import/testdata/deleteimports.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-- fmt.templ --
package test

import "strconv"

templ Hello() {
<div>Hello</div>
}
-- fmt.templ --
package test

templ Hello() {
<div>Hello</div>
}
10 changes: 10 additions & 0 deletions cmd/templ/import/testdata/header.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- fmt_templ.templ --
package test

var x = fmt.Sprintf("Hello")
-- fmt_templ.templ --
package test

import "fmt"

var x = fmt.Sprintf("Hello")
12 changes: 12 additions & 0 deletions cmd/templ/import/testdata/noimports.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- fmt.templ --
package test

templ Hello() {
<div>Hello</div>
}
-- fmt.templ --
package test

templ Hello() {
<div>Hello</div>
}
14 changes: 14 additions & 0 deletions cmd/templ/import/testdata/stringexp.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-- fmt.templ --
package test

templ Hello(name string) {
{ fmt.Sprintf("Hello, %s!", name) }
}
-- fmt.templ --
package test

import "fmt"

templ Hello(name string) {
{ fmt.Sprintf("Hello, %s!", name) }
}
Loading

0 comments on commit 69bfdb1

Please sign in to comment.