Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix alias typing and tests #788

Merged
merged 5 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pkg/values/big_int.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,23 @@ func (b *BigInt) Unwrap() (any, error) {

func (b *BigInt) UnwrapTo(to any) error {
if b == nil || b.Underlying == nil {
return errors.New("could not unwrap nil values.BigInt")
return fmt.Errorf("could not unwrap nil")
nolag marked this conversation as resolved.
Show resolved Hide resolved
}

// check any here because unwrap to will make the *any point to a big.Int instead of *big.Int
switch tb := to.(type) {
case *big.Int:
if tb == nil {
return fmt.Errorf("cannot unwrap to nil pointer")
return errors.New("cannot unwrap to nil pointer")
}
*tb = *b.Underlying
case *any:
if tb == nil {
return fmt.Errorf("cannot unwrap to nil pointer")
return errors.New("cannot unwrap to nil pointer")
}

*tb = b.Underlying
return nil
default:
rto := reflect.ValueOf(to)
if rto.CanConvert(reflect.TypeOf(new(big.Int))) {
Expand Down
122 changes: 106 additions & 16 deletions pkg/values/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package values
import (
"errors"
"fmt"
"math"
"reflect"

"github.com/smartcontractkit/chainlink-common/pkg/values/pb"
Expand Down Expand Up @@ -41,27 +42,116 @@ func (i *Int64) UnwrapTo(to any) error {
return fmt.Errorf("cannot unwrap to nil pointer: %+v", to)
}

if reflect.ValueOf(to).Kind() != reflect.Pointer {
return fmt.Errorf("cannot unwrap to non-pointer value: %+v", to)
}
switch tv := to.(type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nolag I'm less keen on this: IMO we should be trying to move as much logic as possible to using unwrapTo so that we keep the behaviour of UnwrapTo consistent across types. With this implementation we have to reimplement support for aliases, which unwrapTo could handle for us in a common trunk of code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline: we're going to go with this for now given the time crunch

case *int64:
*tv = i.Underlying
return nil
case *int32:
if err := verifyBounds(math.MinInt32, math.MaxInt32, i.Underlying, "int32"); err != nil {
return err
}

*tv = int32(i.Underlying)
return nil
case *int16:
if err := verifyBounds(math.MinInt16, math.MaxInt16, i.Underlying, "int16"); err != nil {
return err
}

*tv = int16(i.Underlying)
return nil
case *int8:
if err := verifyBounds(math.MinInt8, math.MaxInt8, i.Underlying, "int8"); err != nil {
return err
}

rToVal := reflect.Indirect(reflect.ValueOf(to))
switch rToVal.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if rToVal.OverflowInt(i.Underlying) {
return fmt.Errorf("cannot unwrap int64 to %T: overflow", to)
*tv = int8(i.Underlying)
return nil
case *int:
if err := verifyBounds(math.MinInt, math.MaxInt, i.Underlying, "int"); err != nil {
return err
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:

*tv = int(i.Underlying)
return nil
case *uint64:
if i.Underlying < 0 {
return fmt.Errorf("cannot unwrap int64 to %T: underflow", to)
return fmt.Errorf("value %d is too small for uint64", i.Underlying)
}

*tv = uint64(i.Underlying)
return nil
case *uint32:
if err := verifyBounds(0, math.MaxUint32, i.Underlying, "uint32"); err != nil {
return err
}

*tv = uint32(i.Underlying)
return nil
case *uint16:
if err := verifyBounds(0, math.MaxUint16, i.Underlying, "uint16"); err != nil {
return err
}

*tv = uint16(i.Underlying)
return nil
case *uint8:
if err := verifyBounds(0, math.MaxUint8, i.Underlying, "uint8"); err != nil {
return err
}

*tv = uint8(i.Underlying)
return nil
case *uint:
if math.MaxUint == math.MaxUint64 {
if i.Underlying < 0 {
return fmt.Errorf("value %d is too small for uint64", i.Underlying)
}
}
if rToVal.OverflowUint(uint64(i.Underlying)) {
return fmt.Errorf("cannot unwrap int64 to %T: overflow", to)

*tv = uint(i.Underlying)
return nil
case *any:
*tv = i.Underlying
return nil
}

rv := reflect.ValueOf(to)
if rv.Kind() == reflect.Ptr {
switch rv.Elem().Kind() {
case reflect.Int64:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int64(0)))).Interface())
case reflect.Int32:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int32(0)))).Interface())
case reflect.Int16:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int16(0)))).Interface())
case reflect.Int8:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(int8(0)))).Interface())
case reflect.Int:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(0))).Interface())
case reflect.Uint64:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint64(0)))).Interface())
case reflect.Uint32:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint32(0)))).Interface())
case reflect.Uint16:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint16(0)))).Interface())
case reflect.Uint8:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint8(0)))).Interface())
case reflect.Uint:
return i.UnwrapTo(rv.Convert(reflect.PointerTo(reflect.TypeOf(uint(0)))).Interface())
default:
// fall through to the error, default is required by lint
}
case reflect.Interface:
default:
return fmt.Errorf("cannot unwrap to type %T", to)
}

return unwrapTo(i.Underlying, to)
return fmt.Errorf("cannot unwrap to type %T", to)
}

func verifyBounds(min, max, value int64, tpe string) error {
if value < min {
return fmt.Errorf("value %d is too large for %s", value, tpe)
} else if value > max {
return fmt.Errorf("value %d is too small for %s", value, tpe)
}
return nil
}
44 changes: 28 additions & 16 deletions pkg/values/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,15 @@ func (l *List) UnwrapTo(to any) error {
switch ptrVal.Kind() {
case reflect.Slice:
newList := reflect.MakeSlice(ptrVal.Type(), len(l.Underlying), len(l.Underlying))
for i, el := range l.Underlying {
newElm := newList.Index(i)
if newElm.Kind() == reflect.Pointer {
newElm.Set(reflect.New(newElm.Type().Elem()))
} else {
newElm = newElm.Addr()
}

if el == nil {
continue
}
if err := el.UnwrapTo(newElm.Interface()); err != nil {
return err
}
return l.unwrapToSliceOrArray(newList, val)
case reflect.Array:
if ptrVal.Len() < len(l.Underlying) {
return fmt.Errorf("too many elements to unwrap")
} else if ptrVal.Len() > len(l.Underlying) {
return fmt.Errorf("too few elements to unwrap")
}
reflect.Indirect(val).Set(newList)
return nil
arr := reflect.New(ptrVal.Type()).Elem()
return l.unwrapToSliceOrArray(arr, val)
default:
dl := []any{}
err := l.UnwrapTo(&dl)
Expand All @@ -104,3 +96,23 @@ func (l *List) UnwrapTo(to any) error {
return fmt.Errorf("cannot unwrap to type %T", to)
}
}

func (l *List) unwrapToSliceOrArray(newList reflect.Value, val reflect.Value) error {
for i, el := range l.Underlying {
newElm := newList.Index(i)
if newElm.Kind() == reflect.Pointer {
newElm.Set(reflect.New(newElm.Type().Elem()))
} else {
newElm = newElm.Addr()
}

if el == nil {
continue
}
if err := el.UnwrapTo(newElm.Interface()); err != nil {
return err
}
}
reflect.Indirect(val).Set(newList)
return nil
}
27 changes: 27 additions & 0 deletions pkg/values/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,33 @@ func Test_ListUnwrapTo(t *testing.T) {
sliceTest[any](t, expected, got)
})

t.Run("arrays", func(t *testing.T) {
v, err := Wrap([2]string{"foo", "bar"})
require.NoError(t, err)

var got [2]string
err = v.UnwrapTo(&got)
require.NoError(t, err)

require.Equal(t, [2]string{"foo", "bar"}, got)
})

t.Run("arrays too many elements return error", func(t *testing.T) {
wrapped, err := Wrap([]string{"foo", "bar", "baz"})
require.NoError(t, err)
to := [2]string{}
err = wrapped.UnwrapTo(&to)
assert.ErrorContains(t, err, "too many elements to unwrap")
})

t.Run("arrays too few elements return error", func(t *testing.T) {
wrapped, err := Wrap([]string{"foo", "bar", "baz"})
require.NoError(t, err)
to := [4]string{}
err = wrapped.UnwrapTo(&to)
assert.ErrorContains(t, err, "too few elements to unwrap")
})

t.Run("cant be assigned to passed in var", func(t *testing.T) {
a := struct{}{}
l, err := Wrap([]int{1, 2, 3})
Expand Down
49 changes: 36 additions & 13 deletions pkg/values/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,24 @@ func Wrap(v any) (Value, error) {
return NewDecimal(tv), nil
case int64:
return NewInt64(tv), nil
case int32:
return NewInt64(int64(tv)), nil
case int16:
return NewInt64(int64(tv)), nil
case int8:
return NewInt64(int64(tv)), nil
case int:
return NewInt64(int64(tv)), nil
case uint64:
return NewInt64(int64(tv)), nil
case uint:
return NewInt64(int64(tv)), nil
case uint32:
return NewInt64(int64(tv)), nil
case uint16:
return NewInt64(int64(tv)), nil
case uint8:
return NewInt64(int64(tv)), nil
case uint:
return NewInt64(int64(tv)), nil
case *big.Int:
return NewBigInt(tv), nil
case nil:
Expand Down Expand Up @@ -103,20 +113,25 @@ func Wrap(v any) (Value, error) {
return NewMap(m)
// Better complex type support for slices
case reflect.Slice:
s := make([]any, val.Len())
for i := 0; i < val.Len(); i++ {
item := val.Index(i).Interface()
s[i] = item
if val.Type().Elem().Kind() == reflect.Uint8 {
return NewBytes(val.Bytes()), nil
}
return createListFromSlice(val)
case reflect.Array:
arrayLen := val.Len()
slice := reflect.MakeSlice(reflect.SliceOf(val.Type().Elem()), arrayLen, arrayLen)
for i := 0; i < arrayLen; i++ {
slice.Index(i).Set(val.Index(i))
}
return NewList(s)
return Wrap(slice.Interface())
case reflect.Struct:
return CreateMapFromStruct(v)
case reflect.Pointer:
if reflect.Indirect(reflect.ValueOf(v)).Kind() == reflect.Struct {
return CreateMapFromStruct(reflect.Indirect(reflect.ValueOf(v)).Interface())
}
// pointer can't be null or the switch statement above would catch it.
return Wrap(val.Elem().Interface())
case reflect.String:
return Wrap(val.Convert(reflect.TypeOf("")).Interface())

case reflect.Bool:
return Wrap(val.Convert(reflect.TypeOf(true)).Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
Expand All @@ -126,6 +141,15 @@ func Wrap(v any) (Value, error) {
return nil, fmt.Errorf("could not wrap into value: %+v", v)
}

func createListFromSlice(val reflect.Value) (Value, error) {
s := make([]any, val.Len())
for i := 0; i < val.Len(); i++ {
item := val.Index(i).Interface()
s[i] = item
}
return NewList(s)
}

func WrapMap(a any) (*Map, error) {
v, err := Wrap(a)
if err != nil {
Expand Down Expand Up @@ -281,9 +305,8 @@ func unwrapTo[T any](underlying T, to any) error {
return fmt.Errorf("cannot unwrap to value of type: %T", to)
}

if rUnderlying.CanConvert(reflect.Indirect(rTo).Type()) {
conv := rUnderlying.Convert(reflect.Indirect(rTo).Type())
reflect.Indirect(rTo).Set(conv)
if rUnderlying.Type().ConvertibleTo(rTo.Type().Elem()) {
reflect.Indirect(rTo).Set(rUnderlying.Convert(rTo.Type().Elem()))
return nil
}

Expand Down
Loading
Loading