Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor code #1539

Merged
merged 4 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/swag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func initAction(ctx *cli.Context) error {
Tags: ctx.String(tagsFlag),
PackageName: ctx.String(packageName),
Debugger: logger,
OpenAPIVersion: ctx.Bool(openAPIVersionFlag),
GenerateOpenAPI3Doc: ctx.Bool(openAPIVersionFlag),
CollectionFormat: collectionFormat,
})
}
Expand Down
209 changes: 119 additions & 90 deletions gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gen
import (
"bufio"
"bytes"
"embed"
"encoding/json"
"fmt"
"go/format"
Expand All @@ -17,8 +18,9 @@ import (

jsoniter "github.com/json-iterator/go"

"github.com/go-openapi/spec"
openapi "github.com/sv-tools/openapi/spec"
v2 "github.com/go-openapi/spec"
v3 "github.com/sv-tools/openapi/spec"

"github.com/swaggo/swag"
"sigs.k8s.io/yaml"
)
Expand All @@ -28,18 +30,20 @@ var open = os.Open
// DefaultOverridesFile is the location swagger will look for type overrides.
const DefaultOverridesFile = ".swaggo"

type genTypeWriter func(*Config, *spec.Swagger) error
type genTypeWriter func(*Config, interface{}) error

// Gen presents a generate tool for swag.
type Gen struct {
json func(data interface{}) ([]byte, error)
jsonIndent func(data interface{}) ([]byte, error)
jsonToYAML func(data []byte) ([]byte, error)
outputTypeMap map[string]genTypeWriter
outputTypeMapV3 map[string]openAPITypeWriter
debug Debugger
json func(data interface{}) ([]byte, error)
jsonIndent func(data interface{}) ([]byte, error)
jsonToYAML func(data []byte) ([]byte, error)
outputTypeMap map[string]genTypeWriter
debug Debugger
}

//go:embed src/*.tmpl
var tmpl embed.FS

// Debugger is the interface that wraps the basic Printf method.
type Debugger interface {
Printf(format string, v ...interface{})
Expand All @@ -50,25 +54,17 @@ func New() *Gen {
gen := Gen{
json: json.Marshal,
jsonIndent: func(data interface{}) ([]byte, error) {
var json = jsoniter.ConfigCompatibleWithStandardLibrary
return json.MarshalIndent(&data, "", " ")
return jsoniter.ConfigCompatibleWithStandardLibrary.MarshalIndent(&data, "", " ")
},
jsonToYAML: yaml.JSONToYAML,
debug: log.New(os.Stdout, "", log.LstdFlags),
}

gen.outputTypeMap = map[string]genTypeWriter{
"go": gen.writeDocSwagger,
"json": gen.writeJSONSwagger,
"yaml": gen.writeYAMLSwagger,
"yml": gen.writeYAMLSwagger,
}

gen.outputTypeMapV3 = map[string]openAPITypeWriter{
"go": gen.writeDocOpenAPI,
"json": gen.writeJSONOpenAPI,
"yaml": gen.writeYAMLOpenAPI,
"yml": gen.writeYAMLOpenAPI,
"go": gen.writeDoc,
"json": gen.writeJSON,
"yaml": gen.writeYAML,
"yml": gen.writeYAML,
}

return &gen
Expand Down Expand Up @@ -139,8 +135,9 @@ type Config struct {
// include only tags mentioned when searching, comma separated
Tags string

// if true, OpenAPI V3.1 spec will be generated
OpenAPIVersion bool
// GenerateOpenAPI3Doc if true, OpenAPI V3.1 spec will be generated
GenerateOpenAPI3Doc bool

// PackageName defines package name of generated `docs.go`
PackageName string

Expand Down Expand Up @@ -196,7 +193,7 @@ func (g *Gen) Build(config *Config) error {
swag.SetOverrides(overrides),
swag.ParseUsingGoList(config.ParseGoList),
swag.SetTags(config.Tags),
swag.SetOpenAPIVersion(config.OpenAPIVersion),
swag.GenerateOpenAPI3Doc(config.GenerateOpenAPI3Doc),
swag.SetCollectionFormat(config.CollectionFormat),
)

Expand All @@ -213,45 +210,18 @@ func (g *Gen) Build(config *Config) error {
return err
}

if config.OpenAPIVersion {
openAPI := p.GetOpenAPI()
err := g.writeOpenAPI(config, openAPI)
if err != nil {
return err
}

return nil
}

swagger := p.GetSwagger()
err := g.writeSwagger(config, swagger)
if err != nil {
return err
if config.GenerateOpenAPI3Doc {
return g.writeOpenAPI(config, p.GetOpenAPI())
}

return nil
return g.writeOpenAPI(config, p.GetSwagger())
}

func (g *Gen) writeOpenAPI(config *Config, o *openapi.OpenAPI) error {
for _, outputType := range config.OutputTypes {
outputType = strings.ToLower(strings.TrimSpace(outputType))
if typeWriter, ok := g.outputTypeMapV3[outputType]; ok {
if err := typeWriter(config, o); err != nil {
return err
}
} else {
log.Printf("output type '%s' not supported", outputType)
}
}

return nil
}

func (g *Gen) writeSwagger(config *Config, swagger *spec.Swagger) error {
func (g *Gen) writeOpenAPI(config *Config, doc interface{}) error {
for _, outputType := range config.OutputTypes {
outputType = strings.ToLower(strings.TrimSpace(outputType))
if typeWriter, ok := g.outputTypeMap[outputType]; ok {
if err := typeWriter(config, swagger); err != nil {
if err := typeWriter(config, doc); err != nil {
return err
}
} else {
Expand All @@ -262,7 +232,7 @@ func (g *Gen) writeSwagger(config *Config, swagger *spec.Swagger) error {
return nil
}

func (g *Gen) writeDocSwagger(config *Config, swagger *spec.Swagger) error {
func (g *Gen) writeDoc(config *Config, doc interface{}) error {
var filename = "docs.go"

if config.InstanceName != swag.Name {
Expand Down Expand Up @@ -291,17 +261,25 @@ func (g *Gen) writeDocSwagger(config *Config, swagger *spec.Swagger) error {
defer docs.Close()

// Write doc
err = g.writeGoDoc(packageName, docs, swagger, config)
if err != nil {
return err
}
switch spec := doc.(type) {
case *v2.Swagger:
err = g.writeGoDoc(packageName, docs, spec, config)
if err != nil {
return err

}
case *v3.OpenAPI:
err = g.writeGoDocV3(packageName, docs, spec, config)
if err != nil {
return nil
}
}
g.debug.Printf("create docs.go at %+v", docFileName)

return nil
}

func (g *Gen) writeJSONSwagger(config *Config, swagger *spec.Swagger) error {
func (g *Gen) writeJSON(config *Config, spec interface{}) error {
var filename = "swagger.json"

if config.InstanceName != swag.Name {
Expand All @@ -310,7 +288,7 @@ func (g *Gen) writeJSONSwagger(config *Config, swagger *spec.Swagger) error {

jsonFileName := path.Join(config.OutputDir, filename)

b, err := g.jsonIndent(swagger)
b, err := g.jsonIndent(spec)
if err != nil {
return err
}
Expand All @@ -325,7 +303,7 @@ func (g *Gen) writeJSONSwagger(config *Config, swagger *spec.Swagger) error {
return nil
}

func (g *Gen) writeYAMLSwagger(config *Config, swagger *spec.Swagger) error {
func (g *Gen) writeYAML(config *Config, swagger interface{}) error {
var filename = "swagger.yaml"

if config.InstanceName != swag.Name {
Expand Down Expand Up @@ -421,29 +399,29 @@ func parseOverrides(r io.Reader) (map[string]string, error) {
return overrides, nil
}

func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swagger, config *Config) error {
generator, err := template.New("swagger_info").Funcs(template.FuncMap{
func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *v2.Swagger, config *Config) error {
generator, err := template.New("oas2.tmpl").Funcs(template.FuncMap{
"printDoc": func(v string) string {
// Add schemes
v = "{\n \"schemes\": {{ marshal .Schemes }}," + v[1:]
// Sanitize backticks
return strings.Replace(v, "`", "`+\"`\"+`", -1)
},
}).Parse(packageTemplate)
}).ParseFS(tmpl, "src/*.tmpl")
if err != nil {
return err
}

swaggerSpec := &spec.Swagger{
swaggerSpec := &v2.Swagger{
VendorExtensible: swagger.VendorExtensible,
SwaggerProps: spec.SwaggerProps{
SwaggerProps: v2.SwaggerProps{
ID: swagger.ID,
Consumes: swagger.Consumes,
Produces: swagger.Produces,
Swagger: swagger.Swagger,
Info: &spec.Info{
Info: &v2.Info{
VendorExtensible: swagger.Info.VendorExtensible,
InfoProps: spec.InfoProps{
InfoProps: v2.InfoProps{
Description: "{{escape .Description}}",
Title: "{{.Title}}",
TermsOfService: swagger.Info.TermsOfService,
Expand Down Expand Up @@ -510,27 +488,78 @@ func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swa
return err
}

var packageTemplate = `// Code generated by swaggo/swag{{ if .GeneratedTime }} at {{ .Timestamp }}{{ end }}. DO NOT EDIT.
func (g *Gen) writeGoDocV3(packageName string, output io.Writer, openAPI *v3.OpenAPI, config *Config) error {
generator, err := template.New("oas3.tmpl").Funcs(template.FuncMap{
"printDoc": func(v string) string {
// Add schemes
v = "{\n \"schemes\": {{ marshal .Schemes }}," + v[1:]
// Sanitize backticks
return strings.Replace(v, "`", "`+\"`\"+`", -1)
},
}).ParseFS(tmpl, "src/*.tmpl")
if err != nil {
return err
}

openAPISpec := v3.OpenAPI{
Components: openAPI.Components,
OpenAPI: openAPI.OpenAPI,
Info: &v3.Extendable[v3.Info]{
Spec: &v3.Info{
Description: "{{escape .Description}}",
Title: "{{.Title}}",
Version: "{{.Version}}",
TermsOfService: openAPI.Info.Spec.TermsOfService,
Contact: openAPI.Info.Spec.Contact,
License: openAPI.Info.Spec.License,
Summary: openAPI.Info.Spec.Summary,
},
Extensions: openAPI.Info.Extensions,
},
ExternalDocs: openAPI.ExternalDocs,
Paths: openAPI.Paths,
WebHooks: openAPI.WebHooks,
JsonSchemaDialect: openAPI.JsonSchemaDialect,
Security: openAPI.Security,
Tags: openAPI.Tags,
Servers: openAPI.Servers,
}

package docs
// crafted docs.json
buf, err := g.jsonIndent(openAPISpec)
if err != nil {
return err
}

import "github.com/swaggo/swag"
buffer := &bytes.Buffer{}

const docTemplate{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} = ` + "`{{ printDoc .Doc}}`" + `
err = generator.Execute(buffer, struct {
Timestamp time.Time
Doc string
PackageName string
Title string
Description string
Version string
InstanceName string
GeneratedTime bool
}{
Timestamp: time.Now(),
GeneratedTime: config.GeneratedTime,
Doc: string(buf),
PackageName: packageName,
Title: openAPI.Info.Spec.Title,
Description: openAPI.Info.Spec.Description,
Version: openAPI.Info.Spec.Version,
InstanceName: config.InstanceName,
})
if err != nil {
return err
}

// SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} holds exported Swagger Info so clients can modify it
var SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }} = &swag.Spec{
Version: {{ printf "%q" .Version}},
Host: {{ printf "%q" .Host}},
BasePath: {{ printf "%q" .BasePath}},
Schemes: []string{ {{ range $index, $schema := .Schemes}}{{if gt $index 0}},{{end}}{{printf "%q" $schema}}{{end}} },
Title: {{ printf "%q" .Title}},
Description: {{ printf "%q" .Description}},
InfoInstanceName: {{ printf "%q" .InstanceName }},
SwaggerTemplate: docTemplate{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }},
}
code := g.formatSource(buffer.Bytes())

func init() {
swag.Register(SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }}.InstanceName(), SwaggerInfo{{ if ne .InstanceName "swagger" }}{{ .InstanceName }} {{- end }})
// write
_, err = output.Write(code)

return err
}
`
Loading