Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runtime: remove DisallowUnknownFields() #1386

Merged
merged 1 commit into from
May 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions examples/internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func testEcho(t *testing.T, port int, apiPrefix string, contentType string) {
t.Logf("%s", buf)
}

var msg examplepb.SimpleMessage
var msg examplepb.UnannotatedSimpleMessage
if err := marshaler.Unmarshal(buf, &msg); err != nil {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
Expand Down Expand Up @@ -135,7 +135,7 @@ func testEchoOneof(t *testing.T, port int, apiPrefix string, contentType string)
t.Logf("%s", buf)
}

var msg examplepb.SimpleMessage
var msg examplepb.UnannotatedSimpleMessage
if err := marshaler.Unmarshal(buf, &msg); err != nil {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
Expand Down Expand Up @@ -168,7 +168,7 @@ func testEchoOneof1(t *testing.T, port int, apiPrefix string, contentType string
t.Logf("%s", buf)
}

var msg examplepb.SimpleMessage
var msg examplepb.UnannotatedSimpleMessage
if err := marshaler.Unmarshal(buf, &msg); err != nil {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
Expand Down Expand Up @@ -201,7 +201,7 @@ func testEchoOneof2(t *testing.T, port int, apiPrefix string, contentType string
t.Logf("%s", buf)
}

var msg examplepb.SimpleMessage
var msg examplepb.UnannotatedSimpleMessage
if err := marshaler.Unmarshal(buf, &msg); err != nil {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
Expand All @@ -216,7 +216,7 @@ func testEchoOneof2(t *testing.T, port int, apiPrefix string, contentType string
}

func testEchoBody(t *testing.T, port int, apiPrefix string) {
sent := examplepb.SimpleMessage{Id: "example"}
sent := examplepb.UnannotatedSimpleMessage{Id: "example"}
payload, err := marshaler.Marshal(&sent)
if err != nil {
t.Fatalf("marshaler.Marshal(%#v) failed with %v; want success", payload, err)
Expand All @@ -240,12 +240,12 @@ func testEchoBody(t *testing.T, port int, apiPrefix string) {
t.Logf("%s", buf)
}

var received examplepb.SimpleMessage
var received examplepb.UnannotatedSimpleMessage
if err := marshaler.Unmarshal(buf, &received); err != nil {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
}
if diff := cmp.Diff(received, sent, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(&received, &sent, protocmp.Transform()); diff != "" {
t.Errorf(diff)
}

Expand Down Expand Up @@ -334,7 +334,7 @@ func testABECreate(t *testing.T, port int) {
t.Error("msg.Uuid is empty; want not empty")
}
msg.Uuid = ""
if diff := cmp.Diff(msg, want, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(&msg, &want, protocmp.Transform()); diff != "" {
t.Errorf(diff)
}
}
Expand Down Expand Up @@ -442,7 +442,7 @@ func testABECreateBody(t *testing.T, port int) {
t.Error("msg.Uuid is empty; want not empty")
}
msg.Uuid = ""
if diff := cmp.Diff(msg, want, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(&msg, &want, protocmp.Transform()); diff != "" {
t.Errorf(diff)
}
}
Expand Down Expand Up @@ -673,7 +673,7 @@ func testABELookup(t *testing.T, port int) {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
}
if diff := cmp.Diff(msg, want, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(&msg, &want, protocmp.Transform()); diff != "" {
t.Errorf(diff)
}

Expand Down Expand Up @@ -1340,7 +1340,7 @@ func testABERepeated(t *testing.T, port int) {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
}
if diff := cmp.Diff(msg, want, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(&msg, &want, protocmp.Transform()); diff != "" {
t.Errorf(diff)
}
}
Expand Down Expand Up @@ -1590,15 +1590,16 @@ func testResponseBodies(t *testing.T, port int) {
t.Logf("%s", buf)
}

var got []*examplepb.ResponseBodyOut_Response
var got []*examplepb.RepeatedResponseBodyOut_Response
err = marshaler.Unmarshal(buf, &got)
if err != nil {
t.Errorf("marshaler.Unmarshal failed with %v; want success", err)
return
}
want := []*examplepb.ResponseBodyOut_Response{
want := []*examplepb.RepeatedResponseBodyOut_Response{
{
Data: "foo",
Type: examplepb.RepeatedResponseBodyOut_Response_UNKNOWN,
},
}
if diff := cmp.Diff(got, want, protocmp.Transform()); diff != "" {
Expand Down Expand Up @@ -1708,15 +1709,16 @@ func testResponseStrings(t *testing.T, port int) {
t.Logf("%s", buf)
}

var got []*examplepb.ResponseBodyOut_Response
var got []*examplepb.RepeatedResponseBodyOut_Response
err = marshaler.Unmarshal(buf, &got)
if err != nil {
t.Errorf("marshaler.Unmarshal failed with %v; want success", err)
return
}
want := []*examplepb.ResponseBodyOut_Response{
want := []*examplepb.RepeatedResponseBodyOut_Response{
{
Data: "foo",
Type: examplepb.RepeatedResponseBodyOut_Response_UNKNOWN,
},
}
if diff := cmp.Diff(got, want, protocmp.Transform()); diff != "" {
Expand Down
7 changes: 6 additions & 1 deletion examples/internal/server/responsebody.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"

"github.com/golang/glog"
examples "github.com/grpc-ecosystem/grpc-gateway/v2/examples/internal/proto/examplepb"
)

Expand All @@ -16,6 +17,7 @@ func newResponseBodyServer() examples.ResponseBodyServiceServer {
}

func (s *responseBodyServer) GetResponseBody(ctx context.Context, req *examples.ResponseBodyIn) (*examples.ResponseBodyOut, error) {
glog.Info(req)
return &examples.ResponseBodyOut{
Response: &examples.ResponseBodyOut_Response{
Data: req.Data,
Expand All @@ -24,16 +26,18 @@ func (s *responseBodyServer) GetResponseBody(ctx context.Context, req *examples.
}

func (s *responseBodyServer) ListResponseBodies(ctx context.Context, req *examples.ResponseBodyIn) (*examples.RepeatedResponseBodyOut, error) {
glog.Info(req)
return &examples.RepeatedResponseBodyOut{
Response: []*examples.RepeatedResponseBodyOut_Response{
&examples.RepeatedResponseBodyOut_Response{
{
Data: req.Data,
},
},
}, nil
}

func (s *responseBodyServer) ListResponseStrings(ctx context.Context, req *examples.ResponseBodyIn) (*examples.RepeatedResponseStrings, error) {
glog.Info(req)
if req.Data == "empty" {
return &examples.RepeatedResponseStrings{
Values: []string{},
Expand All @@ -45,6 +49,7 @@ func (s *responseBodyServer) ListResponseStrings(ctx context.Context, req *examp
}

func (s *responseBodyServer) GetResponseBodyStream(req *examples.ResponseBodyIn, stream examples.ResponseBodyService_GetResponseBodyStreamServer) error {
glog.Info(req)
if err := stream.Send(&examples.ResponseBodyOut{
Response: &examples.ResponseBodyOut_Response{
Data: fmt.Sprintf("first %s", req.Data),
Expand Down
42 changes: 14 additions & 28 deletions runtime/marshal_jsonpb.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,29 @@ func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) {

// Unmarshal unmarshals JSON "data" into "v"
func (j *JSONPb) Unmarshal(data []byte, v interface{}) error {
return unmarshalJSONPb(data, v)
return unmarshalJSONPb(data, j.UnmarshalOptions, v)
}

// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONPb) NewDecoder(r io.Reader) Decoder {
d := json.NewDecoder(r)
return DecoderWrapper{Decoder: d}
return DecoderWrapper{
Decoder: d,
UnmarshalOptions: j.UnmarshalOptions,
}
}

// DecoderWrapper is a wrapper around a *json.Decoder that adds
// support for protos to the Decode method.
type DecoderWrapper struct {
*json.Decoder
protojson.UnmarshalOptions
}

// Decode wraps the embedded decoder's Decode method to support
// protos using a jsonpb.Unmarshaler.
func (d DecoderWrapper) Decode(v interface{}) error {
return decodeJSONPb(d.Decoder, v)
return decodeJSONPb(d.Decoder, d.UnmarshalOptions, v)
}

// NewEncoder returns an Encoder which writes JSON stream into "w".
Expand All @@ -171,15 +175,15 @@ func (j *JSONPb) NewEncoder(w io.Writer) Encoder {
})
}

func unmarshalJSONPb(data []byte, v interface{}) error {
func unmarshalJSONPb(data []byte, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
d := json.NewDecoder(bytes.NewReader(data))
return decodeJSONPb(d, v)
return decodeJSONPb(d, unmarshaler, v)
}

func decodeJSONPb(d *json.Decoder, v interface{}) error {
func decodeJSONPb(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
return decodeNonProtoField(d, v)
return decodeNonProtoField(d, unmarshaler, v)
}

// Decode into bytes for marshalling
Expand All @@ -189,13 +193,10 @@ func decodeJSONPb(d *json.Decoder, v interface{}) error {
return err
}

unmarshaler := &protojson.UnmarshalOptions{
DiscardUnknown: allowUnknownFields,
}
return unmarshaler.Unmarshal([]byte(b), p)
}

func decodeNonProtoField(d *json.Decoder, v interface{}) error {
func decodeNonProtoField(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("%T is not a pointer", v)
Expand All @@ -212,9 +213,6 @@ func decodeNonProtoField(d *json.Decoder, v interface{}) error {
return err
}

unmarshaler := &protojson.UnmarshalOptions{
DiscardUnknown: allowUnknownFields,
}
return unmarshaler.Unmarshal([]byte(b), rv.Interface().(proto.Message))
}
rv = rv.Elem()
Expand All @@ -239,7 +237,7 @@ func decodeNonProtoField(d *json.Decoder, v interface{}) error {
}
bk := result[0]
bv := reflect.New(rv.Type().Elem())
if err := unmarshalJSONPb([]byte(*v), bv.Interface()); err != nil {
if err := unmarshalJSONPb([]byte(*v), unmarshaler, bv.Interface()); err != nil {
return err
}
rv.SetMapIndex(bk, bv.Elem())
Expand All @@ -256,7 +254,7 @@ func decodeNonProtoField(d *json.Decoder, v interface{}) error {
}
for _, item := range sl {
bv := reflect.New(rv.Type().Elem())
if err := unmarshalJSONPb([]byte(item), bv.Interface()); err != nil {
if err := unmarshalJSONPb([]byte(item), unmarshaler, bv.Interface()); err != nil {
return err
}
rv.Set(reflect.Append(rv, bv.Elem()))
Expand Down Expand Up @@ -294,18 +292,6 @@ func (j *JSONPb) Delimiter() []byte {
return []byte("\n")
}

// allowUnknownFields helps not to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
var allowUnknownFields = true

// DisallowUnknownFields enables option in decoder (unmarshaller) to
// return an error when it finds an unknown field. This function must be
// called before using the JSON marshaller.
func DisallowUnknownFields() {
allowUnknownFields = false
}

var (
convFromType = map[reflect.Kind]reflect.Value{
reflect.String: reflect.ValueOf(String),
Expand Down
8 changes: 5 additions & 3 deletions runtime/marshal_jsonpb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,16 +471,18 @@ func TestJSONPbDecoderFields(t *testing.T) {

func TestJSONPbDecoderUnknownField(t *testing.T) {
var (
m runtime.JSONPb
m = runtime.JSONPb{
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: false,
},
}
got examplepb.ABitOfEverything
)
data := `{
"uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
"unknownField": "111"
}`

runtime.DisallowUnknownFields()

r := strings.NewReader(data)
dec := m.NewDecoder(r)
if err := dec.Decode(&got); err == nil {
Expand Down
3 changes: 3 additions & 0 deletions runtime/marshaler_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ var (
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
},
}
)
Expand Down