diff --git a/amino.go b/amino.go index 4f33d681..356a5888 100644 --- a/amino.go +++ b/amino.go @@ -2,6 +2,7 @@ package amino import ( "bytes" + "encoding" "fmt" "io" "reflect" @@ -205,13 +206,31 @@ func (cdc *Codec) MarshalBinaryBare(o interface{}) ([]byte, error) { } // Encode Amino:binary bytes. - var bz []byte buf := new(bytes.Buffer) rt := rv.Type() info, err := cdc.getTypeInfoWlock(rt) if err != nil { return nil, err } + + if info.Registered && rt.Implements(binaryMarshalerType) { + pb := info.Prefix.Bytes() + buf.Write(pb) + + bz, err := rv.Interface().(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + return nil, err + } + + _, err = buf.Write(bz) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil + } + + var bz []byte // in the case of of a repeated struct (e.g. type Alias []SomeStruct), // we do not need to prepend with `(field_number << 3) | wire_type` as this // would need to be done for each struct and not only for the first. @@ -398,6 +417,31 @@ func (cdc *Codec) UnmarshalBinaryBare(bz []byte, ptr interface{}) error { return err } + if info.Registered && rv.CanAddr() { + addr := rv.Addr() + if addr.Type().Implements(binaryUnmarshalerType) { + pb := info.Prefix.Bytes() + l := len(pb) + if len(bz) < l { + return fmt.Errorf( + "unmarshalBinaryBare expected to read prefix bytes %X (since it is registered concrete) but got %X", + pb, bz, + ) + } + + pb2 := bz[:l] + bz = bz[l:] + if !bytes.Equal(pb2, pb) { + return fmt.Errorf( + "unmarshalBinaryBare expected to read prefix bytes %X (since it is registered concrete) but got %X", + pb, pb2, + ) + } + + return addr.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(bz) + } + } + // If registered concrete, consume and verify prefix bytes. if info.Registered { aminoAny := &RegisteredAny{} diff --git a/binary_encode_override_test.go b/binary_encode_override_test.go new file mode 100644 index 00000000..e93748ad --- /dev/null +++ b/binary_encode_override_test.go @@ -0,0 +1,83 @@ +package amino_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tendermint/go-amino" +) + +type Thing struct { + Name string +} + +func (thing Thing) MarshalBinary() ([]byte, error) { + return []byte(thing.Name), nil +} + +func (thing *Thing) UnmarshalBinary(bz []byte) error { + thing.Name = string(bz) + return nil +} + +func TestMarshalBinaryOverrideBare(t *testing.T) { + var cdc = amino.NewCodec() + cdc.RegisterConcrete(&Thing{}, "amino/thing", nil) + + thing1 := Thing{Name: "a"} + + bz, err := cdc.MarshalBinaryBare(thing1) + assert.Nil(t, err) + assert.Equal(t, bz, []byte{140, 74, 30, 175, 97}) + + var thing2 Thing + err = cdc.UnmarshalBinaryBare(bz, &thing2) + assert.Nil(t, err) + assert.Equal(t, thing1, thing2) +} + +func TestMarshalBinaryOverrideLengthPrefixed(t *testing.T) { + var cdc = amino.NewCodec() + cdc.RegisterConcrete(&Thing{}, "amino/thing", nil) + + thing1 := Thing{Name: "a"} + + bz, err := cdc.MarshalBinaryLengthPrefixed(thing1) + assert.Nil(t, err) + assert.Equal(t, bz, []byte{5, 140, 74, 30, 175, 97}) + + var thing2 Thing + err = cdc.UnmarshalBinaryLengthPrefixed(bz, &thing2) + assert.Nil(t, err) + assert.Equal(t, thing1, thing2) +} + +type Bytes [16]byte + +func (bytes Bytes) MarshalBinary() ([]byte, error) { + bz := make([]byte, 17) + copy(bz[:1], []byte{16}) + copy(bz[1:], bytes[:]) + return bz, nil +} + +func (bytes *Bytes) UnmarshalBinary(bz []byte) error { + copy(bytes[:], bz[1:]) + return nil +} + +func TestMarshalBinaryOverrideBytes(t *testing.T) { + var cdc = amino.NewCodec() + cdc.RegisterConcrete(&Bytes{}, "amino/bytes", nil) + + bytes1 := Bytes{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + + bz, err := cdc.MarshalBinaryBare(bytes1) + assert.Nil(t, err) + assert.Equal(t, bz, []byte{207, 109, 94, 111, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) + + var bytes2 Bytes + err = cdc.UnmarshalBinaryBare(bz, &bytes2) + assert.Nil(t, err) + assert.Equal(t, bytes1, bytes2) +} diff --git a/reflect.go b/reflect.go index ab95603f..b1c369e9 100644 --- a/reflect.go +++ b/reflect.go @@ -1,6 +1,7 @@ package amino import ( + "encoding" "encoding/json" "fmt" "reflect" @@ -13,10 +14,12 @@ import ( const printLog = false var ( - timeType = reflect.TypeOf(time.Time{}) - jsonMarshalerType = reflect.TypeOf(new(json.Marshaler)).Elem() - jsonUnmarshalerType = reflect.TypeOf(new(json.Unmarshaler)).Elem() - errorType = reflect.TypeOf(new(error)).Elem() + timeType = reflect.TypeOf(time.Time{}) + jsonMarshalerType = reflect.TypeOf(new(json.Marshaler)).Elem() + jsonUnmarshalerType = reflect.TypeOf(new(json.Unmarshaler)).Elem() + binaryMarshalerType = reflect.TypeOf(new(encoding.BinaryMarshaler)).Elem() + binaryUnmarshalerType = reflect.TypeOf(new(encoding.BinaryUnmarshaler)).Elem() + errorType = reflect.TypeOf(new(error)).Elem() ) //----------------------------------------