Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add middleware support #1

Merged
merged 9 commits into from
Jan 23, 2017
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"strings"

"github.com/golang/glog"
"github.com/grpc-ecosystem/grpc-gateway/examples/examplepb"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/shilkin/grpc-gateway/examples/examplepb"
"github.com/shilkin/grpc-gateway/runtime"
"golang.org/x/net/context"
"google.golang.org/grpc"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"testing"
"time"

server "github.com/grpc-ecosystem/grpc-gateway/examples/server"
server "github.com/shilkin/grpc-gateway/examples/server"
)

func runServers() <-chan error {
Expand Down
2 changes: 2 additions & 0 deletions options/generate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/usr/bin/env bash
protoc --go_out=Mgoogle/protobuf/descriptor.proto=github.com/golang/protobuf/protoc-gen-go/descriptor:. middleware.proto
11 changes: 11 additions & 0 deletions options/middleware.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto3";

option go_package = "options";

package gengo.grpc.gateway;

import "google/protobuf/descriptor.proto";

extend google.protobuf.MethodOptions {
repeated string middleware = 72295730;
}
8 changes: 8 additions & 0 deletions protoc-gen-grpc-gateway/descriptor/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package descriptor

import options "github.com/shilkin/grpc-gateway/third_party/googleapis/google/api"

type apiOptions struct {
httpRule *options.HttpRule
middleware []string
}
86 changes: 53 additions & 33 deletions protoc-gen-grpc-gateway/descriptor/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import (
"github.com/golang/glog"
"github.com/golang/protobuf/proto"
descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
options "github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis/google/api"
gateway_options "github.com/shilkin/grpc-gateway/options"
"github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/httprule"
google_options "github.com/shilkin/grpc-gateway/third_party/googleapis/google/api"
)

// loadServices registers services and their methods from "targetFile" to "r".
Expand All @@ -24,7 +25,7 @@ func (r *Registry) loadServices(file *File) error {
ServiceDescriptorProto: sd,
}
for _, md := range sd.GetMethod() {
glog.V(2).Infof("Processing %s.%s", sd.GetName(), md.GetName())
glog.V(2).Infof("> Processing %s.%s", sd.GetName(), md.GetName())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for what?

opts, err := extractAPIOptions(md)
if err != nil {
glog.Errorf("Failed to extract ApiMethodOptions from %s.%s: %v", svc.GetName(), md.GetName(), err)
Expand All @@ -33,6 +34,7 @@ func (r *Registry) loadServices(file *File) error {
if opts == nil {
glog.V(1).Infof("Found non-target method: %s.%s", svc.GetName(), md.GetName())
}
glog.V(2).Infof("API options for %s.%s: %#v", svc.GetName(), md.GetName(), opts)
meth, err := r.newMethod(svc, md, opts)
if err != nil {
return err
Expand All @@ -49,7 +51,7 @@ func (r *Registry) loadServices(file *File) error {
return nil
}

func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, opts *options.HttpRule) (*Method, error) {
func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, opts *apiOptions) (*Method, error) {
requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType())
if err != nil {
return nil, err
Expand All @@ -65,40 +67,40 @@ func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto,
ResponseType: responseType,
}

newBinding := func(opts *options.HttpRule, idx int) (*Binding, error) {
newBinding := func(opts *apiOptions, idx int) (*Binding, error) {
var (
httpMethod string
pathTemplate string
)
switch {
case opts.GetGet() != "":
case opts.httpRule.GetGet() != "":
httpMethod = "GET"
pathTemplate = opts.GetGet()
if opts.Body != "" {
pathTemplate = opts.httpRule.GetGet()
if opts.httpRule.Body != "" {
return nil, fmt.Errorf("needs request body even though http method is GET: %s", md.GetName())
}

case opts.GetPut() != "":
case opts.httpRule.GetPut() != "":
httpMethod = "PUT"
pathTemplate = opts.GetPut()
pathTemplate = opts.httpRule.GetPut()

case opts.GetPost() != "":
case opts.httpRule.GetPost() != "":
httpMethod = "POST"
pathTemplate = opts.GetPost()
pathTemplate = opts.httpRule.GetPost()

case opts.GetDelete() != "":
case opts.httpRule.GetDelete() != "":
httpMethod = "DELETE"
pathTemplate = opts.GetDelete()
if opts.Body != "" && !r.allowDeleteBody {
pathTemplate = opts.httpRule.GetDelete()
if opts.httpRule.Body != "" && !r.allowDeleteBody {
return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName())
}

case opts.GetPatch() != "":
case opts.httpRule.GetPatch() != "":
httpMethod = "PATCH"
pathTemplate = opts.GetPatch()
pathTemplate = opts.httpRule.GetPatch()

case opts.GetCustom() != nil:
custom := opts.GetCustom()
case opts.httpRule.GetCustom() != nil:
custom := opts.httpRule.GetCustom()
httpMethod = custom.Kind
pathTemplate = custom.Path

Expand All @@ -122,6 +124,7 @@ func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto,
Index: idx,
PathTmpl: tmpl,
HTTPMethod: httpMethod,
Middleware: opts.middleware,
}

for _, f := range tmpl.Fields {
Expand All @@ -134,7 +137,7 @@ func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto,

// TODO(yugui) Handle query params

b.Body, err = r.newBody(meth, opts.Body)
b.Body, err = r.newBody(meth, opts.httpRule.Body)
if err != nil {
return nil, err
}
Expand All @@ -149,11 +152,12 @@ func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto,
if b != nil {
meth.Bindings = append(meth.Bindings, b)
}
for i, additional := range opts.GetAdditionalBindings() {
for i, additional := range opts.httpRule.GetAdditionalBindings() {
if len(additional.AdditionalBindings) > 0 {
return nil, fmt.Errorf("additional_binding in additional_binding not allowed: %s.%s", svc.GetName(), meth.GetName())
}
b, err := newBinding(additional, i+1)
apiOpts := &apiOptions{httpRule: additional, middleware: opts.middleware}
b, err := newBinding(apiOpts, i+1)
if err != nil {
return nil, err
}
Expand All @@ -163,22 +167,38 @@ func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto,
return meth, nil
}

func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*options.HttpRule, error) {
func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*apiOptions, error) { // (*options.HttpRule, error) {
var opts apiOptions

if meth.Options == nil {
return nil, nil
}
if !proto.HasExtension(meth.Options, options.E_Http) {
return nil, nil
}
ext, err := proto.GetExtension(meth.Options, options.E_Http)
if err != nil {
return nil, err
// google api extension
if proto.HasExtension(meth.Options, google_options.E_Http) {
ext, err := proto.GetExtension(meth.Options, google_options.E_Http)
if err != nil {
return nil, err
}
httpRule, ok := ext.(*google_options.HttpRule)
if !ok {
return nil, fmt.Errorf("extension is %T; want an HttpRule", ext)
}
opts.httpRule = httpRule
}
opts, ok := ext.(*options.HttpRule)
if !ok {
return nil, fmt.Errorf("extension is %T; want an HttpRule", ext)
// grpc gateway middleware extension
if proto.HasExtension(meth.Options, gateway_options.E_Middleware) {
ext, err := proto.GetExtension(meth.Options, gateway_options.E_Middleware)
if err != nil {
return nil, err
}
middleware, ok := ext.([]string)
if !ok {
return nil, fmt.Errorf("extension is %T; want an []string", ext)
}
opts.middleware = middleware
}
return opts, nil

return &opts, nil
}

func (r *Registry) newParam(meth *Method, path string) (Parameter, error) {
Expand Down
2 changes: 1 addition & 1 deletion protoc-gen-grpc-gateway/descriptor/services_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/golang/protobuf/proto"
descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
"github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/httprule"
)

func compilePath(t *testing.T, path string) httprule.Template {
Expand Down
4 changes: 3 additions & 1 deletion protoc-gen-grpc-gateway/descriptor/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
gogen "github.com/golang/protobuf/protoc-gen-go/generator"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
"github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/httprule"
)

// GoPackage represents a golang package
Expand Down Expand Up @@ -151,6 +151,8 @@ type Binding struct {
PathParams []Parameter
// Body describes parameters provided in HTTP request body.
Body *Body
// Middleware is the list of middleware names
Middleware []string
}

// ExplicitParams returns a list of explicitly bound parameters of "b",
Expand Down
2 changes: 1 addition & 1 deletion protoc-gen-grpc-gateway/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package generator

import (
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
)

// Generator is an abstraction of code generators.
Expand Down
8 changes: 4 additions & 4 deletions protoc-gen-grpc-gateway/gengateway/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
"github.com/golang/glog"
"github.com/golang/protobuf/proto"
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
gen "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/generator"
"github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
gen "github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/generator"
)

var (
Expand All @@ -30,8 +30,8 @@ func New(reg *descriptor.Registry) gen.Generator {
for _, pkgpath := range []string{
"io",
"net/http",
"github.com/grpc-ecosystem/grpc-gateway/runtime",
"github.com/grpc-ecosystem/grpc-gateway/utilities",
"github.com/shilkin/grpc-gateway/runtime",
"github.com/shilkin/grpc-gateway/utilities",
"github.com/golang/protobuf/proto",
"golang.org/x/net/context",
"google.golang.org/grpc",
Expand Down
40 changes: 35 additions & 5 deletions protoc-gen-grpc-gateway/gengateway/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"text/template"

"github.com/golang/glog"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/utilities"
"github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/shilkin/grpc-gateway/utilities"
)

type param struct {
Expand Down Expand Up @@ -299,6 +299,13 @@ var (
// Register{{$svc.GetName}}HandlerFromEndpoint is same as Register{{$svc.GetName}}Handler but
// automatically dials to "endpoint" and closes the connection when "ctx" gets done.
func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) {
middleware := map[string]runtime.Middleware{}
return Register{{$svc.GetName}}HandlerFromEndpointWithMiddleware(ctx, mux, middleware, endpoint, opts)
}

// Register{{$svc.GetName}}HandlerFromEndpointMiddlware is same as Register{{$svc.GetName}}HandlerMiddleware but
// automatically dials to "endpoint" and closes the connection when "ctx" gets done.
func Register{{$svc.GetName}}HandlerFromEndpointWithMiddleware(ctx context.Context, mux *runtime.ServeMux, middleware map[string]runtime.Middleware, endpoint string, opts []grpc.DialOption) (err error) {
conn, err := grpc.Dial(endpoint, opts...)
if err != nil {
return err
Expand All @@ -318,16 +325,31 @@ func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runti
}()
}()

return Register{{$svc.GetName}}Handler(ctx, mux, conn)
return Register{{$svc.GetName}}HandlerWithMiddleware(ctx, mux, middleware, conn)
}

// Register{{$svc.GetName}}Handler registers the http handlers for service {{$svc.GetName}} to "mux".
// The handlers forward requests to the grpc endpoint over "conn".
func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {
middleware := map[string]runtime.Middleware{}
return Register{{$svc.GetName}}HandlerWithMiddleware(ctx, mux, middleware, conn)
}

// Register{{$svc.GetName}}HandlerMiddleware registers the http handlers for service {{$svc.GetName}} to "mux".
// The handlers forward requests to the grpc endpoint over "conn".
func Register{{$svc.GetName}}HandlerWithMiddleware(ctx context.Context, mux *runtime.ServeMux, middleware map[string]runtime.Middleware, conn *grpc.ClientConn) error {
client := New{{$svc.GetName}}Client(conn)
var handler runtime.HandlerFunc
var mw []string

{{range $m := $svc.Methods}}
{{range $b := $m.Bindings}}
mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {

mw = []string{ {{range $name := $b.Middleware}}
"{{$name}}",
{{end}} }

handler = func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if cn, ok := w.(http.CloseNotifier); ok {
Expand Down Expand Up @@ -355,7 +377,15 @@ func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux,
{{else}}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
{{end}}
})
}

for _, name := range mw {
if m, ok := middleware[name]; ok {
handler = m(handler)
}
}

mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, handler)
{{end}}
{{end}}
return nil
Expand Down
4 changes: 2 additions & 2 deletions protoc-gen-grpc-gateway/gengateway/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (

"github.com/golang/protobuf/proto"
protodescriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
"github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/httprule"
)

func crossLinkFixture(f *descriptor.File) *descriptor.File {
Expand Down
2 changes: 1 addition & 1 deletion protoc-gen-grpc-gateway/httprule/compile.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package httprule

import (
"github.com/grpc-ecosystem/grpc-gateway/utilities"
"github.com/shilkin/grpc-gateway/utilities"
)

const (
Expand Down
2 changes: 1 addition & 1 deletion protoc-gen-grpc-gateway/httprule/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"reflect"
"testing"

"github.com/grpc-ecosystem/grpc-gateway/utilities"
"github.com/shilkin/grpc-gateway/utilities"
)

const (
Expand Down
4 changes: 2 additions & 2 deletions protoc-gen-grpc-gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import (
"github.com/golang/glog"
"github.com/golang/protobuf/proto"
plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/gengateway"
"github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/descriptor"
"github.com/shilkin/grpc-gateway/protoc-gen-grpc-gateway/gengateway"
)

var (
Expand Down
Loading