From f17a0013fdf783786958c276b690953405565583 Mon Sep 17 00:00:00 2001 From: adrianiacobghiula <2491756+adrianiacobghiula@users.noreply.github.com> Date: Mon, 6 May 2024 20:12:10 +0200 Subject: [PATCH] feat: allow custom template in avrogen (#392) --- cmd/avrogen/main.go | 48 +++++++++---- gen/gen.go | 147 +++++++++++++-------------------------- gen/output_template.tmpl | 50 +++++++++++++ go.mod | 2 + go.sum | 4 ++ 5 files changed, 141 insertions(+), 110 deletions(-) create mode 100644 gen/output_template.tmpl diff --git a/cmd/avrogen/main.go b/cmd/avrogen/main.go index 6b175213..fea65f04 100644 --- a/cmd/avrogen/main.go +++ b/cmd/avrogen/main.go @@ -5,7 +5,6 @@ import ( "errors" "flag" "fmt" - "go/format" "io" "os" "path/filepath" @@ -13,9 +12,12 @@ import ( "github.com/hamba/avro/v2" "github.com/hamba/avro/v2/gen" + "golang.org/x/tools/imports" ) type config struct { + TemplateFileName string + Pkg string Out string Tags string @@ -38,6 +40,7 @@ func realMain(args []string, stdout, stderr io.Writer) int { flgs.BoolVar(&cfg.FullName, "fullname", false, "Use the full name of the Record schema to create the struct name.") flgs.BoolVar(&cfg.Encoders, "encoders", false, "Generate encoders for the structs.") flgs.StringVar(&cfg.Initialisms, "initialisms", "", "Custom initialisms [,...] for struct and field names.") + flgs.StringVar(&cfg.TemplateFileName, "templateFileName", "", "Override output template with one loaded from file.") flgs.Usage = func() { _, _ = fmt.Fprintln(stderr, "Usage: avrogen [options] schemas") _, _ = fmt.Fprintln(stderr, "Options:") @@ -64,10 +67,17 @@ func realMain(args []string, stdout, stderr io.Writer) int { return 1 } + template, err := loadTemplate(cfg.TemplateFileName) + if err != nil { + _, _ = fmt.Fprintln(stderr, "Error: "+err.Error()) + return 1 + } + opts := []gen.OptsFunc{ gen.WithFullName(cfg.FullName), gen.WithEncoders(cfg.Encoders), gen.WithInitialisms(initialisms), + gen.WithTemplate(string(template)), } g := gen.NewGenerator(cfg.Pkg, tags, opts...) for _, file := range flgs.Args() { @@ -84,30 +94,37 @@ func realMain(args []string, stdout, stderr io.Writer) int { _, _ = fmt.Fprintf(stderr, "Error: could not generate code: %v\n", err) return 3 } - formatted, err := format.Source(buf.Bytes()) + formatted, err := imports.Process("", buf.Bytes(), nil) if err != nil { - _, _ = fmt.Fprintf(stderr, "Error: could not format code: %v\n", err) + _ = writeOut(cfg.Out, stdout, buf.Bytes()) + _, _ = fmt.Fprintf(stderr, "Error: generated code could not be formatted: %v\n", err) return 3 } + err = writeOut(cfg.Out, stdout, formatted) + if err != nil { + _, _ = fmt.Fprintf(stderr, "Error: %v\n", err) + return 4 + } + return 0 +} + +func writeOut(filename string, stdout io.Writer, bytes []byte) error { writer := stdout - if cfg.Out != "" { - file, err := os.Create(cfg.Out) + if filename != "" { + file, err := os.Create(filepath.Clean(filename)) if err != nil { - _, _ = fmt.Fprintf(stderr, "Error: could not create output file: %v\n", err) - return 4 + return fmt.Errorf("could not create output file: %w", err) } defer func() { _ = file.Close() }() writer = file } - if _, err := writer.Write(formatted); err != nil { - _, _ = fmt.Fprintf(stderr, "Error: could not write code: %v\n", err) - return 4 + if _, err := writer.Write(bytes); err != nil { + return fmt.Errorf("could not write code: %w", err) } - - return 0 + return nil } func validateOpts(nargs int, cfg config) error { @@ -172,3 +189,10 @@ func parseInitialisms(raw string) ([]string, error) { return result, nil } + +func loadTemplate(templateFileName string) ([]byte, error) { + if templateFileName == "" { + return nil, nil + } + return os.ReadFile(filepath.Clean(templateFileName)) +} diff --git a/gen/gen.go b/gen/gen.go index 78454d90..3bc18dc5 100644 --- a/gen/gen.go +++ b/gen/gen.go @@ -3,15 +3,17 @@ package gen import ( "bytes" + _ "embed" "errors" "fmt" - "go/format" "io" + "maps" "strings" "text/template" "github.com/ettle/strcase" "github.com/hamba/avro/v2" + "golang.org/x/tools/imports" ) // Config configures the code generation. @@ -39,60 +41,8 @@ const ( UpperCamel TagStyle = "upper-camel" ) -const outputTemplate = `package {{ .PackageName }} - -// Code generated by avro/gen. DO NOT EDIT. - -{{- $encoders := .WithEncoders }} -{{ if len .Imports }} -import ( - {{- range .Imports }} - "{{ . }}" - {{- end }} - {{ if len .ThirdPartyImports }} - - {{- range .ThirdPartyImports }} - "{{ . }}" - {{- end }} - {{ end }} -) -{{ else if len .ThirdPartyImports }} -import ( - {{- range .ThirdPartyImports }} - "{{ . }}" - {{- end }} -) -{{ end }} - - - -{{- range .Typedefs }} -// {{ .Name }} is a generated struct. -type {{ .Name }} struct { - {{- range .Fields }} - {{ .Name }} {{ .Type }} {{ .Tag }} - {{- end }} -} - -{{- if $encoders }} -var schema{{ .Name }} = avro.MustParse(` + "`{{ .Schema }}`" + `) - -// Schema returns the schema for {{ .Name }}. -func (o *{{ .Name }}) Schema() avro.Schema { - return schema{{ .Name }} -} - -// Unmarshal decodes b into the receiver. -func (o *{{ .Name }}) Unmarshal(b []byte) error { - return avro.Unmarshal(o.Schema(), b, o) -} - -// Marshal encodes the receiver. -func (o *{{ .Name }}) Marshal() ([]byte, error) { - return avro.Marshal(o.Schema(), o) -} -{{- end }} -{{ end }}` +//go:embed output_template.tmpl +var outputTemplate string var primitiveMappings = map[avro.Type]string{ "string": "string", @@ -133,9 +83,10 @@ func StructFromSchema(schema avro.Schema, w io.Writer, cfg Config) error { return err } - formatted, err := format.Source(buf.Bytes()) + formatted, err := imports.Process("", buf.Bytes(), nil) if err != nil { - return fmt.Errorf("could not format code: %w", err) + _, _ = w.Write(buf.Bytes()) + return fmt.Errorf("generated code could not be formatted: %w", err) } _, err = w.Write(formatted) @@ -172,8 +123,19 @@ func WithInitialisms(ss []string) OptsFunc { } } +// WithTemplate configures the generator to use a custom template provided by the user. +func WithTemplate(template string) OptsFunc { + return func(g *Generator) { + if template == "" { + return + } + g.template = template + } +} + // Generator generates Go structs from schemas. type Generator struct { + template string pkg string tags map[string]TagStyle fullName bool @@ -189,9 +151,13 @@ type Generator struct { // NewGenerator returns a generator. func NewGenerator(pkg string, tags map[string]TagStyle, opts ...OptsFunc) *Generator { + clonedTags := maps.Clone(tags) + delete(clonedTags, "avro") + g := &Generator{ - pkg: pkg, - tags: tags, + template: outputTemplate, + pkg: pkg, + tags: clonedTags, } for _, opt := range opts { @@ -266,8 +232,7 @@ func (g *Generator) resolveRecordSchema(schema *avro.RecordSchema) string { fields := make([]field, len(schema.Fields())) for i, f := range schema.Fields() { typ := g.generate(f.Type()) - tag := f.Name() - fields[i] = g.newField(g.nameCaser.ToPascal(f.Name()), typ, tag) + fields[i] = g.newField(g.nameCaser.ToPascal(f.Name()), typ, f.Doc(), f.Name()) } typeName := g.resolveTypeName(schema) @@ -334,35 +299,13 @@ func (g *Generator) resolveLogicalSchema(logicalType avro.LogicalType) string { return typ } -func (g *Generator) newField(name, typ, tag string) field { - tagLine := fmt.Sprintf(`avro:"%s"`, tag) - for tagName, style := range g.tags { - if tagName == "avro" { - continue - } - tagLine += fmt.Sprintf(` %s:"%s"`, tagName, formatTag(tag, style)) - } +func (g *Generator) newField(name, typ, avroFieldDoc, avroFieldName string) field { return field{ - Name: name, - Type: typ, - Tag: fmt.Sprintf("`%s`", tagLine), - } -} - -func formatTag(tag string, style TagStyle) string { - switch style { - case Kebab: - return strcase.ToKebab(tag) - case UpperCamel: - return strcase.ToPascal(tag) - case Camel: - return strcase.ToCamel(tag) - case Snake: - return strcase.ToSnake(tag) - case Original: - fallthrough - default: - return tag + Name: name, + Type: typ, + AvroFieldName: avroFieldName, + AvroFieldDoc: avroFieldDoc, + Tags: g.tags, } } @@ -386,7 +329,14 @@ func (g *Generator) addThirdPartyImport(pkg string) { // Write writes Go code from the parsed schemas. func (g *Generator) Write(w io.Writer) error { - parsed, err := template.New("out").Parse(outputTemplate) + parsed, err := template.New("out"). + Funcs(template.FuncMap{ + "kebab": strcase.ToKebab, + "upperCamel": strcase.ToPascal, + "camel": strcase.ToCamel, + "snake": strcase.ToSnake, + }). + Parse(g.template) if err != nil { return err } @@ -398,11 +348,10 @@ func (g *Generator) Write(w io.Writer) error { ThirdPartyImports []string Typedefs []typedef }{ - WithEncoders: g.encoders, - PackageName: g.pkg, - Imports: g.imports, - ThirdPartyImports: g.thirdPartyImports, - Typedefs: g.typedefs, + WithEncoders: g.encoders, + PackageName: g.pkg, + Imports: append(g.imports, g.thirdPartyImports...), + Typedefs: g.typedefs, } return parsed.Execute(w, data) } @@ -422,7 +371,9 @@ func newType(name string, fields []field, schema string) typedef { } type field struct { - Name string - Type string - Tag string + Name string + Type string + AvroFieldName string + AvroFieldDoc string + Tags map[string]TagStyle } diff --git a/gen/output_template.tmpl b/gen/output_template.tmpl new file mode 100644 index 00000000..94cbcf67 --- /dev/null +++ b/gen/output_template.tmpl @@ -0,0 +1,50 @@ +package {{ .PackageName }} + +// Code generated by avro/gen. DO NOT EDIT. + +{{- $encoders := .WithEncoders }} +{{ if len .Imports }} + import ( + {{- range .Imports }} + "{{ . }}" + {{- end }} + ) +{{ end }} + +{{- range .Typedefs }} + // {{ .Name }} is a generated struct. + type {{ .Name }} struct { + {{- range .Fields }} + {{- $f := . }} + {{ .Name }} {{ .Type }} `avro:"{{ $f.AvroFieldName }}" + {{- range $tag, $style := .Tags }} + {{- " "}}{{ $tag }}:" + {{- if eq $style "kebab" }}{{ kebab $f.AvroFieldName }} + {{- else if eq $style "upper-camel"}}{{ upperCamel $f.AvroFieldName }} + {{- else if eq $style "camel"}}{{ camel $f.AvroFieldName }} + {{- else if eq $style "snake"}}{{ snake $f.AvroFieldName }} + {{- else}}{{ $f.AvroFieldName }} + {{- end}}" + {{- end }}` + {{- end }} + } + + {{- if $encoders }} + var schema{{ .Name }} = avro.MustParse(`{{ .Schema }}`) + + // Schema returns the schema for {{ .Name }}. + func (o *{{ .Name }}) Schema() avro.Schema { + return schema{{ .Name }} + } + + // Unmarshal decodes b into the receiver. + func (o *{{ .Name }}) Unmarshal(b []byte) error { + return avro.Unmarshal(o.Schema(), b, o) + } + + // Marshal encodes the receiver. + func (o *{{ .Name }}) Marshal() ([]byte, error) { + return avro.Marshal(o.Schema(), o) + } + {{- end }} +{{ end }} \ No newline at end of file diff --git a/go.mod b/go.mod index 4c1a3bd4..4586f114 100644 --- a/go.mod +++ b/go.mod @@ -16,5 +16,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/tools v0.20.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 99fc2b0a..3db61f1f 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,10 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/tools v0.20.0 h1:hz/CVckiOxybQvFw6h7b/q80NTr9IUQb4s1IIzW7KNY= +golang.org/x/tools v0.20.0/go.mod h1:WvitBU7JJf6A4jOdg4S1tviW9bhUxkgeCui/0JHctQg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=