Skip to content

Commit

Permalink
Rename BaseType to TypeKind (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonpaulos authored Aug 15, 2022
1 parent 2b016ea commit fb2e94d
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 98 deletions.
10 changes: 5 additions & 5 deletions abi/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func (t Type) typeCastToTuple(tupLen ...int) (Type, error) {
var childT []Type

switch t.abiTypeID {
switch t.kind {
case String:
if len(tupLen) != 1 {
return Type{}, fmt.Errorf("string type conversion to tuple need 1 length argument")
Expand Down Expand Up @@ -52,7 +52,7 @@ func (t Type) typeCastToTuple(tupLen ...int) (Type, error) {

// Encode is an ABI type method to encode go values into bytes following ABI encoding rules
func (t Type) Encode(value interface{}) ([]byte, error) {
switch t.abiTypeID {
switch t.kind {
case Uint, Ufixed:
return encodeInt(value, t.bitSize)
case Bool:
Expand Down Expand Up @@ -211,7 +211,7 @@ func encodeTuple(value interface{}, childT []Type) ([]byte, error) {
}
tails[i] = tailEncoding
isDynamicIndex[i] = true
} else if childT[i].abiTypeID == Bool {
} else if childT[i].kind == Bool {
// search previous bool
before := findBoolLR(childT, i, -1)
// search after bool
Expand Down Expand Up @@ -317,7 +317,7 @@ func decodeUint(encoded []byte, bitSize uint16) (interface{}, error) {

// Decode is an ABI type method to decode bytes to go values from ABI encoding rules
func (t Type) Decode(encoded []byte) (interface{}, error) {
switch t.abiTypeID {
switch t.kind {
case Uint, Ufixed:
return decodeUint(encoded, t.bitSize)
case Bool:
Expand Down Expand Up @@ -389,7 +389,7 @@ func decodeTuple(encoded []byte, childT []Type) ([]interface{}, error) {
dynamicSegments = append(dynamicSegments, int(dynamicIndex))
valuePartition = append(valuePartition, nil)
iterIndex += lengthEncodeByteSize
} else if childT[i].abiTypeID == Bool {
} else if childT[i].kind == Bool {
// search previous bool
before := findBoolLR(childT, i, -1)
// search after bool
Expand Down
14 changes: 7 additions & 7 deletions abi/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ func categorySelfRoundTripTest(t *testing.T, category []testUnit) {
}
}

func addPrimitiveRandomValues(t *testing.T, pool *map[BaseType][]testUnit) {
func addPrimitiveRandomValues(t *testing.T, pool *map[TypeKind][]testUnit) {
(*pool)[Uint] = make([]testUnit, uintTestCaseCount*uintEnd/uintStepLength)
(*pool)[Ufixed] = make([]testUnit, ufixedPrecision*uintEnd/uintStepLength)

Expand Down Expand Up @@ -964,7 +964,7 @@ func addPrimitiveRandomValues(t *testing.T, pool *map[BaseType][]testUnit) {
}

func takeSomeFromCategoryAndGenerateArray(
t *testing.T, abiT BaseType, srtIndex int, takeNum uint16, pool *map[BaseType][]testUnit) {
t *testing.T, abiT TypeKind, srtIndex int, takeNum uint16, pool *map[TypeKind][]testUnit) {

tempArray := make([]interface{}, takeNum)
for i := 0; i < int(takeNum); i++ {
Expand All @@ -986,7 +986,7 @@ func takeSomeFromCategoryAndGenerateArray(
})
}

func addArrayRandomValues(t *testing.T, pool *map[BaseType][]testUnit) {
func addArrayRandomValues(t *testing.T, pool *map[TypeKind][]testUnit) {
for intIndex := 0; intIndex < len((*pool)[Uint]); intIndex += uintTestCaseCount {
takeSomeFromCategoryAndGenerateArray(t, Uint, intIndex, takeNum, pool)
}
Expand All @@ -999,16 +999,16 @@ func addArrayRandomValues(t *testing.T, pool *map[BaseType][]testUnit) {
categorySelfRoundTripTest(t, (*pool)[ArrayDynamic])
}

func addTupleRandomValues(t *testing.T, slotRange BaseType, pool *map[BaseType][]testUnit) {
func addTupleRandomValues(t *testing.T, slotRange TypeKind, pool *map[TypeKind][]testUnit) {
for i := 0; i < tupleTestCaseCount; i++ {
tupleLenBig, err := rand.Int(rand.Reader, big.NewInt(tupleMaxLength))
require.NoError(t, err, "generate random tuple length should not return error")
tupleLen := tupleLenBig.Int64() + 1
testUnits := make([]testUnit, tupleLen)
for index := 0; index < int(tupleLen); index++ {
tupleTypeIndexBig, err := rand.Int(rand.Reader, big.NewInt(int64(slotRange)+1))
tupleTypeIndexBig, err := rand.Int(rand.Reader, big.NewInt(int64(slotRange)))
require.NoError(t, err, "generate random tuple element type index should not return error")
tupleTypeIndex := BaseType(tupleTypeIndexBig.Int64())
tupleTypeIndex := TypeKind(tupleTypeIndexBig.Int64() + 1)
tupleElemChoiceRange := len((*pool)[tupleTypeIndex])

tupleElemRangeIndexBig, err := rand.Int(rand.Reader, big.NewInt(int64(tupleElemChoiceRange)))
Expand Down Expand Up @@ -1036,7 +1036,7 @@ func addTupleRandomValues(t *testing.T, slotRange BaseType, pool *map[BaseType][

func TestRandomABIEncodeDecodeRoundTrip(t *testing.T) {
t.Parallel()
testValuePool := make(map[BaseType][]testUnit)
testValuePool := make(map[TypeKind][]testUnit)
addPrimitiveRandomValues(t, &testValuePool)
addArrayRandomValues(t, &testValuePool)
addTupleRandomValues(t, String, &testValuePool)
Expand Down
14 changes: 7 additions & 7 deletions abi/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func castBigIntToNearestPrimitive(num *big.Int, bitSize uint16) (interface{}, er

// MarshalToJSON convert golang value to JSON format from ABI type
func (t Type) MarshalToJSON(value interface{}) ([]byte, error) {
switch t.abiTypeID {
switch t.kind {
case Uint:
bytesUint, err := encodeInt(value, t.bitSize)
if err != nil {
Expand Down Expand Up @@ -122,10 +122,10 @@ func (t Type) MarshalToJSON(value interface{}) ([]byte, error) {
if err != nil {
return nil, err
}
if t.abiTypeID == ArrayStatic && int(t.staticLength) != len(values) {
if t.kind == ArrayStatic && int(t.staticLength) != len(values) {
return nil, fmt.Errorf("length of slice %d != type specific length %d", len(values), t.staticLength)
}
if t.childTypes[0].abiTypeID == Byte {
if t.childTypes[0].kind == Byte {
byteArr := make([]byte, len(values))
for i := 0; i < len(values); i++ {
tempByte, ok := values[i].(byte)
Expand Down Expand Up @@ -173,7 +173,7 @@ func (t Type) MarshalToJSON(value interface{}) ([]byte, error) {

// UnmarshalFromJSON convert bytes to golang value following ABI type and encoding rules
func (t Type) UnmarshalFromJSON(jsonEncoded []byte) (interface{}, error) {
switch t.abiTypeID {
switch t.kind {
case Uint:
num := new(big.Int)
if err := num.UnmarshalJSON(jsonEncoded); err != nil {
Expand Down Expand Up @@ -217,13 +217,13 @@ func (t Type) UnmarshalFromJSON(jsonEncoded []byte) (interface{}, error) {

return addrBytes[:], nil
case ArrayStatic, ArrayDynamic:
if t.childTypes[0].abiTypeID == Byte && bytes.HasPrefix(jsonEncoded, []byte{'"'}) {
if t.childTypes[0].kind == Byte && bytes.HasPrefix(jsonEncoded, []byte{'"'}) {
var byteArr []byte
err := json.Unmarshal(jsonEncoded, &byteArr)
if err != nil {
return nil, fmt.Errorf("cannot cast JSON encoded (%s) to bytes: %w", string(jsonEncoded), err)
}
if t.abiTypeID == ArrayStatic && len(byteArr) != int(t.staticLength) {
if t.kind == ArrayStatic && len(byteArr) != int(t.staticLength) {
return nil, fmt.Errorf("length of slice %d != type specific length %d", len(byteArr), t.staticLength)
}
outInterface := make([]interface{}, len(byteArr))
Expand All @@ -236,7 +236,7 @@ func (t Type) UnmarshalFromJSON(jsonEncoded []byte) (interface{}, error) {
if err := json.Unmarshal(jsonEncoded, &elems); err != nil {
return nil, fmt.Errorf("cannot cast JSON encoded (%s) to array: %w", string(jsonEncoded), err)
}
if t.abiTypeID == ArrayStatic && len(elems) != int(t.staticLength) {
if t.kind == ArrayStatic && len(elems) != int(t.staticLength) {
return nil, fmt.Errorf("JSON array element number != ABI array elem number")
}
values := make([]interface{}, len(elems))
Expand Down
81 changes: 37 additions & 44 deletions abi/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,29 @@ import (
"strings"
)

/*
ABI-Types: uint<N>: An N-bit unsigned integer (8 <= N <= 512 and N % 8 = 0).
| byte (alias for uint8)
| ufixed <N> x <M> (8 <= N <= 512, N % 8 = 0, and 0 < M <= 160)
| bool
| address (alias for byte[32])
| <type> [<N>]
| <type> []
| string
| (T1, ..., Tn)
*/

// BaseType is an type-alias for uint32. A BaseType value indicates the type of an ABI value.
type BaseType uint32
// TypeKind is an enum value which indicates the kind of an ABI type.
type TypeKind uint32

const (
// Uint is the index (0) for `Uint` type in ABI encoding.
Uint BaseType = iota
// Byte is the index (1) for `Byte` type in ABI encoding.
// InvalidType represents an invalid and unused kind.
InvalidType = iota
// Uint is kind for ABI unsigned integer types, i.e. `uint<N>`.
Uint
// Byte is kind for the ABI `byte` type.
Byte
// Ufixed is the index (2) for `UFixed` type in ABI encoding.
// Ufixed is the kind for ABI unsigned fixed point decimal types, i.e. `ufixed<N>x<M>`.
Ufixed
// Bool is the index (3) for `Bool` type in ABI encoding.
// Bool is the kind for the ABI `bool` type.
Bool
// ArrayStatic is the index (4) for static length array (<type>[length]) type in ABI encoding.
// ArrayStatic is the kind for ABI static array types, i.e. `<type>[<length>]`.
ArrayStatic
// Address is the index (5) for `Address` type in ABI encoding (an type alias of Byte[32]).
// Address is the the kind for the ABI `address` type.
Address
// ArrayDynamic is the index (6) for dynamic length array (<type>[]) type in ABI encoding.
// ArrayDynamic is the kind for ABI dynamic array types, i.e. `<type>[]`.
ArrayDynamic
// String is the index (7) for `String` type in ABI encoding (an type alias of Byte[]).
// String is the kind for the ABI `string` type.
String
// Tuple is the index (8) for tuple `(<type 0>, ..., <type k>)` in ABI encoding.
// Tuple is the kind for ABI tuple types, i.e. `(<type 0>,...,<type k>)`.
Tuple
)

Expand All @@ -53,9 +43,12 @@ const (
abiEncodingLengthLimit = 1 << 16
)

// Type is the struct that stores information about an ABI value's type.
// Type is the struct that represents an ABI type.
//
// Do not use the zero value of this struct. Use the `TypeOf` function to create an instance of an
// ABI type.
type Type struct {
abiTypeID BaseType
kind TypeKind
childTypes []Type

// only can be applied to `uint` bitSize <N> or `ufixed` bitSize <N>
Expand All @@ -75,7 +68,7 @@ type Type struct {

// String serialize an ABI Type to a string in ABI encoding.
func (t Type) String() string {
switch t.abiTypeID {
switch t.kind {
case Uint:
return fmt.Sprintf("uint%d", t.bitSize)
case Byte:
Expand All @@ -99,7 +92,7 @@ func (t Type) String() string {
}
return "(" + strings.Join(typeStrings, ",") + ")"
default:
panic("Type Serialization Error, fail to infer from abiTypeID (bruh you shouldn't be here)")
return "<invalid type>"
}
}

Expand Down Expand Up @@ -275,23 +268,23 @@ func makeUintType(typeSize int) (Type, error) {
return Type{}, fmt.Errorf("unsupported uint type bitSize: %d", typeSize)
}
return Type{
abiTypeID: Uint,
bitSize: uint16(typeSize),
kind: Uint,
bitSize: uint16(typeSize),
}, nil
}

var (
// byteType is ABI type constant for byte
byteType = Type{abiTypeID: Byte}
byteType = Type{kind: Byte}

// boolType is ABI type constant for bool
boolType = Type{abiTypeID: Bool}
boolType = Type{kind: Bool}

// addressType is ABI type constant for address
addressType = Type{abiTypeID: Address}
addressType = Type{kind: Address}

// stringType is ABI type constant for string
stringType = Type{abiTypeID: String}
stringType = Type{kind: String}
)

// makeUfixedType makes `UFixed` ABI type by taking type bitSize and type precision as arguments.
Expand All @@ -305,7 +298,7 @@ func makeUfixedType(typeSize int, typePrecision int) (Type, error) {
return Type{}, fmt.Errorf("unsupported ufixed type precision: %d", typePrecision)
}
return Type{
abiTypeID: Ufixed,
kind: Ufixed,
bitSize: uint16(typeSize),
precision: uint16(typePrecision),
}, nil
Expand All @@ -315,7 +308,7 @@ func makeUfixedType(typeSize int, typePrecision int) (Type, error) {
// array element type and array length as arguments.
func makeStaticArrayType(argumentType Type, arrayLength uint16) Type {
return Type{
abiTypeID: ArrayStatic,
kind: ArrayStatic,
childTypes: []Type{argumentType},
staticLength: arrayLength,
}
Expand All @@ -324,7 +317,7 @@ func makeStaticArrayType(argumentType Type, arrayLength uint16) Type {
// makeDynamicArrayType makes dynamic length array by taking array element type as argument.
func makeDynamicArrayType(argumentType Type) Type {
return Type{
abiTypeID: ArrayDynamic,
kind: ArrayDynamic,
childTypes: []Type{argumentType},
}
}
Expand All @@ -335,15 +328,15 @@ func MakeTupleType(argumentTypes []Type) (Type, error) {
return Type{}, fmt.Errorf("tuple type child type number larger than maximum uint16 error")
}
return Type{
abiTypeID: Tuple,
kind: Tuple,
childTypes: argumentTypes,
staticLength: uint16(len(argumentTypes)),
}, nil
}

// Equal method decides the equality of two types: t == t0.
func (t Type) Equal(t0 Type) bool {
if t.abiTypeID != t0.abiTypeID {
if t.kind != t0.kind {
return false
}
if t.precision != t0.precision || t.bitSize != t0.bitSize {
Expand All @@ -366,7 +359,7 @@ func (t Type) Equal(t0 Type) bool {

// IsDynamic method decides if an ABI type is dynamic or static.
func (t Type) IsDynamic() bool {
switch t.abiTypeID {
switch t.kind {
case ArrayDynamic, String:
return true
default:
Expand All @@ -385,7 +378,7 @@ func findBoolLR(typeList []Type, index int, delta int) int {
until := 0
for {
curr := index + delta*until
if typeList[curr].abiTypeID == Bool {
if typeList[curr].kind == Bool {
if curr != len(typeList)-1 && delta > 0 {
until++
} else if curr > 0 && delta < 0 {
Expand All @@ -403,7 +396,7 @@ func findBoolLR(typeList []Type, index int, delta int) int {

// ByteLen method calculates the byte length of a static ABI type.
func (t Type) ByteLen() (int, error) {
switch t.abiTypeID {
switch t.kind {
case Address:
return addressByteSize, nil
case Byte:
Expand All @@ -413,7 +406,7 @@ func (t Type) ByteLen() (int, error) {
case Bool:
return singleBoolSize, nil
case ArrayStatic:
if t.childTypes[0].abiTypeID == Bool {
if t.childTypes[0].kind == Bool {
byteLen := int(t.staticLength+7) / 8
return byteLen, nil
}
Expand All @@ -425,7 +418,7 @@ func (t Type) ByteLen() (int, error) {
case Tuple:
size := 0
for i := 0; i < len(t.childTypes); i++ {
if t.childTypes[i].abiTypeID == Bool {
if t.childTypes[i].kind == Bool {
// search after bool
after := findBoolLR(t.childTypes, i, 1)
// shift the index
Expand Down
Loading

0 comments on commit fb2e94d

Please sign in to comment.