Skip to content

Commit

Permalink
testing/protocmp: add MessageTypeResolver.
Browse files Browse the repository at this point in the history
Fixes golang/protobuf#1377

Change-Id: Idf06ba21fea3e2ede8176a8408eb08490707242b
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/552455
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Cassondra Foesch <[email protected]>
Reviewed-by: Damien Neil <[email protected]>
Reviewed-by: Michael Stapelberg <[email protected]>
Auto-Submit: Damien Neil <[email protected]>
  • Loading branch information
Tommie Gannert authored and gopherbot committed Jan 4, 2024
1 parent 7b78149 commit 18202d2
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 22 deletions.
2 changes: 1 addition & 1 deletion testing/protocmp/reflect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestReflect(t *testing.T) {

for _, src := range tests {
dst := src.ProtoReflect().Type().New().Interface()
proto.Merge(dst, transformMessage(src.ProtoReflect()))
proto.Merge(dst, newTransformer().transformMessage(src.ProtoReflect()))
if diff := cmp.Diff(src, dst, Transform()); diff != "" {
t.Errorf("Merge mismatch (-want +got):\n%s", diff)
}
Expand Down
63 changes: 43 additions & 20 deletions testing/protocmp/xform.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,32 @@ func (m Message) String() string {
}
}

type option struct{}
type transformer struct {
resolver protoregistry.MessageTypeResolver
}

func newTransformer(opts ...option) *transformer {
xf := &transformer{
resolver: protoregistry.GlobalTypes,
}
for _, opt := range opts {
opt(xf)
}
return xf
}

type option func(*transformer)

// MessageTypeResolver overrides the resolver used for messages packed
// inside Any. The default is protoregistry.GlobalTypes, which is
// sufficient for all compiled-in Protobuf messages. Overriding the
// resolver is useful in tests that dynamically create Protobuf
// descriptors and messages, e.g. in proxies using dynamicpb.
func MessageTypeResolver(r protoregistry.MessageTypeResolver) option {
return func(xf *transformer) {
xf.resolver = r
}
}

// Transform returns a [cmp.Option] that converts each [proto.Message] to a [Message].
// The transformation does not mutate nor alias any converted messages.
Expand All @@ -172,10 +197,9 @@ type option struct{}
// This does not directly transform higher-order composite Go types.
// For example, []*foopb.Message is not transformed into []Message,
// but rather the individual message elements of the slice are transformed.
//
// Note that there are currently no custom options for Transform,
// but the use of an unexported type keeps the future open.
func Transform(...option) cmp.Option {
func Transform(opts ...option) cmp.Option {
xf := newTransformer(opts...)

// addrType returns a pointer to t if t isn't a pointer or interface.
addrType := func(t reflect.Type) reflect.Type {
if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr {
Expand Down Expand Up @@ -218,7 +242,7 @@ func Transform(...option) cmp.Option {
case !m.IsValid():
return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
default:
return transformMessage(m)
return xf.transformMessage(m)
}
}))
}
Expand All @@ -231,7 +255,7 @@ func isMessageType(t reflect.Type) bool {
return t.Implements(messageV1Type) || t.Implements(messageV2Type)
}

func transformMessage(m protoreflect.Message) Message {
func (xf *transformer) transformMessage(m protoreflect.Message) Message {
mx := Message{}
mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}

Expand All @@ -243,11 +267,11 @@ func transformMessage(m protoreflect.Message) Message {
}
switch {
case fd.IsList():
mx[s] = transformList(fd, v.List())
mx[s] = xf.transformList(fd, v.List())
case fd.IsMap():
mx[s] = transformMap(fd, v.Map())
mx[s] = xf.transformMap(fd, v.Map())
default:
mx[s] = transformSingular(fd, v)
mx[s] = xf.transformSingular(fd, v)
}
return true
})
Expand All @@ -263,15 +287,14 @@ func transformMessage(m protoreflect.Message) Message {

// Expand Any messages.
if mt.md.FullName() == genid.Any_message_fullname {
// TODO: Expose Transform option to specify a custom resolver?
s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string)
b, _ := mx[string(genid.Any_Value_field_name)].([]byte)
mt, err := protoregistry.GlobalTypes.FindMessageByURL(s)
mt, err := xf.resolver.FindMessageByURL(s)
if mt != nil && err == nil {
m2 := mt.New()
err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface())
if err == nil {
mx[string(genid.Any_Value_field_name)] = transformMessage(m2)
mx[string(genid.Any_Value_field_name)] = xf.transformMessage(m2)
}
}
}
Expand All @@ -280,37 +303,37 @@ func transformMessage(m protoreflect.Message) Message {
return mx
}

func transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} {
func (xf *transformer) transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} {
t := protoKindToGoType(fd.Kind())
rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len())
for i := 0; i < lv.Len(); i++ {
v := reflect.ValueOf(transformSingular(fd, lv.Get(i)))
v := reflect.ValueOf(xf.transformSingular(fd, lv.Get(i)))
rv.Index(i).Set(v)
}
return rv.Interface()
}

func transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} {
func (xf *transformer) transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} {
kfd := fd.MapKey()
vfd := fd.MapValue()
kt := protoKindToGoType(kfd.Kind())
vt := protoKindToGoType(vfd.Kind())
rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len())
mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
kv := reflect.ValueOf(transformSingular(kfd, k.Value()))
vv := reflect.ValueOf(transformSingular(vfd, v))
kv := reflect.ValueOf(xf.transformSingular(kfd, k.Value()))
vv := reflect.ValueOf(xf.transformSingular(vfd, v))
rv.SetMapIndex(kv, vv)
return true
})
return rv.Interface()
}

func transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} {
func (xf *transformer) transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} {
switch fd.Kind() {
case protoreflect.EnumKind:
return Enum{num: v.Enum(), ed: fd.Enum()}
case protoreflect.MessageKind, protoreflect.GroupKind:
return transformMessage(v.Message())
return xf.transformMessage(v.Message())
case protoreflect.BytesKind:
// The protoreflect API does not specify whether an empty bytes is
// guaranteed to be nil or not. Always return non-nil bytes to avoid
Expand Down
55 changes: 54 additions & 1 deletion testing/protocmp/xform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
package protocmp

import (
"strings"
"testing"

"github.com/google/go-cmp/cmp"

"google.golang.org/protobuf/internal/detrand"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/testing/protopack"
"google.golang.org/protobuf/types/known/anypb"

testpb "google.golang.org/protobuf/internal/testprotos/test"
)
Expand Down Expand Up @@ -254,7 +257,7 @@ func TestTransform(t *testing.T) {
}}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got := transformMessage(tt.in.ProtoReflect())
got := newTransformer().transformMessage(tt.in.ProtoReflect())
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("Transform() mismatch (-want +got):\n%v", diff)
}
Expand All @@ -263,6 +266,34 @@ func TestTransform(t *testing.T) {
}
})
}

t.Run("messageTypeResolver", func(t *testing.T) {
r := unaryMessageTypeResolver{
Type: (&testpb.TestAllTypes{}).ProtoReflect().Type(),
}
m := &testpb.TestAllTypes{OptionalBool: proto.Bool(true)}
in, err := anypb.New(m)
if err != nil {
t.Fatalf("anypb.New() failed: %v", err)
}
in.TypeUrl = "type.googleapis.com/MagicTestMessage"

got := newTransformer(MessageTypeResolver(r)).transformMessage(in.ProtoReflect())
want := Message{
messageTypeKey: messageMetaOf(&anypb.Any{}),
"type_url": "type.googleapis.com/MagicTestMessage",
"value": Message{
messageTypeKey: messageMetaOf(m),
"optional_bool": true,
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Transform() mismatch (-want +got):\n%v", diff)
}
if got.Unwrap() != in {
t.Errorf("got.Unwrap() = %p, want %p", got.Unwrap(), in)
}
})
}

func enumOf(e protoreflect.Enum) Enum {
Expand All @@ -272,3 +303,25 @@ func enumOf(e protoreflect.Enum) Enum {
func messageMetaOf(m protoreflect.ProtoMessage) messageMeta {
return messageMeta{m: m, md: m.ProtoReflect().Descriptor()}
}

// A unaryMessageTypeResolver can only resolve one type, and it's
// called "MagicTestMessage".
type unaryMessageTypeResolver struct {
Type protoreflect.MessageType
}

func (r unaryMessageTypeResolver) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
if message != "MagicTestMessage" {
return nil, protoregistry.NotFound
}
return r.Type, nil
}

func (r unaryMessageTypeResolver) FindMessageByURL(url string) (protoreflect.MessageType, error) {
const prefix = "type.googleapis.com/"

if !strings.HasPrefix(url, prefix) {
return nil, protoregistry.NotFound
}
return r.FindMessageByName(protoreflect.FullName(strings.TrimPrefix(url, prefix)))
}

0 comments on commit 18202d2

Please sign in to comment.