From 2d747f984b07c355e48c9bc567aa74cc1f94e124 Mon Sep 17 00:00:00 2001 From: Kanishka Date: Mon, 29 May 2023 14:45:13 -0600 Subject: [PATCH] feat/scale: add `BitVec` (#3253) --- pkg/scale/README.md | 33 +++++++ pkg/scale/bitvec.go | 87 ++++++++++++++++++ pkg/scale/bitvec_test.go | 186 +++++++++++++++++++++++++++++++++++++++ pkg/scale/decode.go | 30 +++++++ pkg/scale/decode_test.go | 24 +++++ pkg/scale/encode.go | 19 ++++ pkg/scale/encode_test.go | 56 +++++++++++- pkg/scale/errors.go | 1 + 8 files changed, 435 insertions(+), 1 deletion(-) create mode 100644 pkg/scale/bitvec.go create mode 100644 pkg/scale/bitvec_test.go diff --git a/pkg/scale/README.md b/pkg/scale/README.md index edb7f9db67..253b3574bd 100644 --- a/pkg/scale/README.md +++ b/pkg/scale/README.md @@ -77,6 +77,39 @@ SCALE uses a compact encoding for variable width unsigned integers. | `Compact` | `uint` | | `Compact` | `*big.Int` | +### BitVec + +SCALE uses a bit vector to encode a sequence of booleans. The bit vector is encoded as a compact length followed by a byte array. +The byte array is a sequence of bytes where each bit represents a boolean value. + +**Note: This is a work in progress.** +The current implementation of BitVec is just bare bones. It does not implement any of the methods of the `BitVec` type in Rust. + +```go +import ( + "fmt" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +func ExampleBitVec() { + bitvec := NewBitVec([]bool{true, false, true, false, true, false, true, false}) + bytes, err := scale.Marshal(bitvec) + if err != nil { + panic(err) + } + + var unmarshaled BitVec + err = scale.Unmarshal(bytes, &unmarshaled) + if err != nil { + panic(err) + } + + // [true false true false true false true false] + fmt.Printf("%v", unmarshaled.Bits()) +} +``` + + ## Usage ### Basic Example diff --git a/pkg/scale/bitvec.go b/pkg/scale/bitvec.go new file mode 100644 index 0000000000..2ebedee8f6 --- /dev/null +++ b/pkg/scale/bitvec.go @@ -0,0 +1,87 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package scale + +const ( + // maxLen equivalent of `ARCH32BIT_BITSLICE_MAX_BITS` in parity-scale-codec + maxLen = 268435455 + // byteSize is the number of bits in a byte + byteSize = 8 +) + +// BitVec is the implementation of the bit vector +type BitVec struct { + bits []bool +} + +// NewBitVec returns a new BitVec with the given bits +// This isn't a complete implementation of the bit vector +// It is only used for ParachainHost runtime exports +// TODO: Implement the full bit vector +// https://github.com/ChainSafe/gossamer/issues/3248 +func NewBitVec(bits []bool) BitVec { + return BitVec{ + bits: bits, + } +} + +// Bits returns the bits in the BitVec +func (bv *BitVec) Bits() []bool { + return bv.bits +} + +// Bytes returns the byte representation of the BitVec.Bits +func (bv *BitVec) Bytes() []byte { + return bitsToBytes(bv.bits) +} + +// Size returns the number of bits in the BitVec +func (bv *BitVec) Size() uint { + return uint(len(bv.bits)) +} + +// bitsToBytes converts a slice of bits to a slice of bytes +// Uses lsb ordering +// TODO: Implement msb ordering +// https://github.com/ChainSafe/gossamer/issues/3248 +func bitsToBytes(bits []bool) []byte { + bitLength := len(bits) + numOfBytes := (bitLength + (byteSize - 1)) / byteSize + bytes := make([]byte, numOfBytes) + + if len(bits)%byteSize != 0 { + // Pad with zeros to make the number of bits a multiple of byteSize + pad := make([]bool, byteSize-len(bits)%byteSize) + bits = append(bits, pad...) + } + + for i := 0; i < bitLength; i++ { + if bits[i] { + byteIndex := i / byteSize + bitIndex := i % byteSize + bytes[byteIndex] |= 1 << bitIndex + } + } + + return bytes +} + +// bytesToBits converts a slice of bytes to a slice of bits +func bytesToBits(b []byte, size uint) []bool { + var bits []bool + for _, uint8val := range b { + end := size + if end > byteSize { + end = byteSize + } + size -= end + + for j := uint(0); j < end; j++ { + bit := (uint8val>>j)&1 == 1 + bits = append(bits, bit) + } + } + + return bits +} diff --git a/pkg/scale/bitvec_test.go b/pkg/scale/bitvec_test.go new file mode 100644 index 0000000000..c0fde7f3b6 --- /dev/null +++ b/pkg/scale/bitvec_test.go @@ -0,0 +1,186 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package scale + +import ( + "testing" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/stretchr/testify/require" +) + +func TestBitVec(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in string + wantBitVec BitVec + wantErr bool + }{ + { + name: "empty_bitvec", + in: "0x00", + wantBitVec: NewBitVec(nil), + wantErr: false, + }, + { + name: "1_byte", + in: "0x2055", + wantBitVec: NewBitVec([]bool{true, false, true, false, true, false, true, false}), + wantErr: false, + }, + { + name: "4_bytes", + in: "0x645536aa01", + wantBitVec: NewBitVec([]bool{ + true, false, true, false, true, false, true, false, + false, true, true, false, true, true, false, false, + false, true, false, true, false, true, false, true, + true, + }), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + resultBytes, err := common.HexToBytes(tt.in) + require.NoError(t, err) + + bv := NewBitVec(nil) + err = Unmarshal(resultBytes, &bv) + require.NoError(t, err) + + require.Equal(t, tt.wantBitVec.Size(), bv.Size()) + require.Equal(t, tt.wantBitVec.Size(), bv.Size()) + + b, err := Marshal(bv) + require.NoError(t, err) + require.Equal(t, resultBytes, b) + }) + } +} + +func TestBitVecBytes(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in BitVec + want []byte + wantErr bool + }{ + { + name: "empty_bitvec", + in: NewBitVec(nil), + want: []byte(nil), + wantErr: false, + }, + { + name: "1_byte", + in: NewBitVec([]bool{true, false, true, false, true, false, true, false}), + want: []byte{0x55}, + wantErr: false, + }, + { + name: "4_bytes", + in: NewBitVec([]bool{ + true, false, true, false, true, false, true, false, + false, true, true, false, true, true, false, false, + false, true, false, true, false, true, false, true, + true, + }), + want: []byte{0x55, 0x36, 0xaa, 0x1}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, tt.in.Bytes()) + }) + } +} + +func TestBitVecBytesToBits(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in []byte + want []bool + wantErr bool + }{ + { + name: "empty", + in: []byte(nil), + want: []bool(nil), + wantErr: false, + }, + { + name: "1_byte", + in: []byte{0x55}, + want: []bool{true, false, true, false, true, false, true, false}, + wantErr: false, + }, + { + name: "4_bytes", + in: []byte{0x55, 0x36, 0xaa, 0x1}, + want: []bool{ + true, false, true, false, true, false, true, false, + false, true, true, false, true, true, false, false, + false, true, false, true, false, true, false, true, + true, false, false, false, false, false, false, false, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, bytesToBits(tt.in, uint(len(tt.in)*byteSize))) + }) + } +} + +func TestBitVecBitsToBytes(t *testing.T) { + t.Parallel() + tests := []struct { + name string + in []bool + want []byte + wantErr bool + }{ + { + name: "empty", + in: []bool(nil), + want: []byte{}, + wantErr: false, + }, + { + name: "1_byte", + in: []bool{true, false, true, false, true, false, true, false}, + want: []byte{0x55}, + wantErr: false, + }, + { + name: "4_bytes", + in: []bool{ + true, false, true, false, true, false, true, false, + false, true, true, false, true, true, false, false, + false, true, false, true, false, true, false, true, + true, + }, + want: []byte{0x55, 0x36, 0xaa, 0x1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, bitsToBytes(tt.in)) + }) + } +} diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index 45a527f0b8..77a1476551 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -114,6 +114,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { err = ds.decodeBigInt(dstv) case *Uint128: err = ds.decodeUint128(dstv) + case BitVec: + err = ds.decodeBitVec(dstv) case int, uint: err = ds.decodeUint(dstv) case int8, uint8, int16, uint16, int32, uint32, int64, uint64: @@ -768,3 +770,31 @@ func (ds *decodeState) decodeUint128(dstv reflect.Value) (err error) { dstv.Set(reflect.ValueOf(ui128)) return } + +// decodeBitVec accepts a byte array representing a SCALE encoded +// BitVec and performs SCALE decoding of the BitVec +func (ds *decodeState) decodeBitVec(dstv reflect.Value) error { + var size uint + if err := ds.decodeUint(reflect.ValueOf(&size).Elem()); err != nil { + return err + } + + if size > maxLen { + return fmt.Errorf("%w: %d", errBitVecTooLong, size) + } + + numBytes := (size + (byteSize - 1)) / byteSize + b := make([]byte, numBytes) + _, err := ds.Read(b) + if err != nil { + return err + } + + bitvec := NewBitVec(bytesToBits(b, size)) + if len(bitvec.bits) > int(size) { + return fmt.Errorf("bitvec length mismatch: expected %d, got %d", size, len(bitvec.bits)) + } + + dstv.Set(reflect.ValueOf(bitvec)) + return nil +} diff --git a/pkg/scale/decode_test.go b/pkg/scale/decode_test.go index 3309c58a9b..68a1a4b767 100644 --- a/pkg/scale/decode_test.go +++ b/pkg/scale/decode_test.go @@ -9,6 +9,8 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/require" + "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/assert" @@ -251,6 +253,28 @@ func Test_decodeState_decodeMap(t *testing.T) { } } +func Test_decodeState_decodeBitVec(t *testing.T) { + for _, tt := range bitVecTests { + t.Run(tt.name, func(t *testing.T) { + dst := reflect.New(reflect.TypeOf(tt.in)).Elem().Interface() + if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr { + t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(dst, tt.in) { + t.Errorf("decodeState.unmarshal() = %v, want %v", dst, tt.in) + } + }) + } +} + +func Test_decodeState_decodeBitVecMaxLen(t *testing.T) { + t.Parallel() + bitvec := NewBitVec(nil) + maxLen10 := []byte{38, 0, 0, 64, 0} // maxLen + 10 + err := Unmarshal(maxLen10, &bitvec) + require.Error(t, err, errBitVecTooLong) +} + func Test_unmarshal_optionality(t *testing.T) { var ptrTests tests for _, t := range append(tests{}, allTests...) { diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index d312b85f91..5a07919136 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -73,6 +73,8 @@ func (es *encodeState) marshal(in interface{}) (err error) { err = es.encodeBigInt(in) case *Uint128: err = es.encodeUint128(in) + case BitVec: + err = es.encodeBitVec(in) case []byte: err = es.encodeBytes(in) case string: @@ -423,3 +425,20 @@ func (es *encodeState) encodeUint128(i *Uint128) (err error) { err = binary.Write(es, binary.LittleEndian, padBytes(i.Bytes(), binary.LittleEndian)) return } + +// encodeBitVec encodes a BitVec +func (es *encodeState) encodeBitVec(bitvec BitVec) (err error) { + if bitvec.Size() > maxLen { + err = fmt.Errorf("%w: %d", errBitVecTooLong, bitvec.Size()) + return + } + + err = es.encodeUint(bitvec.Size()) + if err != nil { + return + } + + data := bitvec.Bytes() + _, err = es.Write(data) + return +} diff --git a/pkg/scale/encode_test.go b/pkg/scale/encode_test.go index 92de411919..2d8c578671 100644 --- a/pkg/scale/encode_test.go +++ b/pkg/scale/encode_test.go @@ -225,7 +225,7 @@ var ( uintTests = tests{ { name: "uint(1)", - in: int(1), + in: 1, want: []byte{0x04}, }, { @@ -938,6 +938,29 @@ var ( }, } + bitVecTests = tests{ + { + name: "BitVec{Size:__0,_Bits:__nil}", + in: NewBitVec(nil), + want: []byte{0}, + }, + { + name: "BitVec{Size:_8}", + in: NewBitVec([]bool{true, false, true, false, true, false, true, false}), + want: []byte{0x20, 0x55}, + }, + { + name: "BitVec{Size:_25}", + in: NewBitVec([]bool{ + true, false, true, false, true, false, true, false, + false, true, true, false, true, true, false, false, + false, true, false, true, false, true, false, true, + true, + }), + want: []byte{0x64, 0x55, 0x36, 0xaa, 0x1}, + }, + } + allTests = newTests( fixedWidthIntegerTests, variableWidthIntegerTests, stringTests, boolTests, structTests, sliceTests, arrayTests, @@ -1172,6 +1195,37 @@ func Test_encodeState_encodeMap(t *testing.T) { } } +func Test_encodeState_encodeBitVec(t *testing.T) { + for _, tt := range bitVecTests { + t.Run(tt.name, func(t *testing.T) { + buffer := bytes.NewBuffer(nil) + es := &encodeState{ + Writer: buffer, + fieldScaleIndicesCache: cache, + } + if err := es.marshal(tt.in); (err != nil) != tt.wantErr { + t.Errorf("encodeState.encodeBitVec() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(buffer.Bytes(), tt.want) { + t.Errorf("encodeState.encodeBitVec() = %v, want %v", buffer.Bytes(), tt.want) + } + }) + } +} + +func Test_encodeState_encodeBitVecMaxLen(t *testing.T) { + t.Parallel() + + var bits []bool + for i := 0; i < maxLen+1; i++ { + bits = append(bits, true) + } + + bitvec := NewBitVec(bits) + _, err := Marshal(bitvec) + require.Error(t, err, errBitVecTooLong) +} + func Test_marshal_optionality(t *testing.T) { var ptrTests tests for i := range allTests { diff --git a/pkg/scale/errors.go b/pkg/scale/errors.go index 5b317979dd..9930cebfbc 100644 --- a/pkg/scale/errors.go +++ b/pkg/scale/errors.go @@ -13,6 +13,7 @@ var ( errUnsupportedOption = errors.New("unsupported option") errUnknownVaryingDataTypeValue = errors.New("unable to find VaryingDataTypeValue with index") errUint128IsNil = errors.New("uint128 in nil") + errBitVecTooLong = errors.New("bitvec too long") ErrResultNotSet = errors.New("result not set") ErrResultAlreadySet = errors.New("result already has an assigned value") ErrUnsupportedVaryingDataTypeValue = errors.New("unsupported VaryingDataTypeValue")