Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pkg/scale): encoding and decoding of maps in scale #2894

Merged
merged 11 commits into from
Oct 21, 2022
30 changes: 30 additions & 0 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
err = ds.decodeArray(dstv)
case reflect.Slice:
err = ds.decodeSlice(dstv)
case reflect.Map:
err = ds.decodeMap(dstv)
default:
err = fmt.Errorf("unsupported type: %T", in)
}
Expand Down Expand Up @@ -426,6 +428,34 @@ func (ds *decodeState) decodeArray(dstv reflect.Value) (err error) {
return
}

func (ds *decodeState) decodeMap(dstv reflect.Value) (err error) {
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
l, err := ds.decodeLength()
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
}
in := dstv.Interface()

for i := uint(0); i < l; i++ {
tempKeyType := reflect.TypeOf(in).Key()
tempKey := reflect.New(tempKeyType).Elem()
err = ds.unmarshal(tempKey)
if err != nil {
return
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
}

tempElemType := reflect.TypeOf(in).Elem()
tempElem := reflect.New(tempElemType).Elem()
err = ds.unmarshal(tempElem)
if err != nil {
return
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
}

dstv.SetMapIndex(tempKey, tempElem)
}

return
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
}

// decodeStruct decodes a byte array representing a SCALE tuple. The order of data is
// determined by the source tuple in rust, or the struct field order in a go struct
func (ds *decodeState) decodeStruct(dstv reflect.Value) (err error) {
Expand Down
34 changes: 34 additions & 0 deletions pkg/scale/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,40 @@ func Test_decodeState_decodeSlice(t *testing.T) {
}
}

func Test_decodeState_decodeMap(t *testing.T) {
mapTests = tests{
{
name: "testMap1",
in: map[int8][]byte{2: []byte("some string")},
want: []byte{4, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103},
},
{
name: "testMap2",
in: map[int8][]byte{
2: []byte("some string"),
16: []byte("lorem ipsum"),
},
want: []byte{
8, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103, 16, 44, 108, 111, 114, 101, 109, 32,
105, 112, 115, 117, 109,
},
},
}

for _, tt := range mapTests {
t.Run(tt.name, func(t *testing.T) {
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
dst := map[int8][]byte{}
if err := Unmarshal(tt.want, &dst); (err != nil) != tt.wantErr {
t.Errorf("decodeState.unmarshal() error = %v, wantErr %v", err, tt.wantErr)
}

if !reflect.DeepEqual(dst, tt.in) {
t.Errorf("decodeState.unmarshal() = %v, want %v", dst, tt.in)
}
})
}
}

func Test_unmarshal_optionality(t *testing.T) {
var ptrTests tests
for _, t := range append(tests{}, allTests...) {
Expand Down
31 changes: 31 additions & 0 deletions pkg/scale/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ func (es *encodeState) marshal(in interface{}) (err error) {
err = es.encodeArray(in)
case reflect.Slice:
err = es.encodeSlice(in)
case reflect.Map:
err = es.encodeMap(in)
default:
err = fmt.Errorf("unsupported type: %T", in)
}
Expand Down Expand Up @@ -223,6 +225,35 @@ func (es *encodeState) encodeArray(in interface{}) (err error) {
return
}

func (es *encodeState) encodeMap(in interface{}) (err error) {
v := reflect.ValueOf(in)
err = es.encodeLength(v.Len())
if err != nil {
return
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
}

for i := v.MapRange(); i.Next(); {
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
fmt.Println(i.Key(), "\t:", i.Value())
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved

axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
key := i.Key()
err = es.marshal(key.Interface())
if err != nil {
return
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
}

mapValue := i.Value()
if !mapValue.CanInterface() {
continue
}
Comment on lines +251 to +253
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we return an error in that case, to avoid silently discarding a map value when encoding?

@timwu20

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Def wait for Tims answer, but this happens at other places in this file as well (i.e. structs)

Copy link
Contributor

@timwu20 timwu20 Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CanInterface is essentially determining if this is a public attribute of a struct or public method. I think it should be fine in this case. This will almost always return true in the case we're trying to decode into a map value. I wonder what cases this would return false though, maybe we can provide a test case.

@kishansagathiya please don't merge the PR with unresolved conversations. My bad for not getting to this earlier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realised about this comment few moments after merging it. Wasn't deliberate.


err = es.marshal(mapValue.Interface())
if err != nil {
return
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
}
}
return
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
}

// encodeBigInt performs the same encoding as encodeInteger, except on a big.Int.
// if 2^30 <= n < 2^536 write
// [lower 2 bits of first byte = 11] [upper 6 bits of first byte = # of bytes following less 4]
Expand Down
38 changes: 37 additions & 1 deletion pkg/scale/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,9 +909,27 @@ var (
},
}

mapTests = tests{
{
name: "testMap1",
in: map[int8][]byte{2: []byte("some string")},
want: []byte{4, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103},
jimjbrettj marked this conversation as resolved.
Show resolved Hide resolved
},
{
name: "testMap2",
in: map[int8][]byte{
2: []byte("some string"),
16: []byte("lorem ipsum"),
},
want: []byte{
8, 2, 44, 115, 111, 109, 101, 32, 115, 116, 114, 105, 110, 103, 16, 44, 108, 111, 114, 101, 109, 32,
105, 112, 115, 117, 109,
}},
}

allTests = newTests(
fixedWidthIntegerTests, variableWidthIntegerTests, stringTests,
boolTests, structTests, sliceTests, arrayTests,
boolTests, structTests, sliceTests, arrayTests, mapTests,
varyingDataTypeTests,
)
)
Expand Down Expand Up @@ -1096,6 +1114,24 @@ func Test_encodeState_encodeArray(t *testing.T) {
}
}

func Test_encodeState_encodeMap(t *testing.T) {
for _, tt := range mapTests {
t.Run(tt.name, func(t *testing.T) {
axaysagathiya marked this conversation as resolved.
Show resolved Hide resolved
buffer := bytes.NewBuffer(nil)
es := &encodeState{
Writer: buffer,
fieldScaleIndicesCache: cache,
}
if err := es.marshal(tt.in); (err != nil) != tt.wantErr {
t.Errorf("encodeState.encodeMap() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(buffer.Bytes(), tt.want) {
t.Errorf("encodeState.encodeMap() = %v, want %v", buffer.Bytes(), tt.want)
}
})
}
}

func Test_marshal_optionality(t *testing.T) {
var ptrTests tests
for i := range allTests {
Expand Down