From 3be80685c794e829c89adb46026d35ed3b35ffca Mon Sep 17 00:00:00 2001 From: Nevio Vesic <0x19@users.noreply.github.com> Date: Sun, 19 May 2024 19:39:52 +0200 Subject: [PATCH] Overall improvements (#214) --- .gitignore | 1 + ast/builder.go | 6 +- ast/contract.go | 1 + ast/reference.go | 1 + ast/source_unit.go | 20 ++++++ ast/src.go | 12 ++-- ast/state_variable.go | 100 ++++++++++++++++++++++++++++- bindings/otterscan.go | 26 +++++++- bindings/trace.go | 26 ++++++++ contracts/descriptor.go | 12 +++- contracts/parser.go | 4 ++ contracts/source.go | 12 ++++ ir/contract.go | 6 +- ir/function_call.go | 4 +- opcode/decompiler.go | 15 +++-- opcode/events.go | 55 ++++++++++++++++ opcode/instructions.go | 1 + opcode/log.go | 131 ++++++++++++++++++++++++++++++++++++++ opcode/matcher.go | 49 --------------- opcode/opcodes.go | 20 ++++++ sources.go | 135 +++++++++++++++++++++++++++++++++++++++- utils/networks.go | 4 ++ utils/version.go | 12 +++- validation/verifier.go | 52 +++++++++++++++- 24 files changed, 625 insertions(+), 80 deletions(-) create mode 100644 bindings/trace.go create mode 100644 opcode/events.go create mode 100644 opcode/log.go diff --git a/.gitignore b/.gitignore index 20d273c3..b77b174b 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ go.work solgo playground/* bin/* +.idea \ No newline at end of file diff --git a/ast/builder.go b/ast/builder.go index 1dff4517..bb810711 100644 --- a/ast/builder.go +++ b/ast/builder.go @@ -94,7 +94,7 @@ func (b *ASTBuilder) ToJSON() ([]byte, error) { return b.InterfaceToJSON(b.tree.GetRoot()) } -// ToPrettyJSON converts the provided data to a JSON byte array. +// InterfaceToJSON converts the provided data to a JSON byte array. func (b *ASTBuilder) InterfaceToJSON(data interface{}) ([]byte, error) { return json.Marshal(data) } @@ -129,6 +129,10 @@ func (b *ASTBuilder) ImportFromJSON(ctx context.Context, jsonBytes []byte) (*Roo return toReturn, nil } +func (b *ASTBuilder) GetImports() []Node[NodeType] { + return b.currentImports +} + // GarbageCollect cleans up the ASTBuilder after resolving references. func (b *ASTBuilder) GarbageCollect() { b.currentEnums = nil diff --git a/ast/contract.go b/ast/contract.go index e2c824d5..3c1842ac 100644 --- a/ast/contract.go +++ b/ast/contract.go @@ -390,6 +390,7 @@ func (c *Contract) Parse(unitCtx *parser.SourceUnitContext, ctx *parser.Contract ) contractId := c.GetNextID() + contractNode := &Contract{ Id: contractId, Name: ctx.Identifier().GetText(), diff --git a/ast/reference.go b/ast/reference.go index 43150822..e6c2ed60 100644 --- a/ast/reference.go +++ b/ast/reference.go @@ -294,6 +294,7 @@ func (r *Resolver) resolveEntrySourceUnit() { for _, entry := range node.GetExportedSymbols() { if len(r.sources.EntrySourceUnitName) > 0 && r.sources.EntrySourceUnitName == entry.GetName() { + r.tree.astRoot.SetEntrySourceUnit(entry.GetId()) return } diff --git a/ast/source_unit.go b/ast/source_unit.go index 65cc7f6e..a2a33c72 100644 --- a/ast/source_unit.go +++ b/ast/source_unit.go @@ -24,6 +24,7 @@ type SourceUnit[T NodeType] struct { AbsolutePath string `json:"absolutePath"` // AbsolutePath is the absolute path of the source unit. Name string `json:"name"` // Name is the name of the source unit. This is going to be one of the following: contract, interface or library name. It's here for convenience. NodeType ast_pb.NodeType `json:"nodeType"` // NodeType is the type of the AST node. + Kind ast_pb.NodeType `json:"kind"` // Kind is the type of the AST node (contract, library, interface). Nodes []Node[NodeType] `json:"nodes"` // Nodes is the list of AST nodes. Src SrcNode `json:"src"` // Src is the source code location. } @@ -107,6 +108,11 @@ func (s *SourceUnit[T]) GetType() ast_pb.NodeType { return s.NodeType } +// GetKind returns the type of the source unit. +func (s *SourceUnit[T]) GetKind() ast_pb.NodeType { + return s.Kind +} + // GetSrc returns the source code location of the source unit. func (s *SourceUnit[T]) GetSrc() SrcNode { return s.Src @@ -302,6 +308,7 @@ func (b *ASTBuilder) EnterSourceUnit(ctx *parser.SourceUnitContext) { if interfaceCtx, ok := child.(*parser.InterfaceDefinitionContext); ok { license := getLicenseFromSources(b.sources, b.comments, interfaceCtx.Identifier().GetText()) sourceUnit := NewSourceUnit[Node[ast_pb.SourceUnit]](b, interfaceCtx.Identifier().GetText(), license) + sourceUnit.Kind = ast_pb.NodeType_KIND_INTERFACE interfaceNode := NewInterfaceDefinition(b) interfaceNode.Parse(ctx, interfaceCtx, rootNode, sourceUnit) b.sourceUnits = append(b.sourceUnits, sourceUnit) @@ -310,6 +317,7 @@ func (b *ASTBuilder) EnterSourceUnit(ctx *parser.SourceUnitContext) { if libraryCtx, ok := child.(*parser.LibraryDefinitionContext); ok { license := getLicenseFromSources(b.sources, b.comments, libraryCtx.Identifier().GetText()) sourceUnit := NewSourceUnit[Node[ast_pb.SourceUnit]](b, libraryCtx.Identifier().GetText(), license) + sourceUnit.Kind = ast_pb.NodeType_KIND_LIBRARY libraryNode := NewLibraryDefinition(b) libraryNode.Parse(ctx, libraryCtx, rootNode, sourceUnit) b.sourceUnits = append(b.sourceUnits, sourceUnit) @@ -318,11 +326,23 @@ func (b *ASTBuilder) EnterSourceUnit(ctx *parser.SourceUnitContext) { if contractCtx, ok := child.(*parser.ContractDefinitionContext); ok { license := getLicenseFromSources(b.sources, b.comments, contractCtx.Identifier().GetText()) sourceUnit := NewSourceUnit[Node[ast_pb.SourceUnit]](b, contractCtx.Identifier().GetText(), license) + sourceUnit.Kind = ast_pb.NodeType_KIND_CONTRACT contractNode := NewContractDefinition(b) contractNode.Parse(ctx, contractCtx, rootNode, sourceUnit) b.sourceUnits = append(b.sourceUnits, sourceUnit) } } + + // Idea here is to basically set the source unit entry name as soon as we have parsed all of the classes. + // Now this won't be possible always but nevertheless. (In rest of the cases, resolver will take care of it) + if b.sources.EntrySourceUnitName != "" { + for _, sourceUnit := range b.sourceUnits { + if b.sources.EntrySourceUnitName == sourceUnit.GetName() { + rootNode.SetEntrySourceUnit(sourceUnit.GetId()) + return + } + } + } } // ExitSourceUnit is called when the ASTBuilder exits a source unit context. diff --git a/ast/src.go b/ast/src.go index 4ee6d75f..7ae2b030 100644 --- a/ast/src.go +++ b/ast/src.go @@ -6,12 +6,12 @@ import ( // SrcNode represents a node in the source code. type SrcNode struct { - Line int64 `json:"line"` // Line number of the source node in the source code. - Column int64 `json:"column"` // Column number of the source node in the source code. - Start int64 `json:"start"` // Start position of the source node in the source code. - End int64 `json:"end"` // End position of the source node in the source code. - Length int64 `json:"length"` // Length of the source node in the source code. - ParentIndex int64 `json:"parent_index,omitempty"` // Index of the parent node in the source code. + Line int64 `json:"line"` // Line number of the source node in the source code. + Column int64 `json:"column"` // Column number of the source node in the source code. + Start int64 `json:"start"` // Start position of the source node in the source code. + End int64 `json:"end"` // End position of the source node in the source code. + Length int64 `json:"length"` // Length of the source node in the source code. + ParentIndex int64 `json:"parentIndex,omitempty"` // Index of the parent node in the source code. } // GetLine returns the line number of the source node in the source code. diff --git a/ast/state_variable.go b/ast/state_variable.go index ca5cab07..b07a24ba 100644 --- a/ast/state_variable.go +++ b/ast/state_variable.go @@ -1,11 +1,11 @@ package ast import ( - "strings" - v3 "github.com/cncf/xds/go/xds/type/v3" + "github.com/goccy/go-json" ast_pb "github.com/unpackdev/protos/dist/go/ast" "github.com/unpackdev/solgo/parser" + "strings" ) // StateVariableDeclaration represents a state variable declaration in the Solidity abstract syntax tree (AST). @@ -145,6 +145,102 @@ func (v *StateVariableDeclaration) GetInitialValue() Node[NodeType] { return v.InitialValue } +// UnmarshalJSON customizes the JSON unmarshaling for StateVariableDeclaration. +func (v *StateVariableDeclaration) UnmarshalJSON(data []byte) error { + var tempMap map[string]json.RawMessage + if err := json.Unmarshal(data, &tempMap); err != nil { + return err + } + + if id, ok := tempMap["id"]; ok { + if err := json.Unmarshal(id, &v.Id); err != nil { + return err + } + } + + if name, ok := tempMap["name"]; ok { + if err := json.Unmarshal(name, &v.Name); err != nil { + return err + } + } + + if isConstant, ok := tempMap["isConstant"]; ok { + if err := json.Unmarshal(isConstant, &v.Constant); err != nil { + return err + } + } + + if isStateVariable, ok := tempMap["isStateVariable"]; ok { + if err := json.Unmarshal(isStateVariable, &v.StateVariable); err != nil { + return err + } + } + + if nodeType, ok := tempMap["nodeType"]; ok { + if err := json.Unmarshal(nodeType, &v.NodeType); err != nil { + return err + } + } + + if visibility, ok := tempMap["visibility"]; ok { + if err := json.Unmarshal(visibility, &v.Visibility); err != nil { + return err + } + } + + if storageLocation, ok := tempMap["storageLocation"]; ok { + if err := json.Unmarshal(storageLocation, &v.StorageLocation); err != nil { + return err + } + } + + if mutability, ok := tempMap["mutability"]; ok { + if err := json.Unmarshal(mutability, &v.StateMutability); err != nil { + return err + } + } + + if src, ok := tempMap["src"]; ok { + if err := json.Unmarshal(src, &v.Src); err != nil { + return err + } + } + + if scope, ok := tempMap["scope"]; ok { + if err := json.Unmarshal(scope, &v.Scope); err != nil { + return err + } + } + + if expression, ok := tempMap["initialValue"]; ok { + if err := json.Unmarshal(expression, &v.InitialValue); err != nil { + var tempNodeMap map[string]json.RawMessage + if err := json.Unmarshal(expression, &tempNodeMap); err != nil { + return err + } + + var tempNodeType ast_pb.NodeType + if err := json.Unmarshal(tempNodeMap["nodeType"], &tempNodeType); err != nil { + return err + } + + node, err := unmarshalNode(expression, tempNodeType) + if err != nil { + return err + } + v.InitialValue = node + } + } + + if typeDescription, ok := tempMap["typeDescription"]; ok { + if err := json.Unmarshal(typeDescription, &v.TypeDescription); err != nil { + return err + } + } + + return nil +} + // ToProto returns the protobuf representation of the state variable declaration. func (v *StateVariableDeclaration) ToProto() NodeType { proto := ast_pb.StateVariable{ diff --git a/bindings/otterscan.go b/bindings/otterscan.go index 664f5928..b7a8c7ec 100644 --- a/bindings/otterscan.go +++ b/bindings/otterscan.go @@ -3,6 +3,7 @@ package bindings import ( "context" "fmt" + "github.com/pkg/errors" "github.com/ethereum/go-ethereum/common" "github.com/unpackdev/solgo/utils" @@ -19,9 +20,9 @@ type CreatorInformation struct { // GetContractCreator queries the Ethereum blockchain to find the creator of a specified smart contract. This method // utilizes the Ethereum JSON-RPC API to request creator information, which includes both the creator's address and // the transaction hash of the contract's creation. It's a valuable tool for auditing and tracking the origins of -// contracts on the network. WORKS ONLY WITH ERIGON NODE - OR NODES THAT SUPPORT OTTERSCAN! +// contracts on the network. WORKS ONLY WITH ERIGON NODE OR QUICKNODE PROVIDER - OR NODES THAT SUPPORT OTTERSCAN! func (m *Manager) GetContractCreator(ctx context.Context, network utils.Network, contract common.Address) (*CreatorInformation, error) { - client := m.clientPool.GetClientByGroup(string(network)) + client := m.clientPool.GetClientByGroup(network.String()) if client == nil { return nil, fmt.Errorf("client not found for network %s", network) } @@ -30,7 +31,26 @@ func (m *Manager) GetContractCreator(ctx context.Context, network utils.Network, var result *CreatorInformation if err := rpcClient.CallContext(ctx, &result, "ots_getContractCreator", contract.Hex()); err != nil { - return nil, fmt.Errorf("failed to fetch otterscan creator information: %v", err) + return nil, errors.Wrap(err, "failed to fetch otterscan creator information") + } + + return result, nil +} + +// GetTransactionBySenderAndNonce retrieves a transaction hash based on a specific sender's address and nonce. +// This function also utilizes the Ethereum JSON-RPC API and requires a node that supports specific transaction lookup +// by sender and nonce, which is particularly useful for tracking transaction sequences and debugging transaction flows. +func (m *Manager) GetTransactionBySenderAndNonce(ctx context.Context, network utils.Network, sender common.Address, nonce int64) (*common.Hash, error) { + client := m.clientPool.GetClientByGroup(network.String()) + if client == nil { + return nil, fmt.Errorf("client not found for network %s", network) + } + + rpcClient := client.GetRpcClient() + var result *common.Hash + + if err := rpcClient.CallContext(ctx, &result, "ots_getTransactionBySenderAndNonce", sender.Hex(), nonce); err != nil { + return nil, errors.Wrap(err, "failed to fetch otterscan get transaction by sender and nonce information") } return result, nil diff --git a/bindings/trace.go b/bindings/trace.go new file mode 100644 index 00000000..e1a751b8 --- /dev/null +++ b/bindings/trace.go @@ -0,0 +1,26 @@ +package bindings + +import ( + "context" + "fmt" + "github.com/pkg/errors" + + "github.com/ethereum/go-ethereum/common" + "github.com/unpackdev/solgo/utils" +) + +func (m *Manager) TraceCallMany(ctx context.Context, network utils.Network, sender common.Address, nonce int64) (*common.Hash, error) { + client := m.clientPool.GetClientByGroup(network.String()) + if client == nil { + return nil, fmt.Errorf("client not found for network %s", network) + } + + rpcClient := client.GetRpcClient() + var result *common.Hash + + if err := rpcClient.CallContext(ctx, &result, "trace_callMany", sender.Hex(), nonce); err != nil { + return nil, errors.Wrap(err, "failed to execute trace_callMany") + } + + return result, nil +} diff --git a/contracts/descriptor.go b/contracts/descriptor.go index 25157e63..46ba31e8 100644 --- a/contracts/descriptor.go +++ b/contracts/descriptor.go @@ -40,9 +40,10 @@ type Descriptor struct { // SourcesRaw is the raw sources from Etherscan|BscScan|etc. Should not be used anywhere except in // the contract discovery process. - SourcesRaw *etherscan.Contract `json:"-"` - Sources *solgo.Sources `json:"sources,omitempty"` - SourceProvider string `json:"source_provider,omitempty"` + SourcesRaw *etherscan.Contract `json:"-"` + Sources *solgo.Sources `json:"sources,omitempty"` + SourcesUnsorted *solgo.Sources `json:"-"` + SourceProvider string `json:"source_provider,omitempty"` // Source detection related fields. Detector *detector.Detector `json:"-"` @@ -144,6 +145,11 @@ func (d *Descriptor) GetSources() *solgo.Sources { return d.Sources } +// GetUnsortedSources returns the parsed sources of the contract, providing a structured view of the contract's code. +func (d *Descriptor) GetUnsortedSources() *solgo.Sources { + return d.SourcesUnsorted +} + // GetSourcesRaw returns the raw contract source as obtained from external providers like Etherscan. func (d *Descriptor) GetSourcesRaw() *etherscan.Contract { return d.SourcesRaw diff --git a/contracts/parser.go b/contracts/parser.go index 4cb99803..9603592d 100644 --- a/contracts/parser.go +++ b/contracts/parser.go @@ -38,6 +38,10 @@ func (c *Contract) Parse(ctx context.Context) error { ) return err } + + // Sets the address for more understanding when we need to troubleshoot contract parsing + parser.GetIR().SetAddress(c.addr) + c.descriptor.Detector = parser c.descriptor.SolgoVersion = utils.GetBuildVersionByModule("github.com/unpackdev/solgo") diff --git a/contracts/source.go b/contracts/source.go index 8b7ec50f..dab56d44 100644 --- a/contracts/source.go +++ b/contracts/source.go @@ -75,7 +75,19 @@ func (c *Contract) DiscoverSourceCode(ctx context.Context) error { return fmt.Errorf("failed to create new sources from etherscan response: %s", err) } + unsortedSources, err := solgo.NewUnsortedSourcesFromEtherScan(response.Name, response.SourceCode) + if err != nil { + zap.L().Error( + "failed to create new unsorted sources from etherscan response", + zap.Error(err), + zap.String("network", c.network.String()), + zap.String("contract_address", c.addr.String()), + ) + return fmt.Errorf("failed to create new unsorted sources from etherscan response: %s", err) + } + c.descriptor.Sources = sources + c.descriptor.SourcesUnsorted = unsortedSources license := strings.ReplaceAll(c.descriptor.SourcesRaw.LicenseType, "\r", "") license = strings.ReplaceAll(license, "\n", "") diff --git a/ir/contract.go b/ir/contract.go index 2d9fcef2..e5dec24d 100644 --- a/ir/contract.go +++ b/ir/contract.go @@ -248,7 +248,7 @@ func (c *Contract) ToProto() *ir_pb.Contract { // processContract processes the contract unit and returns the Contract. func (b *Builder) processContract(unit *ast.SourceUnit[ast.Node[ast_pb.SourceUnit]]) *Contract { - contract := getContractByNodeType(unit.GetContract()) + contract := GetContractByNodeType(unit.GetContract()) contractNode := &Contract{ Unit: unit, @@ -360,8 +360,8 @@ func (b *Builder) processContract(unit *ast.SourceUnit[ast.Node[ast_pb.SourceUni return contractNode } -// getContractByNodeType returns the ContractNode based on the node type. -func getContractByNodeType(c ast.Node[ast.NodeType]) ContractNode { +// GetContractByNodeType returns the ContractNode based on the node type. +func GetContractByNodeType(c ast.Node[ast.NodeType]) ContractNode { switch contract := c.(type) { case *ast.Library: return contract diff --git a/ir/function_call.go b/ir/function_call.go index f037aa7e..c185ce28 100644 --- a/ir/function_call.go +++ b/ir/function_call.go @@ -161,7 +161,7 @@ func (b *Builder) processFunctionCall(fn *Function, unit *ast.FunctionCall) *Fun if fn != nil { toReturn.ExternalContractId = fn.GetAST().GetScope() sourceContract := b.astBuilder.GetTree().GetById(fn.GetAST().GetScope()) - toReturn.referencedContract = getContractByNodeType(sourceContract) + toReturn.referencedContract = GetContractByNodeType(sourceContract) } } } @@ -177,7 +177,7 @@ func (b *Builder) processFunctionCall(fn *Function, unit *ast.FunctionCall) *Fun if fn != nil { toReturn.ExternalContractId = fn.GetAST().GetScope() sourceContract := b.astBuilder.GetTree().GetById(fn.GetAST().GetScope()) - toReturn.referencedContract = getContractByNodeType(sourceContract) + toReturn.referencedContract = GetContractByNodeType(sourceContract) toReturn.ExternalContractName = toReturn.referencedContract.GetName() } diff --git a/opcode/decompiler.go b/opcode/decompiler.go index 3c93edf9..3a80cac0 100644 --- a/opcode/decompiler.go +++ b/opcode/decompiler.go @@ -43,7 +43,7 @@ func (d *Decompiler) GetBytecodeSize() uint64 { // Decompile processes the bytecode, populates the instructions slice, and identifies function entry points. func (d *Decompiler) Decompile() error { if d.bytecodeSize < 1 { - return ErrEmptyBytecode + return fmt.Errorf("bytecode is empty") } offset := 0 @@ -58,11 +58,14 @@ func (d *Decompiler) Decompile() error { if op.IsPush() { argSize := int(op) - int(PUSH1) + 1 - if offset+argSize >= len(d.bytecode) { - break + if offset+argSize < len(d.bytecode) { + instruction.Args = d.bytecode[offset+1 : offset+argSize+1] + offset += argSize + } else { + // If we don't have enough bytes for PUSH arguments, use the remaining bytes + instruction.Args = d.bytecode[offset+1:] + offset = len(d.bytecode) - 1 } - instruction.Args = d.bytecode[offset+1 : offset+argSize+1] - offset += argSize } d.instructions = append(d.instructions, instruction) @@ -74,6 +77,8 @@ func (d *Decompiler) Decompile() error { offset++ } + + fmt.Printf("Total instructions processed: %d\n", len(d.instructions)) return nil } diff --git a/opcode/events.go b/opcode/events.go new file mode 100644 index 00000000..d59a0a6d --- /dev/null +++ b/opcode/events.go @@ -0,0 +1,55 @@ +package opcode + +import ( + "fmt" + "github.com/ethereum/go-ethereum/common" +) + +// EventTreeNode represents a node in the opcode execution tree that represents an event. +type EventTreeNode struct { + EventSignatureHex string `json:"eventSignatureHex"` + EventSignature string `json:"event_signature"` + EventBytesHex string `json:"event_bytes_hex"` + EventBytes []byte `json:"event_bytes"` + HasEventSignature bool `json:"has_event_signature"` + *TreeNode +} + +// GetEvents iterates through LOG1 to LOG4 instructions, decodes their arguments, and collects them into EventTreeNode structures. +func (d *Decompiler) GetEvents() []*EventTreeNode { + logInstructions := map[OpCode]int{ + LOG1: 1, + LOG2: 2, + LOG3: 3, + LOG4: 4, + } + + events := make([]*EventTreeNode, 0) + + for opCode, topicCount := range logInstructions { + instructions := d.GetInstructionsByOpCode(opCode) + for _, instruction := range instructions { + _, topics := d.decodeLogArgs(instruction.Offset, topicCount) + + if len(topics) < 1 { + continue + } + + eventSignature := topics[0] + eventSignatureBytes := common.Hex2Bytes(eventSignature) + + eventNode := &EventTreeNode{ + EventSignature: eventSignature, + EventSignatureHex: fmt.Sprintf("0x%s", eventSignature), + EventBytesHex: common.Bytes2Hex(eventSignatureBytes), + EventBytes: eventSignatureBytes, + HasEventSignature: len(eventSignatureBytes) == 32, + TreeNode: &TreeNode{Instruction: instruction, Children: make([]*TreeNode, 0)}, + } + + events = append(events, eventNode) + } + } + + return events +} diff --git a/opcode/instructions.go b/opcode/instructions.go index 7d61b6ae..142fe333 100644 --- a/opcode/instructions.go +++ b/opcode/instructions.go @@ -8,6 +8,7 @@ type Instruction struct { OpCode OpCode `json:"opcode"` Args []byte `json:"args"` Description string `json:"description"` + Data []byte `json:"data"` } // GetOffset returns the offset of the instruction. diff --git a/opcode/log.go b/opcode/log.go new file mode 100644 index 00000000..03ced8d1 --- /dev/null +++ b/opcode/log.go @@ -0,0 +1,131 @@ +package opcode + +import ( + "github.com/ethereum/go-ethereum/common" +) + +// DecodeLOG1 returns all instructions with the LOG1 OpCode and decodes their arguments. +func (d *Decompiler) DecodeLOG1() []Instruction { + return d.decodeLogInstructions(LOG1, 1) +} + +// DecodeLOG2 returns all instructions with the LOG2 OpCode and decodes their arguments. +func (d *Decompiler) DecodeLOG2() []Instruction { + return d.decodeLogInstructions(LOG2, 2) +} + +// DecodeLOG3 returns all instructions with the LOG3 OpCode and decodes their arguments. +func (d *Decompiler) DecodeLOG3() []Instruction { + return d.decodeLogInstructions(LOG3, 3) +} + +// DecodeLOG4 returns all instructions with the LOG4 OpCode and decodes their arguments. +func (d *Decompiler) DecodeLOG4() []Instruction { + return d.decodeLogInstructions(LOG4, 4) +} + +// decodeLogInstructions processes the LOG instructions and decodes their arguments. +func (d *Decompiler) decodeLogInstructions(opCode OpCode, topicCount int) []Instruction { + instructions := d.GetInstructionsByOpCode(opCode) + + for _, instruction := range instructions { + data, topics := d.decodeLogArgs(instruction.Offset, topicCount) + instruction.Data = data + _ = topics + } + + return instructions +} + +// decodeLogArgs decodes the arguments for a LOG instruction. +func (d *Decompiler) decodeLogArgs(offset, topicCount int) ([]byte, []string) { + var stack [][]byte + currentIndex := -1 + + // Find the index of the LOG instruction + for i, instr := range d.instructions { + if instr.Offset == offset { + currentIndex = i + break + } + } + + if currentIndex == -1 { + return nil, nil + } + + // Process instructions to reconstruct the stack state + for i := 0; i < currentIndex; i++ { + instr := d.instructions[i] + switch { + case instr.OpCode.IsPush(): + stack = append(stack, instr.Args) + case instr.OpCode == SWAP1: + if len(stack) >= 2 { + stack[len(stack)-1], stack[len(stack)-2] = stack[len(stack)-2], stack[len(stack)-1] + } + case instr.OpCode == SWAP2: + if len(stack) >= 3 { + stack[len(stack)-1], stack[len(stack)-3] = stack[len(stack)-3], stack[len(stack)-1] + } + case instr.OpCode == SWAP3: + if len(stack) >= 4 { + stack[len(stack)-1], stack[len(stack)-4] = stack[len(stack)-4], stack[len(stack)-1] + } + case instr.OpCode == SWAP4: + if len(stack) >= 5 { + stack[len(stack)-1], stack[len(stack)-5] = stack[len(stack)-5], stack[len(stack)-1] + } + case instr.OpCode == DUP1: + if len(stack) >= 1 { + stack = append(stack, stack[len(stack)-1]) + } + case instr.OpCode == DUP2: + if len(stack) >= 2 { + stack = append(stack, stack[len(stack)-2]) + } + case instr.OpCode == DUP3: + if len(stack) >= 3 { + stack = append(stack, stack[len(stack)-3]) + } + case instr.OpCode == DUP4: + if len(stack) >= 4 { + stack = append(stack, stack[len(stack)-4]) + } + case instr.OpCode == POP: + if len(stack) > 0 { + stack = stack[:len(stack)-1] + } + } + } + + // Collect the data and topics for the LOG instruction + if len(stack) < topicCount+1 { + return nil, nil + } + + topics := make([]string, 0) + var data []byte + topicCountFound := 0 + + // Iterate from the end of the stack to find topics and data + for i := len(stack) - 1; i >= 0; i-- { + if len(stack[i]) == 32 && topicCountFound < topicCount { + topics = append([]string{common.Bytes2Hex(stack[i])}, topics...) // Prepend to maintain order + topicCountFound++ + } else if data == nil { + data = stack[i] + } + + // Break if we have found all topics and data + if topicCountFound == topicCount && data != nil { + break + } + } + + if data == nil { + return nil, nil + } + + return data, topics +} diff --git a/opcode/matcher.go b/opcode/matcher.go index 0ae90dd4..ef63ff87 100644 --- a/opcode/matcher.go +++ b/opcode/matcher.go @@ -2,7 +2,6 @@ package opcode import ( "bytes" - "fmt" "strings" "github.com/ethereum/go-ethereum/common" @@ -62,51 +61,3 @@ func (d *Decompiler) isStateVariableDeclaration(instruction Instruction) bool { } return false } - -// GetEvents returns the instruction trees for all events declared in the bytecode. -func (d *Decompiler) GetEvents() ([]*InstructionTree, error) { - // Initialize slice to hold event trees with estimated capacity - events := make([]*InstructionTree, 0) - - // Iterate through instructions to find event definitions - for _, instruction := range d.instructions { - // Check if the instruction is an event definition - if isEventDefinition(instruction) { - // Parse event arguments - args, err := parseEventArguments(instruction) - if err != nil { - return nil, fmt.Errorf("error parsing event arguments: %v", err) - } - - // Build instruction tree for the event definition with arguments - tree := NewInstructionTree(instruction) - tree.Instruction.Args = args // Append event arguments - events = append(events, tree) - } - } - - return events, nil -} - -// Function to determine if an instruction declares an event -func isEventDefinition(instruction Instruction) bool { - // Check if the instruction is a LOG operation - return instruction.OpCode == LOG0 || instruction.OpCode == LOG1 || - instruction.OpCode == LOG2 || instruction.OpCode == LOG3 || instruction.OpCode == LOG4 -} - -// Function to parse event arguments from instruction -func parseEventArguments(instruction Instruction) ([]byte, error) { - // Check if the instruction has arguments - if len(instruction.Args) > 0 { - // For LOG0, no additional argument is needed - if instruction.OpCode == LOG0 { - return []byte{}, nil - } - - // For other LOG operations, extract event arguments - eventArgs := instruction.Args[32:] // Skip the first 32 bytes (topic) - return eventArgs, nil - } - return nil, nil -} diff --git a/opcode/opcodes.go b/opcode/opcodes.go index a542b151..306fea90 100644 --- a/opcode/opcodes.go +++ b/opcode/opcodes.go @@ -159,6 +159,26 @@ func (op OpCode) IsFunctionEnd() bool { return op == RETURN || op == STOP } +// IsEvent checks if the given opcode corresponds to an event logging operation. +// The LOG0 to LOG4 opcodes are used to log events in the EVM. +func (op OpCode) IsEvent() bool { + switch op { + case LOG0, LOG1, LOG2, LOG3, LOG4: + return true + default: + return false + } +} + +// OpCode extensions for identifying PUSH32 and LOG opcodes +func (op OpCode) IsPush32() bool { + return op == PUSH32 +} + +func (op OpCode) IsLog() bool { + return op >= LOG0 && op <= LOG4 +} + // IsSelfDestruct checks if the given opcode corresponds to the SELFDESTRUCT operation. // The SELFDESTRUCT opcode is used in the EVM to destroy the current contract, sending its funds to the provided address. func (op OpCode) IsSelfDestruct() bool { diff --git a/sources.go b/sources.go index b4ad4807..596a386a 100644 --- a/sources.go +++ b/sources.go @@ -181,6 +181,67 @@ func NewSourcesFromPath(entrySourceUnitName, path string) (*Sources, error) { return sources, nil } +func NewUnsortedSourcesFromPath(entrySourceUnitName, path string) (*Sources, error) { + info, err := os.Stat(path) + if err != nil { + return nil, err // Return the error if the path does not exist or cannot be accessed + } + + if !info.IsDir() { + return nil, fmt.Errorf("path is not a directory: %s", path) + } + + var sourcesDir string + + if GetLocalSourcesPath() == "" { + _, filename, _, _ := runtime.Caller(0) + dir := filepath.Dir(filename) + sourcesDir = filepath.Clean(filepath.Join(dir, "sources")) + } else { + sourcesDir = GetLocalSourcesPath() + } + + sources := &Sources{ + MaskLocalSourcesPath: true, + LocalSourcesPath: sourcesDir, + LocalSources: false, + EntrySourceUnitName: entrySourceUnitName, + } + + files, err := os.ReadDir(path) + if err != nil { + return nil, err + } + + for _, file := range files { + if file.IsDir() { + continue // Skip directories + } + + // Check if the file has a .sol extension + if filepath.Ext(file.Name()) == ".sol" { + filePath := filepath.Join(path, file.Name()) + + content, err := os.ReadFile(filePath) + if err != nil { + return nil, err + } + + sources.SourceUnits = append(sources.SourceUnits, &SourceUnit{ + Name: strings.TrimSuffix(file.Name(), ".sol"), + Path: filePath, + Content: string(content), + }) + } + } + + if err := sources.SortContracts(); err != nil { + return nil, fmt.Errorf("failure while doing topological contract sorting: %s", err.Error()) + } + + return sources, nil +} + // NewSourcesFromMetadata creates a Sources from a metadata package ContractMetadata. // This is a helper function that ensures easier integration when working with the metadata package. func NewSourcesFromMetadata(md *metadata.ContractMetadata) *Sources { @@ -322,6 +383,72 @@ func NewSourcesFromEtherScan(entryContractName string, sc interface{}) (*Sources return sources, nil } +// NewUnsortedSourcesFromEtherScan creates a Sources from an EtherScan response. +// This is a helper function that ensures easier integration when working with the EtherScan provider. +// This includes BscScan, and other equivalent from the same family. +func NewUnsortedSourcesFromEtherScan(entryContractName string, sc interface{}) (*Sources, error) { + var sourcesDir string + + if GetLocalSourcesPath() == "" { + _, filename, _, _ := runtime.Caller(0) + dir := filepath.Dir(filename) + sourcesDir = filepath.Clean(filepath.Join(dir, "sources")) + } else { + sourcesDir = GetLocalSourcesPath() + } + + sources := &Sources{ + MaskLocalSourcesPath: true, + LocalSourcesPath: sourcesDir, + EntrySourceUnitName: entryContractName, + LocalSources: false, + } + + switch sourceCode := sc.(type) { + case string: + sources.AppendSource(&SourceUnit{ + Name: entryContractName, + Path: fmt.Sprintf("%s.sol", entryContractName), + Content: sourceCode, + }) + case map[string]interface{}: + // Create an instance of ContractMetadata + var contractMetadata metadata.ContractMetadata + + // Marshal the map into JSON, then Unmarshal it into the ContractMetadata struct + jsonBytes, err := json.Marshal(sourceCode) + if err != nil { + return nil, fmt.Errorf("error marshalling to json: %v", err) + } + + if err := json.Unmarshal(jsonBytes, &contractMetadata); err != nil { + return nil, fmt.Errorf("error unmarshalling to contract metadata: %v", err) + } + + for name, source := range contractMetadata.Sources { + sources.AppendSource(&SourceUnit{ + Name: strings.TrimSuffix(filepath.Base(name), ".sol"), + Path: name, + Content: source.Content, + }) + } + + case metadata.ContractMetadata: + for name, source := range sourceCode.Sources { + sources.AppendSource(&SourceUnit{ + Name: strings.TrimSuffix(filepath.Base(name), ".sol"), + Path: name, + Content: source.Content, + }) + } + + default: + return nil, fmt.Errorf("unknown source code type: %T", sourceCode) + } + + return sources, nil +} + // AppendSource appends a SourceUnit to the Sources. // If a SourceUnit with the same name already exists, it replaces it unless the new SourceUnit has less content. func (s *Sources) AppendSource(source *SourceUnit) { @@ -381,11 +508,13 @@ func (s *Sources) Validate() error { // in one file provided back to us. In that case, we should check if // specific file contains `contract {entrySourceUnitName}`. found := false + contractRegex := regexp.MustCompile(fmt.Sprintf(`\bcontract\s+%s\b`, regexp.QuoteMeta(s.EntrySourceUnitName))) + libraryRegex := regexp.MustCompile(fmt.Sprintf(`\blibrary\s+%s\b`, regexp.QuoteMeta(s.EntrySourceUnitName))) + for _, sourceUnit := range s.SourceUnits { - if strings.Contains(sourceUnit.Content, fmt.Sprintf("contract %s", s.EntrySourceUnitName)) { - found = true - } else if strings.Contains(sourceUnit.Content, fmt.Sprintf("library %s", s.EntrySourceUnitName)) { + if contractRegex.MatchString(sourceUnit.Content) || libraryRegex.MatchString(sourceUnit.Content) { found = true + break } } diff --git a/utils/networks.go b/utils/networks.go index 2d969d65..4c19610f 100644 --- a/utils/networks.go +++ b/utils/networks.go @@ -12,6 +12,10 @@ func (n NetworkID) String() string { return n.ToBig().String() } +func (n NetworkID) IsValid() bool { + return n != 0 +} + func (n NetworkID) ToNetwork() Network { switch n { case EthereumNetworkID: diff --git a/utils/version.go b/utils/version.go index 5481b09d..3c1033cc 100644 --- a/utils/version.go +++ b/utils/version.go @@ -12,12 +12,22 @@ type SemanticVersion struct { Major int `json:"major"` // Major version, incremented for incompatible API changes. Minor int `json:"minor"` // Minor version, incremented for backwards-compatible enhancements. Patch int `json:"patch"` // Patch version, incremented for backwards-compatible bug fixes. - Commit string `json:"revision"` // Optional commit revision for tracking specific builds. + Commit string `json:"revision,omitempty"` // Optional commit revision for tracking specific builds. } // String returns the string representation of the SemanticVersion, excluding the // commit revision. It adheres to the "Major.Minor.Patch" format. func (v SemanticVersion) String() string { + if v.Commit != "" { + return strconv.Itoa(v.Major) + "." + strconv.Itoa(v.Minor) + "." + strconv.Itoa(v.Patch) + "+" + v.Commit + } + + return strconv.Itoa(v.Major) + "." + strconv.Itoa(v.Minor) + "." + strconv.Itoa(v.Patch) +} + +// String returns the string representation of the SemanticVersion, excluding the +// commit revision. It adheres to the "Major.Minor.Patch" format. +func (v SemanticVersion) StringVersion() string { return strconv.Itoa(v.Major) + "." + strconv.Itoa(v.Minor) + "." + strconv.Itoa(v.Patch) } diff --git a/validation/verifier.go b/validation/verifier.go index d65cd067..14a36049 100644 --- a/validation/verifier.go +++ b/validation/verifier.go @@ -3,8 +3,10 @@ package validation import ( "context" "encoding/hex" - "errors" "fmt" + "github.com/ethereum/go-ethereum/common" + "github.com/pkg/errors" + "github.com/unpackdev/solgo/bytecode" "strings" "github.com/0x19/solc-switch" @@ -106,7 +108,53 @@ func (v *Verifier) VerifyFromResults(bytecode []byte, results *solc.CompilerResu LevenshteinDistance: dmp.DiffLevenshtein(diffs), } - return toReturn, errors.New("bytecode missmatch, failed to verify") + return toReturn, errors.New("contract bytecode mismatch, failed to verify") + } + + toReturn := &VerifyResult{ + Verified: true, + ExpectedBytecode: encoded, + CompilerResult: result, + Diffs: make([]diffmatchpatch.Diff, 0), + } + + return toReturn, nil +} + +// VerifyAuxFromResults compiles the sources using the solc compiler and then verifies the bytecode. +// If the bytecode does not match the compiled result, it returns a diff of the two. +// Returns true if the bytecode matches, otherwise returns false. +// Also returns an error if there's any issue in the compilation or verification process. +func (v *Verifier) VerifyAuxFromResults(bCode []byte, results *solc.CompilerResults) (*VerifyResult, error) { + result := results.GetEntryContract() + + if result == nil { + zap.L().Error( + "no appropriate compilation results found (compiled but missing entry contract)", + zap.Any("results", results), + ) + return nil, errors.New("no appropriate compilation results found (compiled but missing entry contract)") + } + + dBytecode, dBytecodeErr := bytecode.DecodeContractMetadata(common.Hex2Bytes(result.GetDeployedBytecode())) + if dBytecodeErr != nil { + return nil, errors.Wrap(dBytecodeErr, "failure to decode contract metadata while verifying contract aux bytecode") + } + + encoded := common.Bytes2Hex(bCode) + if !strings.Contains(common.Bytes2Hex(dBytecode.GetAuxBytecode()), encoded) { + dmp := diffmatchpatch.New() + diffs := dmp.DiffMain(encoded, common.Bytes2Hex(dBytecode.GetAuxBytecode()), false) + toReturn := &VerifyResult{ + Verified: false, + CompilerResult: result, + ExpectedBytecode: encoded, + Diffs: diffs, + DiffPretty: dmp.DiffPrettyText(diffs), + LevenshteinDistance: dmp.DiffLevenshtein(diffs), + } + + return toReturn, errors.New("aux bytecode mismatch, failed to verify") } toReturn := &VerifyResult{