From 0fdc98b4b1625316c708121983d992fe0fee04f3 Mon Sep 17 00:00:00 2001 From: thiagodeev Date: Mon, 21 Oct 2024 12:37:55 -0300 Subject: [PATCH] Adds revision.go file, new Revision field of TypedData --- typedData/revision.go | 101 +++++++++++++++++++++++++++++++++++++++++ typedData/typedData.go | 31 ++++++++----- typedData/types.go | 10 ++-- 3 files changed, 125 insertions(+), 17 deletions(-) create mode 100644 typedData/revision.go diff --git a/typedData/revision.go b/typedData/revision.go new file mode 100644 index 00000000..9db67fb6 --- /dev/null +++ b/typedData/revision.go @@ -0,0 +1,101 @@ +package typedData + +import ( + "encoding/json" + "fmt" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/starknet.go/curve" +) + +var ( + // There is also an array version of each type. The array is defined like this: 'type' + '*' (e.g.: "felt*", "bool*", "string*"...) + revision_0_basic_types []string = []string{ + "felt", + "bool", + "string", //up to 31 ASCII characters + "selector", + "merkletree", + } + + // Revision 1 includes all types from Revision 0 plus these. The only difference is that for Revision 1 "string" represents an + // arbitrary size string instead of having a 31 ASCII characters limit in Revision 0; for this limit, use the new type "shortstring" instead. + // + // There is also an array version of each type. The array is defined like this: 'type' + '*' (e.g.: "ClassHash*", "timestamp*", "shortstring*"...) + revision_1_basic_types []string = []string{ + //TODO: enum? + "u128", + "i128", + "ContractAddress", + "ClassHash", + "timestamp", + "shortstring", + } +) + +type Revision struct { + Domain String + HashMethod func(felts ...*felt.Felt) *felt.Felt + //TODO: hashMerkleMethod ? + Types RevisionTypes +} + +type RevisionTypes struct { + Basic []string + Preset map[string]any +} + +func NewRevision(version uint8) (rev Revision, err error) { + preset := make(map[string]any) + + switch version { + case 0: + rev = Revision{ + Domain: "StarkNetDomain", + HashMethod: curve.PedersenArray, + Types: RevisionTypes{ + Basic: revision_0_basic_types, + Preset: preset, + }, + } + return rev, nil + case 1: + preset, err = getRevisionV1PresetTypes() + if err != nil { + return rev, fmt.Errorf("error getting revision 1 preset types: %w", err) + } + rev = Revision{ + Domain: "StarknetDomain", + HashMethod: curve.PoseidonArray, + Types: RevisionTypes{ + Basic: append(revision_1_basic_types, revision_0_basic_types...), + Preset: preset, + }, + } + return rev, nil + default: + return rev, fmt.Errorf("invalid revision version") + } +} + +func getRevisionV1PresetTypes() (result map[string]any, err error) { + type RevV1PresetTypes struct { + NftId NftId + TokenAmount TokenAmount + U256 U256 + } + + var preset RevV1PresetTypes + + bytes, err := json.Marshal(preset) + if err != nil { + return result, err + } + + err = json.Unmarshal(bytes, &result) + if err != nil { + return result, err + } + + return result, err +} diff --git a/typedData/typedData.go b/typedData/typedData.go index 3a59e240..45bc36b7 100644 --- a/typedData/typedData.go +++ b/typedData/typedData.go @@ -43,28 +43,29 @@ var ( ) type TypedData struct { - Types map[string]TypeDefinition - PrimaryType string - Domain Domain - Message map[string]any + Types map[string]TypeDefinition `json:"types"` + PrimaryType string `json:"primaryType"` + Domain Domain `json:"domain"` + Message map[string]any `json:"message"` + Revision Revision } type Domain struct { - Name string - Version json.Number - ChainId json.Number - Revision uint8 `json:"contains,omitempty"` + Name string `json:"name"` + Version json.Number `json:"version"` + ChainId json.Number `json:"chainId"` + Revision uint8 `json:"revision,omitempty"` } type TypeDefinition struct { - Name string `json:"-"` - Encoding *big.Int + Name string `json:"-"` + Encoding *big.Int //TODO: maybe remove this Parameters []TypeParameter } type TypeParameter struct { - Name string - Type string + Name string `json:"name"` + Type string `json:"type"` Contains string `json:"contains,omitempty"` } @@ -147,11 +148,17 @@ func NewTypedData(types []TypeDefinition, primaryType string, domain Domain, mes return td, fmt.Errorf("error unmarshalling the message: %w", err) } + revision, err := NewRevision(domain.Revision) + if err != nil { + return td, fmt.Errorf("error getting revision: %w", err) + } + td = TypedData{ Types: typesMap, PrimaryType: primaryType, Domain: domain, Message: messageMap, + Revision: revision, } if _, ok := td.Types[primaryType]; !ok { return td, fmt.Errorf("invalid primary type: %s", primaryType) diff --git a/typedData/types.go b/typedData/types.go index 619c059e..c5852b49 100644 --- a/typedData/types.go +++ b/typedData/types.go @@ -7,8 +7,8 @@ type ( Bool bool String string Selector string - U128 big.Int - I128 big.Int + U128 *big.Int + I128 *big.Int ContractAddress string ClassHash string Timestamp U128 @@ -16,13 +16,13 @@ type ( ) type U256 struct { - Low U128 - High U128 + Low U128 `json:"low"` + High U128 `json:"high"` } type TokenAmount struct { TokenAddress ContractAddress `json:"token_address"` - Amount U256 + Amount U256 `json:"amount"` } type NftId struct {