diff --git a/protoc-gen-grpc-gateway/gengateway/generator.go b/protoc-gen-grpc-gateway/gengateway/generator.go index cb2f5e14a9f..e0a3f8a1b9f 100644 --- a/protoc-gen-grpc-gateway/gengateway/generator.go +++ b/protoc-gen-grpc-gateway/gengateway/generator.go @@ -21,13 +21,14 @@ var ( ) type generator struct { - reg *descriptor.Registry - baseImports []descriptor.GoPackage - useRequestContext bool + reg *descriptor.Registry + baseImports []descriptor.GoPackage + useRequestContext bool + registerFuncSuffix string } // New returns a new generator which generates grpc gateway files. -func New(reg *descriptor.Registry, useRequestContext bool) gen.Generator { +func New(reg *descriptor.Registry, useRequestContext bool, registerFuncSuffix string) gen.Generator { var imports []descriptor.GoPackage for _, pkgpath := range []string{ "io", @@ -57,7 +58,12 @@ func New(reg *descriptor.Registry, useRequestContext bool) gen.Generator { } imports = append(imports, pkg) } - return &generator{reg: reg, baseImports: imports, useRequestContext: useRequestContext} + return &generator{ + reg: reg, + baseImports: imports, + useRequestContext: useRequestContext, + registerFuncSuffix: registerFuncSuffix, + } } func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) { @@ -111,5 +117,11 @@ func (g *generator) generate(file *descriptor.File) (string, error) { imports = append(imports, pkg) } } - return applyTemplate(param{File: file, Imports: imports, UseRequestContext: g.useRequestContext}) + params := param{ + File: file, + Imports: imports, + UseRequestContext: g.useRequestContext, + RegisterFuncSuffix: g.registerFuncSuffix, + } + return applyTemplate(params) } diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index ef52049a625..702d8b7b57f 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -13,8 +13,9 @@ import ( type param struct { *descriptor.File - Imports []descriptor.GoPackage - UseRequestContext bool + Imports []descriptor.GoPackage + UseRequestContext bool + RegisterFuncSuffix string } type binding struct { @@ -68,8 +69,9 @@ func (f queryParamFilter) String() string { } type trailerParams struct { - Services []*descriptor.Service - UseRequestContext bool + Services []*descriptor.Service + UseRequestContext bool + RegisterFuncSuffix string } func applyTemplate(p param) (string, error) { @@ -102,8 +104,9 @@ func applyTemplate(p param) (string, error) { } tp := trailerParams{ - Services: targetServices, - UseRequestContext: p.UseRequestContext, + Services: targetServices, + UseRequestContext: p.UseRequestContext, + RegisterFuncSuffix: p.RegisterFuncSuffix, } if err := trailerTemplate.Execute(w, tp); err != nil { return "", err @@ -314,9 +317,9 @@ var ( trailerTemplate = template.Must(template.New("trailer").Parse(` {{$UseRequestContext := .UseRequestContext}} {{range $svc := .Services}} -// Register{{$svc.GetName}}HandlerFromEndpoint is same as Register{{$svc.GetName}}Handler but +// Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}FromEndpoint is same as Register{{$svc.GetName}}{{$.RegisterFuncSuffix}} but // automatically dials to "endpoint" and closes the connection when "ctx" gets done. -func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { +func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}FromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { conn, err := grpc.Dial(endpoint, opts...) if err != nil { return err @@ -336,21 +339,21 @@ func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runti }() }() - return Register{{$svc.GetName}}Handler(ctx, mux, conn) + return Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}(ctx, mux, conn) } -// Register{{$svc.GetName}}Handler registers the http handlers for service {{$svc.GetName}} to "mux". +// Register{{$svc.GetName}}{{$.RegisterFuncSuffix}} registers the http handlers for service {{$svc.GetName}} to "mux". // The handlers forward requests to the grpc endpoint over "conn". -func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { - return Register{{$svc.GetName}}HandlerClient(ctx, mux, New{{$svc.GetName}}Client(conn)) +func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx, mux, New{{$svc.GetName}}Client(conn)) } -// Register{{$svc.GetName}}Handler registers the http handlers for service {{$svc.GetName}} to "mux". -// The handlers forward requests to the grpc endpoint over the given implementation of "{{$svc.GetName}}Client". +// Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client registers the http handlers for service {{$svc.GetName}} +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "{{$svc.GetName}}Client". // Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "{{$svc.GetName}}Client" // doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in // "{{$svc.GetName}}Client" to call the correct interceptors. -func Register{{$svc.GetName}}HandlerClient(ctx context.Context, mux *runtime.ServeMux, client {{$svc.GetName}}Client) error { +func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, mux *runtime.ServeMux, client {{$svc.GetName}}Client) error { {{range $m := $svc.Methods}} {{range $b := $m.Bindings}} mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { diff --git a/protoc-gen-grpc-gateway/gengateway/template_test.go b/protoc-gen-grpc-gateway/gengateway/template_test.go index c5fb2f93e39..d28c943e1d3 100644 --- a/protoc-gen-grpc-gateway/gengateway/template_test.go +++ b/protoc-gen-grpc-gateway/gengateway/template_test.go @@ -77,7 +77,7 @@ func TestApplyTemplateHeader(t *testing.T) { }, }, } - got, err := applyTemplate(param{File: crossLinkFixture(&file)}) + got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler"}) if err != nil { t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err) return @@ -222,7 +222,7 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) { }, }, } - got, err := applyTemplate(param{File: crossLinkFixture(&file)}) + got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler"}) if err != nil { t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err) return @@ -383,7 +383,7 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { }, }, } - got, err := applyTemplate(param{File: crossLinkFixture(&file)}) + got, err := applyTemplate(param{File: crossLinkFixture(&file), RegisterFuncSuffix: "Handler"}) if err != nil { t.Errorf("applyTemplate(%#v) failed with %v; want success", file, err) return diff --git a/protoc-gen-grpc-gateway/main.go b/protoc-gen-grpc-gateway/main.go index 4b875d51b36..b15a2345fb6 100644 --- a/protoc-gen-grpc-gateway/main.go +++ b/protoc-gen-grpc-gateway/main.go @@ -23,10 +23,11 @@ import ( ) var ( - importPrefix = flag.String("import_prefix", "", "prefix to be added to go package paths for imported proto files") - importPath = flag.String("import_path", "", "used as the package if no input files declare go_package. If it contains slashes, everything up to the rightmost slash is ignored.") - useRequestContext = flag.Bool("request_context", true, "determine whether to use http.Request's context or not") - allowDeleteBody = flag.Bool("allow_delete_body", false, "unless set, HTTP DELETE methods may not have a body") + importPrefix = flag.String("import_prefix", "", "prefix to be added to go package paths for imported proto files") + importPath = flag.String("import_path", "", "used as the package if no input files declare go_package. If it contains slashes, everything up to the rightmost slash is ignored.") + registerFuncSuffix = flag.String("register_func_suffix", "Handler", "used to construct names of generated Register* methods.") + useRequestContext = flag.Bool("request_context", true, "determine whether to use http.Request's context or not") + allowDeleteBody = flag.Bool("allow_delete_body", false, "unless set, HTTP DELETE methods may not have a body") ) func parseReq(r io.Reader) (*plugin.CodeGeneratorRequest, error) { @@ -76,7 +77,7 @@ func main() { } } - g := gengateway.New(reg, *useRequestContext) + g := gengateway.New(reg, *useRequestContext, *registerFuncSuffix) reg.SetPrefix(*importPrefix) reg.SetImportPath(*importPath)