From 1f7663f210335775f9467305c89eceebf531c4d0 Mon Sep 17 00:00:00 2001 From: Srinivas Devaki Date: Wed, 22 Sep 2021 16:43:16 +0530 Subject: [PATCH] add pointer support (#7) --- codecs.go | 38 ++++++++++- codecs_test.go | 181 +++++++++++++++++++++++++++++++++++++++++++++++++ scanner.go | 9 +++ 3 files changed, 227 insertions(+), 1 deletion(-) diff --git a/codecs.go b/codecs.go index 65da58d..6e9cc02 100644 --- a/codecs.go +++ b/codecs.go @@ -196,6 +196,37 @@ func (c *varuintSliceCodec) DecodeTo(d *Decoder, rv reflect.Value) (err error) { return } +type reflectPointerCodec struct { + elemCodec Codec +} + +func (c *reflectPointerCodec) EncodeTo(e *Encoder, rv reflect.Value) (err error) { + if rv.IsNil() { + e.writeBool(true) + return + } + e.writeBool(false) + err = c.elemCodec.EncodeTo(e, rv.Elem()) + if err != nil { + return err + } + return nil +} + +func (c *reflectPointerCodec) DecodeTo(d *Decoder, rv reflect.Value) (err error) { + isNil, err := d.ReadBool() + if err != nil { + return err + } + if isNil { + return + } + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + return c.elemCodec.DecodeTo(d, rv.Elem()) +} + // ------------------------------------------------------------------------------ type reflectStructCodec []fieldCodec @@ -218,7 +249,12 @@ func (c *reflectStructCodec) EncodeTo(e *Encoder, rv reflect.Value) (err error) // Decode decodes into a reflect value from the decoder. func (c *reflectStructCodec) DecodeTo(d *Decoder, rv reflect.Value) (err error) { for _, i := range *c { - if v := rv.Field(i.Index); v.CanSet() { + v := rv.Field(i.Index) + if v.Kind() == reflect.Ptr { + if err = i.Codec.DecodeTo(d, v); err != nil { + return + } + } else if v.CanSet() { if err = i.Codec.DecodeTo(d, reflect.Indirect(v)); err != nil { return } diff --git a/codecs_test.go b/codecs_test.go index 265a7cf..598772a 100644 --- a/codecs_test.go +++ b/codecs_test.go @@ -6,6 +6,7 @@ package binary import ( "bytes" "errors" + "math/rand" "reflect" "testing" "time" @@ -331,6 +332,186 @@ func TestSliceOfStructWithStruct(t *testing.T) { } +func TestBasicTypePointers(t *testing.T) { + type BT struct { + B *bool + S *string + I *int + I8 *int8 + I16 *int16 + I32 *int32 + I64 *int64 + Ui *uint + Ui8 *uint8 + Ui16 *uint16 + Ui32 *uint32 + Ui64 *uint64 + F32 *float32 + F64 *float64 + C64 *complex64 + C128 *complex128 + } + toss := func(chance float32) bool { + return rand.Float32() < chance + } + fuzz := func(bt *BT, nilChance float32) { + if toss(nilChance) { + k := rand.Intn(2) == 1 + bt.B = &k + } + if toss(nilChance) { + b := make([]byte, rand.Intn(32)) + rand.Read(b) + sb := string(b) + bt.S = &sb + } + if toss(nilChance) { + i := rand.Int() + bt.I = &i + } + if toss(nilChance) { + i8 := int8(rand.Int()) + bt.I8 = &i8 + } + if toss(nilChance) { + i16 := int16(rand.Int()) + bt.I16 = &i16 + } + if toss(nilChance) { + i32 := rand.Int31() + bt.I32 = &i32 + } + if toss(nilChance) { + i64 := rand.Int63() + bt.I64 = &i64 + } + if toss(nilChance) { + ui := uint(rand.Uint64()) + bt.Ui = &ui + } + if toss(nilChance) { + ui8 := uint8(rand.Uint32()) + bt.Ui8 = &ui8 + } + if toss(nilChance) { + ui16 := uint16(rand.Uint32()) + bt.Ui16 = &ui16 + } + if toss(nilChance) { + ui32 := rand.Uint32() + bt.Ui32 = &ui32 + } + if toss(nilChance) { + ui64 := rand.Uint64() + bt.Ui64 = &ui64 + } + if toss(nilChance) { + f32 := rand.Float32() + bt.F32 = &f32 + } + if toss(nilChance) { + f64 := rand.Float64() + bt.F64 = &f64 + } + if toss(nilChance) { + c64 := complex(rand.Float32(), rand.Float32()) + bt.C64 = &c64 + } + if toss(nilChance) { + c128 := complex(rand.Float64(), rand.Float64()) + bt.C128 = &c128 + } + } + for _, nilChance := range []float32{.5, 0, 1} { + for i := 0; i < 10; i += 1 { + btOrig := &BT{} + fuzz(btOrig, nilChance) + payload, err := Marshal(btOrig) + if err != nil { + t.Errorf("marshalling failed basic type struct for: %+v, err=%+v", btOrig, err) + continue + } + btDecoded := &BT{} + err = Unmarshal(payload, btDecoded) + if err != nil { + t.Errorf("unmarshalling failed for: %+v, err=%+v", btOrig, err) + continue + } + } + } +} + +func TestPointerOfPointer(t *testing.T) { + type S struct { + V **int + } + i := rand.Int() + pi := &i + ppi := &pi + sOrig := &S{ + V: ppi, + } + payload, err := Marshal(sOrig) + if err != nil { + t.Errorf("marshalling failed pointer of pointer type for: %+v, err=%+v", sOrig, err) + return + } + sDecoded := &S{} + err = Unmarshal(payload, sDecoded) + if err != nil { + t.Errorf("unmarshalling failed pointer of pointer type for: %+v, err=%+v", sOrig, err) + return + } + if sDecoded.V == nil { + t.Errorf("unmarshalling failed for pointer of pointer: expected non-nil pointer of pointer value") + return + } + + if *sDecoded.V == nil { + t.Errorf("unmarshalling failed for pointer of pointer: expected non-nil pointer value") + return + } + if **sDecoded.V != i { + t.Errorf("unmarshalling failed for pointer of pointer: expected: %d, actual: %d", i, **sDecoded.V) + return + } +} + +func TestStructPointer(t *testing.T) { + type T struct { + V int + } + type S struct { + T *T + } + sOrig := &S{ + T: &T{ + V: rand.Int(), + }, + } + payload, err := Marshal(sOrig) + if err != nil { + t.Errorf("marshalling failed for struct containing pointer of another struct: %+v, err=%+v", sOrig, err) + return + } + sDecoded := &S{} + err = Unmarshal(payload, sDecoded) + if err != nil { + t.Errorf("unmarshalling failed for struct containing pointer of another struct: %+v, err=%+v", sOrig, err) + return + } + if sDecoded.T == nil { + t.Errorf("unmarshalling failed for struct containing pointer of another struct: expecting non-nil pointer value") + return + } + if sDecoded.T.V != sOrig.T.V { + t.Errorf( + "unmarshalling failed for struct containing pointer of another struct: expected: %d, actual: %d", + sOrig.T.V, sDecoded.T.V, + ) + } +} + func TestMarshalNonPointer(t *testing.T) { type S struct { A int diff --git a/scanner.go b/scanner.go index dffd2a1..587a1de 100644 --- a/scanner.go +++ b/scanner.go @@ -62,6 +62,15 @@ func scanType(t reflect.Type) (Codec, error) { } switch t.Kind() { + case reflect.Ptr: + elemCodec, err := scanType(t.Elem()) + if err != nil { + return nil, err + } + + return &reflectPointerCodec{ + elemCodec: elemCodec, + }, nil case reflect.Array: elemCodec, err := scanType(t.Elem()) if err != nil {