Skip to content

Commit

Permalink
add pointer support (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
Srinivas Devaki authored Sep 22, 2021
1 parent 16852db commit 1f7663f
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 1 deletion.
38 changes: 37 additions & 1 deletion codecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,37 @@ func (c *varuintSliceCodec) DecodeTo(d *Decoder, rv reflect.Value) (err error) {
return
}

type reflectPointerCodec struct {
elemCodec Codec
}

func (c *reflectPointerCodec) EncodeTo(e *Encoder, rv reflect.Value) (err error) {
if rv.IsNil() {
e.writeBool(true)
return
}
e.writeBool(false)
err = c.elemCodec.EncodeTo(e, rv.Elem())
if err != nil {
return err
}
return nil
}

func (c *reflectPointerCodec) DecodeTo(d *Decoder, rv reflect.Value) (err error) {
isNil, err := d.ReadBool()
if err != nil {
return err
}
if isNil {
return
}
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
return c.elemCodec.DecodeTo(d, rv.Elem())
}

// ------------------------------------------------------------------------------

type reflectStructCodec []fieldCodec
Expand All @@ -218,7 +249,12 @@ func (c *reflectStructCodec) EncodeTo(e *Encoder, rv reflect.Value) (err error)
// Decode decodes into a reflect value from the decoder.
func (c *reflectStructCodec) DecodeTo(d *Decoder, rv reflect.Value) (err error) {
for _, i := range *c {
if v := rv.Field(i.Index); v.CanSet() {
v := rv.Field(i.Index)
if v.Kind() == reflect.Ptr {
if err = i.Codec.DecodeTo(d, v); err != nil {
return
}
} else if v.CanSet() {
if err = i.Codec.DecodeTo(d, reflect.Indirect(v)); err != nil {
return
}
Expand Down
181 changes: 181 additions & 0 deletions codecs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package binary
import (
"bytes"
"errors"
"math/rand"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -331,6 +332,186 @@ func TestSliceOfStructWithStruct(t *testing.T) {

}

func TestBasicTypePointers(t *testing.T) {
type BT struct {
B *bool
S *string
I *int
I8 *int8
I16 *int16
I32 *int32
I64 *int64
Ui *uint
Ui8 *uint8
Ui16 *uint16
Ui32 *uint32
Ui64 *uint64
F32 *float32
F64 *float64
C64 *complex64
C128 *complex128
}
toss := func(chance float32) bool {
return rand.Float32() < chance
}
fuzz := func(bt *BT, nilChance float32) {
if toss(nilChance) {
k := rand.Intn(2) == 1
bt.B = &k
}
if toss(nilChance) {
b := make([]byte, rand.Intn(32))
rand.Read(b)
sb := string(b)
bt.S = &sb
}
if toss(nilChance) {
i := rand.Int()
bt.I = &i
}
if toss(nilChance) {
i8 := int8(rand.Int())
bt.I8 = &i8
}
if toss(nilChance) {
i16 := int16(rand.Int())
bt.I16 = &i16
}
if toss(nilChance) {
i32 := rand.Int31()
bt.I32 = &i32
}
if toss(nilChance) {
i64 := rand.Int63()
bt.I64 = &i64
}
if toss(nilChance) {
ui := uint(rand.Uint64())
bt.Ui = &ui
}
if toss(nilChance) {
ui8 := uint8(rand.Uint32())
bt.Ui8 = &ui8
}
if toss(nilChance) {
ui16 := uint16(rand.Uint32())
bt.Ui16 = &ui16
}
if toss(nilChance) {
ui32 := rand.Uint32()
bt.Ui32 = &ui32
}
if toss(nilChance) {
ui64 := rand.Uint64()
bt.Ui64 = &ui64
}
if toss(nilChance) {
f32 := rand.Float32()
bt.F32 = &f32
}
if toss(nilChance) {
f64 := rand.Float64()
bt.F64 = &f64
}
if toss(nilChance) {
c64 := complex(rand.Float32(), rand.Float32())
bt.C64 = &c64
}
if toss(nilChance) {
c128 := complex(rand.Float64(), rand.Float64())
bt.C128 = &c128
}
}
for _, nilChance := range []float32{.5, 0, 1} {
for i := 0; i < 10; i += 1 {
btOrig := &BT{}
fuzz(btOrig, nilChance)
payload, err := Marshal(btOrig)
if err != nil {
t.Errorf("marshalling failed basic type struct for: %+v, err=%+v", btOrig, err)
continue
}
btDecoded := &BT{}
err = Unmarshal(payload, btDecoded)
if err != nil {
t.Errorf("unmarshalling failed for: %+v, err=%+v", btOrig, err)
continue
}
}
}
}

func TestPointerOfPointer(t *testing.T) {
type S struct {
V **int
}
i := rand.Int()
pi := &i
ppi := &pi
sOrig := &S{
V: ppi,
}
payload, err := Marshal(sOrig)
if err != nil {
t.Errorf("marshalling failed pointer of pointer type for: %+v, err=%+v", sOrig, err)
return
}
sDecoded := &S{}
err = Unmarshal(payload, sDecoded)
if err != nil {
t.Errorf("unmarshalling failed pointer of pointer type for: %+v, err=%+v", sOrig, err)
return
}
if sDecoded.V == nil {
t.Errorf("unmarshalling failed for pointer of pointer: expected non-nil pointer of pointer value")
return
}

if *sDecoded.V == nil {
t.Errorf("unmarshalling failed for pointer of pointer: expected non-nil pointer value")
return
}
if **sDecoded.V != i {
t.Errorf("unmarshalling failed for pointer of pointer: expected: %d, actual: %d", i, **sDecoded.V)
return
}
}

func TestStructPointer(t *testing.T) {
type T struct {
V int
}
type S struct {
T *T
}
sOrig := &S{
T: &T{
V: rand.Int(),
},
}
payload, err := Marshal(sOrig)
if err != nil {
t.Errorf("marshalling failed for struct containing pointer of another struct: %+v, err=%+v", sOrig, err)
return
}
sDecoded := &S{}
err = Unmarshal(payload, sDecoded)
if err != nil {
t.Errorf("unmarshalling failed for struct containing pointer of another struct: %+v, err=%+v", sOrig, err)
return
}
if sDecoded.T == nil {
t.Errorf("unmarshalling failed for struct containing pointer of another struct: expecting non-nil pointer value")
return
}
if sDecoded.T.V != sOrig.T.V {
t.Errorf(
"unmarshalling failed for struct containing pointer of another struct: expected: %d, actual: %d",
sOrig.T.V, sDecoded.T.V,
)
}
}

func TestMarshalNonPointer(t *testing.T) {
type S struct {
A int
Expand Down
9 changes: 9 additions & 0 deletions scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ func scanType(t reflect.Type) (Codec, error) {
}

switch t.Kind() {
case reflect.Ptr:
elemCodec, err := scanType(t.Elem())
if err != nil {
return nil, err
}

return &reflectPointerCodec{
elemCodec: elemCodec,
}, nil
case reflect.Array:
elemCodec, err := scanType(t.Elem())
if err != nil {
Expand Down

0 comments on commit 1f7663f

Please sign in to comment.