diff --git a/protoc-gen-grpc-gateway/BUILD.bazel b/protoc-gen-grpc-gateway/BUILD.bazel index 24b8783ad5a..5aa586cbe56 100644 --- a/protoc-gen-grpc-gateway/BUILD.bazel +++ b/protoc-gen-grpc-gateway/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library", "go_test") load("@io_bazel_rules_go//proto:compiler.bzl", "go_proto_compiler") package(default_visibility = ["//visibility:private"]) @@ -41,3 +41,10 @@ go_proto_compiler( "@org_golang_google_protobuf//proto:go_default_library", ], ) + +go_test( + name = "go_default_test", + srcs = ["main_test.go"], + embed = [":go_default_library"], + deps = ["//internal/descriptor:go_default_library"], +) diff --git a/protoc-gen-grpc-gateway/main.go b/protoc-gen-grpc-gateway/main.go index 8fba8b38df2..3ebecc81a6b 100644 --- a/protoc-gen-grpc-gateway/main.go +++ b/protoc-gen-grpc-gateway/main.go @@ -107,6 +107,10 @@ func main() { } func parseFlags(reg *descriptor.Registry, parameter string) { + if parameter == "" { + return + } + for _, p := range strings.Split(parameter, ",") { spec := strings.SplitN(p, "=", 2) if len(spec) == 1 { diff --git a/protoc-gen-grpc-gateway/main_test.go b/protoc-gen-grpc-gateway/main_test.go new file mode 100644 index 00000000000..f22d535b55f --- /dev/null +++ b/protoc-gen-grpc-gateway/main_test.go @@ -0,0 +1,31 @@ +package main + +import ( + "testing" + + "github.com/grpc-ecosystem/grpc-gateway/v2/internal/descriptor" +) + +func TestParseFlagsEmptyNoPanic(t *testing.T) { + reg := descriptor.NewRegistry() + parseFlags(reg, "") +} + +func TestParseFlags(t *testing.T) { + reg := descriptor.NewRegistry() + parseFlags(reg, "allow_repeated_fields_in_body=true") + if *allowRepeatedFieldsInBody != true { + t.Errorf("flag allow_repeated_fields_in_body was not set correctly, wanted true got %v", *allowRepeatedFieldsInBody) + } +} + +func TestParseFlagsMultiple(t *testing.T) { + reg := descriptor.NewRegistry() + parseFlags(reg, "allow_repeated_fields_in_body=true,import_prefix=foo") + if *allowRepeatedFieldsInBody != true { + t.Errorf("flag allow_repeated_fields_in_body was not set correctly, wanted 'true' got '%v'", *allowRepeatedFieldsInBody) + } + if *importPrefix != "foo" { + t.Errorf("flag importPrefix was not set correctly, wanted 'foo' got '%v'", *importPrefix) + } +}