Skip to content

Commit

Permalink
Fix pointer handling for schema registration and logical types.
Browse files Browse the repository at this point in the history
  • Loading branch information
lostluck committed Oct 31, 2022
1 parent 30b384a commit 9956094
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 22 deletions.
5 changes: 4 additions & 1 deletion sdks/go/pkg/beam/core/runtime/graphx/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ func (r *Registry) reconcileRegistrations() (deferedErr error) {
check := func(ut reflect.Type) bool {
return coder.LookupCustomCoder(ut) != nil
}
if check(ut) || check(reflect.PtrTo(ut)) {
// We could have either a pointer or non pointer here,
// so we strip pointerness and then check both.
vT := reflectx.SkipPtr(ut)
if check(vT) && check(reflect.PtrTo(vT)) {
continue
}
if err := r.registerType(ut, map[reflect.Type]struct{}{}); err != nil {
Expand Down
2 changes: 2 additions & 0 deletions sdks/go/pkg/beam/core/runtime/graphx/schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
"github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/testing/protocmp"
)

Expand Down Expand Up @@ -792,6 +793,7 @@ func TestSchemaConversion(t *testing.T) {
// real embedded type.
if !hasEmbeddedField(test.rt) && !test.rt.AssignableTo(got) {
t.Errorf("%v not assignable to %v", test.rt, got)
t.Errorf("%v for schema %v", test.rt, prototext.Format(test.st))
if d := cmp.Diff(reflect.New(test.rt).Elem().Interface(), reflect.New(got).Elem().Interface()); d != "" {
t.Errorf("diff (-want, +got): %v", d)
}
Expand Down
28 changes: 20 additions & 8 deletions sdks/go/pkg/beam/register/emitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ type myTestTypeEmitter1 struct {

func TestEmitter1(t *testing.T) {
Emitter1[int]()
if !exec.IsEmitterRegistered(reflect.TypeOf((*func(int))(nil)).Elem()) {
t.Fatalf("exec.IsEmitterRegistered(reflect.TypeOf((*func(int))(nil)).Elem()) = false, want true")
e1T := reflect.TypeOf((*func(int))(nil)).Elem()
if !exec.IsEmitterRegistered(e1T) {
t.Fatalf("exec.IsEmitterRegistered(%v) = false, want true", e1T)
}

Emitter1[myTestTypeEmitter1]()
Expand All @@ -50,11 +51,17 @@ type myTestTypeEmitter2B struct {

func TestEmitter2(t *testing.T) {
Emitter2[int, string]()
if !exec.IsEmitterRegistered(reflect.TypeOf((*func(int, string))(nil)).Elem()) {
t.Fatalf("exec.IsEmitterRegistered(reflect.TypeOf((*func(int, string))(nil)).Elem()) = false, want true")
e2isT := reflect.TypeOf((*func(int, string))(nil)).Elem()
if !exec.IsEmitterRegistered(e2isT) {
t.Fatalf("exec.IsEmitterRegistered(%v) = false, want true", e2isT)
}

Emitter2[*myTestTypeEmitter2A, myTestTypeEmitter2B]()
e2ABT := reflect.TypeOf((*func(*myTestTypeEmitter2A, myTestTypeEmitter2B))(nil)).Elem()
if !exec.IsEmitterRegistered(e2ABT) {
t.Fatalf("exec.IsEmitterRegistered(%v) = false, want true", e2ABT)
}

Emitter2[myTestTypeEmitter2A, myTestTypeEmitter2B]()
tA := reflect.TypeOf((*myTestTypeEmitter2A)(nil)).Elem()
checkRegisterations(t, tA)
tB := reflect.TypeOf((*myTestTypeEmitter2B)(nil)).Elem()
Expand All @@ -63,8 +70,9 @@ func TestEmitter2(t *testing.T) {

func TestEmitter2_WithTimestamp(t *testing.T) {
Emitter2[typex.EventTime, string]()
if !exec.IsEmitterRegistered(reflect.TypeOf((*func(typex.EventTime, string))(nil)).Elem()) {
t.Fatalf("exec.IsEmitterRegistered(reflect.TypeOf((*func(typex.EventTime, string))(nil)).Elem()) = false, want true")
e2tssT := reflect.TypeOf((*func(typex.EventTime, string))(nil)).Elem()
if !exec.IsEmitterRegistered(e2tssT) {
t.Fatalf("exec.IsEmitterRegistered(%v) = false, want true", e2tssT)
}
}

Expand All @@ -82,7 +90,11 @@ func TestEmitter3(t *testing.T) {
t.Fatalf("exec.IsEmitterRegistered(reflect.TypeOf((*func(typex.EventTime, int, string))(nil)).Elem()) = false, want true")
}

Emitter3[typex.EventTime, myTestTypeEmitter3A, myTestTypeEmitter3B]()
Emitter3[typex.EventTime, myTestTypeEmitter3A, *myTestTypeEmitter3B]()
e3tsABT := reflect.TypeOf((*func(typex.EventTime, myTestTypeEmitter3A, *myTestTypeEmitter3B))(nil)).Elem()
if !exec.IsEmitterRegistered(e3tsABT) {
t.Fatalf("exec.IsEmitterRegistered(%v) = false, want true", e3tsABT)
}
tA := reflect.TypeOf((*myTestTypeEmitter3A)(nil)).Elem()
checkRegisterations(t, tA)
tB := reflect.TypeOf((*myTestTypeEmitter3B)(nil)).Elem()
Expand Down
3 changes: 3 additions & 0 deletions sdks/go/pkg/beam/register/iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx/schema"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
)

type iter1[T any] struct {
Expand Down Expand Up @@ -107,6 +108,8 @@ func (v *iter2[T1, T2]) invoke(key *T1, value *T2) bool {
}

func registerType(t reflect.Type) {
// strip the pointer if present.
t = reflectx.SkipPtr(t)
if _, ok := runtime.TypeKey(t); !ok {
return
}
Expand Down
34 changes: 21 additions & 13 deletions sdks/go/pkg/beam/register/iter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,41 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx/schema"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
)

type myTestTypeIter1 struct {
Int int
}

func checkRegisterations(t *testing.T, rt reflect.Type) {
func checkRegisterations(t *testing.T, ort reflect.Type) {
t.Helper()
// Strip pointers for the original type since type key doesn't support them.
// Pointer handling is done elsewhere.
rt := reflectx.SkipPtr(ort)
key, ok := runtime.TypeKey(rt)
if !ok {
t.Fatalf("runtime.TypeKey(%v): no typekey for type", rt)
}
if _, ok := runtime.LookupType(key); !ok {
t.Errorf("want type %v to be available with key %q", rt, key)
}
if !schema.Registered(rt) {
t.Errorf("want type %v to be registered with schemas", rt)
if !schema.Registered(ort) {
t.Errorf("want type %v to be registered with schemas", ort)
}
}

func TestIter1(t *testing.T) {
Iter1[int]()
if !exec.IsInputRegistered(reflect.TypeOf((*func(*int) bool)(nil)).Elem()) {
t.Fatalf("exec.IsInputRegistered(reflect.TypeOf(((*func(*int) bool)(nil)).Elem()) = false, want true")
itiT := reflect.TypeOf((*func(*int) bool)(nil)).Elem()
if !exec.IsInputRegistered(itiT) {
t.Fatalf("exec.IsInputRegistered(%v) = false, want true", itiT)
}

Iter1[myTestTypeIter1]()
if !exec.IsInputRegistered(reflect.TypeOf((*func(*int) bool)(nil)).Elem()) {
t.Fatalf("exec.IsInputRegistered(reflect.TypeOf(((*func(*int) bool)(nil)).Elem()) = false, want true")
it1T := reflect.TypeOf((*func(*int) bool)(nil)).Elem()
if !exec.IsInputRegistered(it1T) {
t.Fatalf("exec.IsInputRegistered(%v) = false, want true", it1T)
}

ttrt := reflect.TypeOf((*myTestTypeIter1)(nil)).Elem()
Expand All @@ -69,18 +75,20 @@ type myTestTypeIter2B struct {

func TestIter2(t *testing.T) {
Iter2[int, string]()
if !exec.IsInputRegistered(reflect.TypeOf((*func(*int, *string) bool)(nil)).Elem()) {
t.Fatalf("exec.IsInputRegistered(reflect.TypeOf((*func(*int, *string) bool)(nil)).Elem()) = false, want true")
it2isT := reflect.TypeOf((*func(*int, *string) bool)(nil)).Elem()
if !exec.IsInputRegistered(it2isT) {
t.Fatalf("exec.IsInputRegistered(%v) = false, want true", it2isT)
}

Iter2[myTestTypeIter2A, myTestTypeIter2B]()
if !exec.IsInputRegistered(reflect.TypeOf((*func(*int) bool)(nil)).Elem()) {
t.Fatalf("exec.IsInputRegistered(reflect.TypeOf(((*func(*int) bool)(nil)).Elem()) = false, want true")
Iter2[myTestTypeIter2A, *myTestTypeIter2B]()
it2ABT := reflect.TypeOf((*func(*myTestTypeIter2A, **myTestTypeIter2B) bool)(nil)).Elem()
if !exec.IsInputRegistered(it2ABT) {
t.Fatalf("exec.IsInputRegistered(%v) = false, want true", it2ABT)
}

ttArt := reflect.TypeOf((*myTestTypeIter2A)(nil)).Elem()
checkRegisterations(t, ttArt)
ttBrt := reflect.TypeOf((*myTestTypeIter2B)(nil)).Elem()
ttBrt := reflect.TypeOf((*myTestTypeIter2B)(nil))
checkRegisterations(t, ttBrt)
}

Expand Down

0 comments on commit 9956094

Please sign in to comment.