diff --git a/protoc-gen-grpc-gateway/descriptor/registry.go b/protoc-gen-grpc-gateway/descriptor/registry.go index 091c41e3019..f31cb1a311d 100644 --- a/protoc-gen-grpc-gateway/descriptor/registry.go +++ b/protoc-gen-grpc-gateway/descriptor/registry.go @@ -44,7 +44,22 @@ func (r *Registry) Load(req *plugin.CodeGeneratorRequest) error { for _, file := range req.GetProtoFile() { r.loadFile(file) } - for _, target := range req.FileToGenerate { + + var targetPkg string + for _, name := range req.FileToGenerate { + target := r.files[name] + if target == nil { + return fmt.Errorf("no such file: %s", name) + } + name := packageIdentityName(target.FileDescriptorProto) + if targetPkg == "" { + targetPkg = name + } else { + if targetPkg != name { + return fmt.Errorf("inconsistent package names: %s %s", targetPkg, name) + } + } + if err := r.loadServices(target); err != nil { return err } @@ -170,17 +185,20 @@ func (r *Registry) goPackagePath(f *descriptor.FileDescriptorProto) string { if pkg, ok := r.pkgMap[name]; ok { return path.Join(r.prefix, pkg) } - - ext := filepath.Ext(name) - if ext == ".protodevel" || ext == ".proto" { - name = strings.TrimSuffix(name, ext) - } - return path.Join(r.prefix, fmt.Sprintf("%s.pb", name)) + return path.Join(r.prefix, path.Dir(name)) } // defaultGoPackageName returns the default go package name to be used for go files generated from "f". // You might need to use an unique alias for the package when you import it. Use ReserveGoPackageAlias to get a unique alias. func defaultGoPackageName(f *descriptor.FileDescriptorProto) string { + name := packageIdentityName(f) + return strings.Replace(name, ".", "_", -1) +} + +// packageIdentityName returns the identity of packages. +// protoc-gen-grpc-gateway rejects CodeGenerationRequests which contains more than one packages +// as protoc-gen-go does. +func packageIdentityName(f *descriptor.FileDescriptorProto) string { if f.Options != nil && f.Options.GoPackage != nil { return f.Options.GetGoPackage() } @@ -190,5 +208,5 @@ func defaultGoPackageName(f *descriptor.FileDescriptorProto) string { ext := filepath.Ext(base) return strings.TrimSuffix(base, ext) } - return strings.Replace(f.GetPackage(), ".", "_", -1) + return f.GetPackage() } diff --git a/protoc-gen-grpc-gateway/descriptor/registry_test.go b/protoc-gen-grpc-gateway/descriptor/registry_test.go index 95069f17bba..ff3f4dc1799 100644 --- a/protoc-gen-grpc-gateway/descriptor/registry_test.go +++ b/protoc-gen-grpc-gateway/descriptor/registry_test.go @@ -5,9 +5,10 @@ import ( "github.com/golang/protobuf/proto" descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" + plugin "github.com/golang/protobuf/protoc-gen-go/plugin" ) -func load(t *testing.T, reg *Registry, src string) *descriptor.FileDescriptorProto { +func loadFile(t *testing.T, reg *Registry, src string) *descriptor.FileDescriptorProto { var file descriptor.FileDescriptorProto if err := proto.UnmarshalText(src, &file); err != nil { t.Fatalf("proto.UnmarshalText(%s, &file) failed with %v; want success", src, err) @@ -16,9 +17,17 @@ func load(t *testing.T, reg *Registry, src string) *descriptor.FileDescriptorPro return &file } +func load(t *testing.T, reg *Registry, src string) error { + var req plugin.CodeGeneratorRequest + if err := proto.UnmarshalText(src, &req); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &file) failed with %v; want success", src, err) + } + return reg.Load(&req) +} + func TestLoadFile(t *testing.T) { reg := NewRegistry() - fd := load(t, reg, ` + fd := loadFile(t, reg, ` name: 'example.proto' package: 'example' message_type < @@ -37,7 +46,7 @@ func TestLoadFile(t *testing.T) { t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") return } - wantPkg := GoPackage{Path: "example.pb", Name: "example"} + wantPkg := GoPackage{Path: ".", Name: "example"} if got, want := file.GoPkg, wantPkg; got != want { t.Errorf("file.GoPkg = %#v; want %#v", got, want) } @@ -74,7 +83,7 @@ func TestLoadFile(t *testing.T) { func TestLoadFileNestedPackage(t *testing.T) { reg := NewRegistry() - load(t, reg, ` + loadFile(t, reg, ` name: 'example.proto' package: 'example.nested.nested2' `) @@ -84,7 +93,7 @@ func TestLoadFileNestedPackage(t *testing.T) { t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") return } - wantPkg := GoPackage{Path: "example.pb", Name: "example_nested_nested2"} + wantPkg := GoPackage{Path: ".", Name: "example_nested_nested2"} if got, want := file.GoPkg, wantPkg; got != want { t.Errorf("file.GoPkg = %#v; want %#v", got, want) } @@ -92,7 +101,7 @@ func TestLoadFileNestedPackage(t *testing.T) { func TestLoadFileWithDir(t *testing.T) { reg := NewRegistry() - load(t, reg, ` + loadFile(t, reg, ` name: 'path/to/example.proto' package: 'example' `) @@ -102,7 +111,7 @@ func TestLoadFileWithDir(t *testing.T) { t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") return } - wantPkg := GoPackage{Path: "path/to/example.pb", Name: "example"} + wantPkg := GoPackage{Path: "path/to", Name: "example"} if got, want := file.GoPkg, wantPkg; got != want { t.Errorf("file.GoPkg = %#v; want %#v", got, want) } @@ -110,7 +119,7 @@ func TestLoadFileWithDir(t *testing.T) { func TestLoadFileWithoutPackage(t *testing.T) { reg := NewRegistry() - load(t, reg, ` + loadFile(t, reg, ` name: 'path/to/example_file.proto' `) @@ -119,7 +128,7 @@ func TestLoadFileWithoutPackage(t *testing.T) { t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") return } - wantPkg := GoPackage{Path: "path/to/example_file.pb", Name: "example_file"} + wantPkg := GoPackage{Path: "path/to", Name: "example_file"} if got, want := file.GoPkg, wantPkg; got != want { t.Errorf("file.GoPkg = %#v; want %#v", got, want) } @@ -128,7 +137,7 @@ func TestLoadFileWithoutPackage(t *testing.T) { func TestLoadFileWithMapping(t *testing.T) { reg := NewRegistry() reg.AddPkgMap("path/to/example.proto", "example.com/proj/example/proto") - load(t, reg, ` + loadFile(t, reg, ` name: 'path/to/example.proto' package: 'example' `) @@ -146,18 +155,18 @@ func TestLoadFileWithMapping(t *testing.T) { func TestLoadFileWithPackageNameCollision(t *testing.T) { reg := NewRegistry() - load(t, reg, ` + loadFile(t, reg, ` name: 'path/to/another.proto' package: 'example' `) - load(t, reg, ` + loadFile(t, reg, ` name: 'path/to/example.proto' package: 'example' `) if err := reg.ReserveGoPackageAlias("ioutil", "io/ioutil"); err != nil { t.Fatalf("reg.ReserveGoPackageAlias(%q) failed with %v; want success", "ioutil", err) } - load(t, reg, ` + loadFile(t, reg, ` name: 'path/to/ioutil.proto' package: 'ioutil' `) @@ -167,7 +176,7 @@ func TestLoadFileWithPackageNameCollision(t *testing.T) { t.Errorf("reg.files[%q] = nil; want non-nil", "path/to/another.proto") return } - wantPkg := GoPackage{Path: "path/to/another.pb", Name: "example"} + wantPkg := GoPackage{Path: "path/to", Name: "example"} if got, want := file.GoPkg, wantPkg; got != want { t.Errorf("file.GoPkg = %#v; want %#v", got, want) } @@ -177,7 +186,7 @@ func TestLoadFileWithPackageNameCollision(t *testing.T) { t.Errorf("reg.files[%q] = nil; want non-nil", "path/to/example.proto") return } - wantPkg = GoPackage{Path: "path/to/example.pb", Name: "example", Alias: "example_0"} + wantPkg = GoPackage{Path: "path/to", Name: "example", Alias: ""} if got, want := file.GoPkg, wantPkg; got != want { t.Errorf("file.GoPkg = %#v; want %#v", got, want) } @@ -187,7 +196,7 @@ func TestLoadFileWithPackageNameCollision(t *testing.T) { t.Errorf("reg.files[%q] = nil; want non-nil", "path/to/ioutil.proto") return } - wantPkg = GoPackage{Path: "path/to/ioutil.pb", Name: "ioutil", Alias: "ioutil_0"} + wantPkg = GoPackage{Path: "path/to", Name: "ioutil", Alias: "ioutil_0"} if got, want := file.GoPkg, wantPkg; got != want { t.Errorf("file.GoPkg = %#v; want %#v", got, want) } @@ -197,11 +206,11 @@ func TestLoadFileWithIdenticalGoPkg(t *testing.T) { reg := NewRegistry() reg.AddPkgMap("path/to/another.proto", "example.com/example") reg.AddPkgMap("path/to/example.proto", "example.com/example") - load(t, reg, ` + loadFile(t, reg, ` name: 'path/to/another.proto' package: 'example' `) - load(t, reg, ` + loadFile(t, reg, ` name: 'path/to/example.proto' package: 'example' `) @@ -230,7 +239,7 @@ func TestLoadFileWithIdenticalGoPkg(t *testing.T) { func TestLoadFileWithPrefix(t *testing.T) { reg := NewRegistry() reg.SetPrefix("third_party") - load(t, reg, ` + loadFile(t, reg, ` name: 'path/to/example.proto' package: 'example' `) @@ -240,7 +249,7 @@ func TestLoadFileWithPrefix(t *testing.T) { t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto") return } - wantPkg := GoPackage{Path: "third_party/path/to/example.pb", Name: "example"} + wantPkg := GoPackage{Path: "third_party/path/to", Name: "example"} if got, want := file.GoPkg, wantPkg; got != want { t.Errorf("file.GoPkg = %#v; want %#v", got, want) } @@ -248,7 +257,7 @@ func TestLoadFileWithPrefix(t *testing.T) { func TestLookupMsgWithoutPackage(t *testing.T) { reg := NewRegistry() - fd := load(t, reg, ` + fd := loadFile(t, reg, ` name: 'example.proto' message_type < name: 'ExampleMessage' @@ -273,7 +282,7 @@ func TestLookupMsgWithoutPackage(t *testing.T) { func TestLookupMsgWithNestedPackage(t *testing.T) { reg := NewRegistry() - fd := load(t, reg, ` + fd := loadFile(t, reg, ` name: 'example.proto' package: 'nested.nested2.mypackage' message_type < @@ -344,3 +353,181 @@ func TestLookupMsgWithNestedPackage(t *testing.T) { } } } + +func TestLoadWithInconsistentTargetPackage(t *testing.T) { + for _, spec := range []struct { + req string + consistent bool + }{ + // root package, no explicit go package + { + req: ` + file_to_generate: 'a.proto' + file_to_generate: 'b.proto' + proto_file < + name: 'a.proto' + message_type < name: 'A' > + service < + name: "AService" + method < + name: "Meth" + input_type: "A" + output_type: "A" + options < + [google.api.http] < post: "/v1/a" body: "*" > + > + > + > + > + proto_file < + name: 'b.proto' + message_type < name: 'B' > + service < + name: "BService" + method < + name: "Meth" + input_type: "B" + output_type: "B" + options < + [google.api.http] < post: "/v1/b" body: "*" > + > + > + > + > + `, + consistent: false, + }, + // named package, no explicit go package + { + req: ` + file_to_generate: 'a.proto' + file_to_generate: 'b.proto' + proto_file < + name: 'a.proto' + package: 'example.foo' + message_type < name: 'A' > + service < + name: "AService" + method < + name: "Meth" + input_type: "A" + output_type: "A" + options < + [google.api.http] < post: "/v1/a" body: "*" > + > + > + > + > + proto_file < + name: 'b.proto' + package: 'example.foo' + message_type < name: 'B' > + service < + name: "BService" + method < + name: "Meth" + input_type: "B" + output_type: "B" + options < + [google.api.http] < post: "/v1/b" body: "*" > + > + > + > + > + `, + consistent: true, + }, + // root package, explicit go package + { + req: ` + file_to_generate: 'a.proto' + file_to_generate: 'b.proto' + proto_file < + name: 'a.proto' + options < go_package: 'foo' > + message_type < name: 'A' > + service < + name: "AService" + method < + name: "Meth" + input_type: "A" + output_type: "A" + options < + [google.api.http] < post: "/v1/a" body: "*" > + > + > + > + > + proto_file < + name: 'b.proto' + options < go_package: 'foo' > + message_type < name: 'B' > + service < + name: "BService" + method < + name: "Meth" + input_type: "B" + output_type: "B" + options < + [google.api.http] < post: "/v1/b" body: "*" > + > + > + > + > + `, + consistent: true, + }, + // named package, explicit go package + { + req: ` + file_to_generate: 'a.proto' + file_to_generate: 'b.proto' + proto_file < + name: 'a.proto' + package: 'example.foo' + options < go_package: 'foo' > + message_type < name: 'A' > + service < + name: "AService" + method < + name: "Meth" + input_type: "A" + output_type: "A" + options < + [google.api.http] < post: "/v1/a" body: "*" > + > + > + > + > + proto_file < + name: 'b.proto' + package: 'example.foo' + options < go_package: 'foo' > + message_type < name: 'B' > + service < + name: "BService" + method < + name: "Meth" + input_type: "B" + output_type: "B" + options < + [google.api.http] < post: "/v1/b" body: "*" > + > + > + > + > + `, + consistent: true, + }, + } { + reg := NewRegistry() + err := load(t, reg, spec.req) + if got, want := err == nil, spec.consistent; got != want { + if want { + t.Errorf("reg.Load(%s) failed with %v; want success", spec.req, err) + continue + } + t.Errorf("reg.Load(%s) succeeded; want an package inconsistency error", spec.req) + } + } +} diff --git a/protoc-gen-grpc-gateway/descriptor/services.go b/protoc-gen-grpc-gateway/descriptor/services.go index 49ac4c317d9..728dda8967b 100644 --- a/protoc-gen-grpc-gateway/descriptor/services.go +++ b/protoc-gen-grpc-gateway/descriptor/services.go @@ -14,12 +14,8 @@ import ( // loadServices registers services and their methods from "targetFile" to "r". // It must be called after loadFile is called for all files so that loadServices // can resolve names of message types and their fields. -func (r *Registry) loadServices(targetFile string) error { - glog.V(1).Infof("Loading services from %s", targetFile) - file := r.files[targetFile] - if file == nil { - return fmt.Errorf("no such file: %s", targetFile) - } +func (r *Registry) loadServices(file *File) error { + glog.V(1).Infof("Loading services from %s", file.GetName()) var svcs []*Service for _, sd := range file.GetService() { glog.V(2).Infof("Registering %s", sd.GetName()) diff --git a/protoc-gen-grpc-gateway/descriptor/services_test.go b/protoc-gen-grpc-gateway/descriptor/services_test.go index 323147904f0..f2df37011d5 100644 --- a/protoc-gen-grpc-gateway/descriptor/services_test.go +++ b/protoc-gen-grpc-gateway/descriptor/services_test.go @@ -22,7 +22,7 @@ func testExtractServices(t *testing.T, input []*descriptor.FileDescriptorProto, for _, file := range input { reg.loadFile(file) } - err := reg.loadServices(target) + err := reg.loadServices(reg.files[target]) if err != nil { t.Errorf("loadServices(%q) failed with %v; want success; files=%v", target, err, input) } @@ -914,7 +914,7 @@ func TestExtractServicesWithError(t *testing.T) { reg.loadFile(&fd) fds = append(fds, &fd) } - err := reg.loadServices(spec.target) + err := reg.loadServices(reg.files[spec.target]) if err == nil { t.Errorf("loadServices(%q) succeeded; want an error; files=%v", spec.target, spec.srcs) } diff --git a/protoc-gen-grpc-gateway/main.go b/protoc-gen-grpc-gateway/main.go index 454eb1fbbc3..32a719c2b23 100644 --- a/protoc-gen-grpc-gateway/main.go +++ b/protoc-gen-grpc-gateway/main.go @@ -65,14 +65,14 @@ func main() { } } + g := gengateway.New(reg) + reg.SetPrefix(*importPrefix) if err := reg.Load(req); err != nil { emitError(err) return } - g := gengateway.New(reg) - var targets []*descriptor.File for _, target := range req.FileToGenerate { f, err := reg.LookupFile(target)