diff --git a/athena_abi/abi_version_parsing_test.go b/athena_abi/abi_version_parsing_test.go new file mode 100644 index 00000000..70b1a0d0 --- /dev/null +++ b/athena_abi/abi_version_parsing_test.go @@ -0,0 +1,105 @@ +package athena_abi + +import ( + "encoding/hex" + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestParseV1AbisToStarknetAbi(t *testing.T) { + + v1Abis, err := GetAbisForVersion("v1") + if err != nil { + t.Fatalf("Failed to get v1 ABIs: %v", err) + } + + for abiName, abi := range v1Abis { + t.Logf("Testing ABI Name: %s, ", abiName) + + _, err := StarknetAbiFromJSON(abi, abiName, []byte{}) + if err != nil { + t.Errorf("Failed to parse v1 ABI for %s: %v", abiName, err) + } + } +} + +func TestParseV2AbisToStarknetAbi(t *testing.T) { + v2Abis, err := GetAbisForVersion("v2") + if err != nil { + t.Fatalf("Failed to get v2 ABIs: %v", err) + } + + for abiName, abi := range v2Abis { + // Replace with actual ABI name from data + decoder, err := StarknetAbiFromJSON(abi, abiName, []byte{}) + if err != nil { + t.Errorf("Failed to parse v2 ABI: %v", err) + } + + // Add specific assertions based on ABI name + if abiName == "starknet_eth" { + if _, ok := decoder.Functions["transfer"]; !ok { + t.Errorf("Expected function 'transfer' in starknet_eth ABI") + } + } + + if abiName == "argent_account_v3" { + funcDef := decoder.Functions["change_guardian_backup"] + if len(funcDef.inputs) > 0 && funcDef.inputs[0].Type.idStr() != "StarknetStruct" { + t.Errorf("Expected first input to be a StarknetStruct") + } + } + } +} + +func TestNamedTupleParsing(t *testing.T) { + abiFile := filepath.Join("abis", "v1", "legacy_named_tuple.json") + abiJson, err := os.ReadFile(abiFile) + if err != nil { + t.Fatalf("Failed to read ABI file: %v", err) + } + + var abi []map[string]interface{} + if err := json.Unmarshal(abiJson, &abi); err != nil { + t.Fatalf("Failed to unmarshal ABI JSON: %v", err) + } + + classHash, _ := hex.DecodeString("0484c163658bcce5f9916f486171ac60143a92897533aa7ff7ac800b16c63311") + parsedAbi, err := StarknetAbiFromJSON(abi, "legacy_named_tuple", classHash) + if err != nil { + t.Fatalf("Failed to parse named tuple ABI: %v", err) + } + + // Assertions based on the parsed ABI + funcDef := parsedAbi.Functions["xor_counters"] + if len(funcDef.inputs) == 0 || funcDef.inputs[0].Name != "index_and_x" { + t.Errorf("Expected input 'index_and_x' in xor_counters function") + } +} + +func TestStorageAddressParsing(t *testing.T) { + abiFile := filepath.Join("abis", "v2", "storage_address.json") + abiJson, err := os.ReadFile(abiFile) + if err != nil { + t.Fatalf("Failed to read ABI file: %v", err) + } + + var abi []map[string]interface{} + if err := json.Unmarshal(abiJson, &abi); err != nil { + t.Fatalf("Failed to unmarshal ABI JSON: %v", err) + } + + classHash, _ := hex.DecodeString("0484c163658bcce5f9916f486171ac60143a92897533aa7ff7ac800b16c63311") + parsedAbi, err := StarknetAbiFromJSON(abi, "storage_address", classHash) + if err != nil { + t.Fatalf("Failed to parse storage address ABI: %v", err) + } + + // Assertions based on parsed ABI + storageFunction := parsedAbi.Functions["storage_read"] + if len(storageFunction.inputs) != 2 || storageFunction.inputs[0].Name != "address_domain" { + t.Errorf("Expected two inputs with first input named 'address_domain'") + } +} diff --git a/athena_abi/core.go b/athena_abi/core.go index 93f98fcd..ba3fc9a0 100644 --- a/athena_abi/core.go +++ b/athena_abi/core.go @@ -37,6 +37,7 @@ func StarknetAbiFromJSON(abiJson []map[string]interface{}, abiName string, class definedTypes, err := ParseEnumsAndStructs(groupedAbi["type_def"]) if err != nil { sortedDefs, errDef := TopoSortTypeDefs(groupedAbi["type_def"]) + if errDef == nil { defineTypes, errDtypes := ParseEnumsAndStructs(sortedDefs) definedTypes = defineTypes diff --git a/athena_abi/parse.go b/athena_abi/parse.go index dacc189e..5dd4d150 100644 --- a/athena_abi/parse.go +++ b/athena_abi/parse.go @@ -49,23 +49,44 @@ func extractInnerType(abiType string) string { return abiType[start+1 : end] } -// The function takes in a list of type definitions (dict) and returns a dict of sets (map[string]bool) func BuildTypeGraph(typeDefs []map[string]interface{}) map[string]map[string]bool { outputGraph := make(map[string]map[string]bool) + for _, typeDef := range typeDefs { referencedTypes := []string{} + + // Check if the type is a struct if typeDef["type"] == "struct" { - for _, member := range typeDef["members"].([]map[string]interface{}) { - referencedTypes = append(referencedTypes, member["type"].(string)) + // Handle if "members" is []map[string]interface{} + if membersMap, ok := typeDef["members"].([]map[string]interface{}); ok { + for _, member := range membersMap { + referencedTypes = append(referencedTypes, member["type"].(string)) + } + } else if members, ok := typeDef["members"].([]interface{}); ok { + // Handle if "members" is []interface{} + for _, member := range members { + if memberMap, ok := member.(map[string]interface{}); ok { + referencedTypes = append(referencedTypes, memberMap["type"].(string)) + } + } } } else { - for _, variant := range typeDef["variants"].([]map[string]interface{}) { - referencedTypes = append(referencedTypes, variant["type"].(string)) + // Handle variants + if variants, ok := typeDef["variants"].([]map[string]interface{}); ok { + for _, variant := range variants { + referencedTypes = append(referencedTypes, variant["type"].(string)) + } + } else if variants, ok := typeDef["variants"].([]interface{}); ok { + for _, variant := range variants { + if variantMap, ok := variant.(map[string]interface{}); ok { + referencedTypes = append(referencedTypes, variantMap["type"].(string)) + } + } } } + // Collect referenced types, excluding core types refTypes := make(map[string]bool) - for _, typeStr := range referencedTypes { if _, ok := StarknetCoreTypes[typeStr]; ok { continue @@ -80,7 +101,10 @@ func BuildTypeGraph(typeDefs []map[string]interface{}) map[string]map[string]boo refTypes[typeStr] = true } - outputGraph[typeDef["name"].(string)] = refTypes + // Safely assert the name of the type + if name, ok := typeDef["name"].(string); ok { + outputGraph[name] = refTypes + } } return outputGraph diff --git a/athena_abi/utils.go b/athena_abi/utils.go index 3c1e6e3e..3a50e339 100644 --- a/athena_abi/utils.go +++ b/athena_abi/utils.go @@ -1,6 +1,7 @@ package athena_abi import ( + "io/fs" "math/big" "golang.org/x/crypto/sha3" @@ -100,3 +101,54 @@ func loadABI(abiName string, abiVersion int) (map[string]interface{}, error) { return abiData, nil } + +func GetAbisForVersion(abiVersion string) (map[string][]map[string]interface{}, error) { + abiDir := filepath.Join("abis", abiVersion) + abis := make(map[string][]map[string]interface{}) + + err := filepath.Walk(abiDir, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + + // Check for JSON files only + if filepath.Ext(path) == ".json" { + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + var rawData interface{} + if err := json.NewDecoder(file).Decode(&rawData); err != nil { + return err + } + + // Check if rawData is of type []interface{} + abiList, ok := rawData.([]interface{}) + if !ok { + return fmt.Errorf("expected ABI data to be []interface{}, got %T", rawData) + } + + // Convert []interface{} to []map[string]interface{} + var abiData []map[string]interface{} + for _, item := range abiList { + abiMap, ok := item.(map[string]interface{}) + if !ok { + return fmt.Errorf("expected item to be map[string]interface{}, got %T", item) + } + abiData = append(abiData, abiMap) + } + + abis[info.Name()] = abiData + } + + return nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to load ABIs: %w", err) + } + + return abis, nil +} diff --git a/athena_abi/utils_test.go b/athena_abi/utils_test.go index 32e138b9..55615cce 100644 --- a/athena_abi/utils_test.go +++ b/athena_abi/utils_test.go @@ -1,10 +1,11 @@ package athena_abi import ( - "github.com/stretchr/testify/assert" "math/big" "strconv" "testing" + + "github.com/stretchr/testify/assert" ) func TestBigIntToBytes(t *testing.T) {