Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overall improvements #214

Merged
merged 2 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ go.work
solgo
playground/*
bin/*
.idea
6 changes: 5 additions & 1 deletion ast/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (b *ASTBuilder) ToJSON() ([]byte, error) {
return b.InterfaceToJSON(b.tree.GetRoot())
}

// ToPrettyJSON converts the provided data to a JSON byte array.
// InterfaceToJSON converts the provided data to a JSON byte array.
func (b *ASTBuilder) InterfaceToJSON(data interface{}) ([]byte, error) {
return json.Marshal(data)
}
Expand Down Expand Up @@ -129,6 +129,10 @@ func (b *ASTBuilder) ImportFromJSON(ctx context.Context, jsonBytes []byte) (*Roo
return toReturn, nil
}

func (b *ASTBuilder) GetImports() []Node[NodeType] {
return b.currentImports
}

// GarbageCollect cleans up the ASTBuilder after resolving references.
func (b *ASTBuilder) GarbageCollect() {
b.currentEnums = nil
Expand Down
1 change: 1 addition & 0 deletions ast/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ func (c *Contract) Parse(unitCtx *parser.SourceUnitContext, ctx *parser.Contract
)

contractId := c.GetNextID()

contractNode := &Contract{
Id: contractId,
Name: ctx.Identifier().GetText(),
Expand Down
1 change: 1 addition & 0 deletions ast/reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ func (r *Resolver) resolveEntrySourceUnit() {
for _, entry := range node.GetExportedSymbols() {
if len(r.sources.EntrySourceUnitName) > 0 &&
r.sources.EntrySourceUnitName == entry.GetName() {

r.tree.astRoot.SetEntrySourceUnit(entry.GetId())
return
}
Expand Down
20 changes: 20 additions & 0 deletions ast/source_unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type SourceUnit[T NodeType] struct {
AbsolutePath string `json:"absolutePath"` // AbsolutePath is the absolute path of the source unit.
Name string `json:"name"` // Name is the name of the source unit. This is going to be one of the following: contract, interface or library name. It's here for convenience.
NodeType ast_pb.NodeType `json:"nodeType"` // NodeType is the type of the AST node.
Kind ast_pb.NodeType `json:"kind"` // Kind is the type of the AST node (contract, library, interface).
Nodes []Node[NodeType] `json:"nodes"` // Nodes is the list of AST nodes.
Src SrcNode `json:"src"` // Src is the source code location.
}
Expand Down Expand Up @@ -107,6 +108,11 @@ func (s *SourceUnit[T]) GetType() ast_pb.NodeType {
return s.NodeType
}

// GetKind returns the type of the source unit.
func (s *SourceUnit[T]) GetKind() ast_pb.NodeType {
return s.Kind
}

// GetSrc returns the source code location of the source unit.
func (s *SourceUnit[T]) GetSrc() SrcNode {
return s.Src
Expand Down Expand Up @@ -302,6 +308,7 @@ func (b *ASTBuilder) EnterSourceUnit(ctx *parser.SourceUnitContext) {
if interfaceCtx, ok := child.(*parser.InterfaceDefinitionContext); ok {
license := getLicenseFromSources(b.sources, b.comments, interfaceCtx.Identifier().GetText())
sourceUnit := NewSourceUnit[Node[ast_pb.SourceUnit]](b, interfaceCtx.Identifier().GetText(), license)
sourceUnit.Kind = ast_pb.NodeType_KIND_INTERFACE
interfaceNode := NewInterfaceDefinition(b)
interfaceNode.Parse(ctx, interfaceCtx, rootNode, sourceUnit)
b.sourceUnits = append(b.sourceUnits, sourceUnit)
Expand All @@ -310,6 +317,7 @@ func (b *ASTBuilder) EnterSourceUnit(ctx *parser.SourceUnitContext) {
if libraryCtx, ok := child.(*parser.LibraryDefinitionContext); ok {
license := getLicenseFromSources(b.sources, b.comments, libraryCtx.Identifier().GetText())
sourceUnit := NewSourceUnit[Node[ast_pb.SourceUnit]](b, libraryCtx.Identifier().GetText(), license)
sourceUnit.Kind = ast_pb.NodeType_KIND_LIBRARY
libraryNode := NewLibraryDefinition(b)
libraryNode.Parse(ctx, libraryCtx, rootNode, sourceUnit)
b.sourceUnits = append(b.sourceUnits, sourceUnit)
Expand All @@ -318,11 +326,23 @@ func (b *ASTBuilder) EnterSourceUnit(ctx *parser.SourceUnitContext) {
if contractCtx, ok := child.(*parser.ContractDefinitionContext); ok {
license := getLicenseFromSources(b.sources, b.comments, contractCtx.Identifier().GetText())
sourceUnit := NewSourceUnit[Node[ast_pb.SourceUnit]](b, contractCtx.Identifier().GetText(), license)
sourceUnit.Kind = ast_pb.NodeType_KIND_CONTRACT
contractNode := NewContractDefinition(b)
contractNode.Parse(ctx, contractCtx, rootNode, sourceUnit)
b.sourceUnits = append(b.sourceUnits, sourceUnit)
}
}

// Idea here is to basically set the source unit entry name as soon as we have parsed all of the classes.
// Now this won't be possible always but nevertheless. (In rest of the cases, resolver will take care of it)
if b.sources.EntrySourceUnitName != "" {
for _, sourceUnit := range b.sourceUnits {
if b.sources.EntrySourceUnitName == sourceUnit.GetName() {
rootNode.SetEntrySourceUnit(sourceUnit.GetId())
return
}
}
}
}

// ExitSourceUnit is called when the ASTBuilder exits a source unit context.
Expand Down
12 changes: 6 additions & 6 deletions ast/src.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import (

// SrcNode represents a node in the source code.
type SrcNode struct {
Line int64 `json:"line"` // Line number of the source node in the source code.
Column int64 `json:"column"` // Column number of the source node in the source code.
Start int64 `json:"start"` // Start position of the source node in the source code.
End int64 `json:"end"` // End position of the source node in the source code.
Length int64 `json:"length"` // Length of the source node in the source code.
ParentIndex int64 `json:"parent_index,omitempty"` // Index of the parent node in the source code.
Line int64 `json:"line"` // Line number of the source node in the source code.
Column int64 `json:"column"` // Column number of the source node in the source code.
Start int64 `json:"start"` // Start position of the source node in the source code.
End int64 `json:"end"` // End position of the source node in the source code.
Length int64 `json:"length"` // Length of the source node in the source code.
ParentIndex int64 `json:"parentIndex,omitempty"` // Index of the parent node in the source code.
}

// GetLine returns the line number of the source node in the source code.
Expand Down
100 changes: 98 additions & 2 deletions ast/state_variable.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package ast

import (
"strings"

v3 "github.com/cncf/xds/go/xds/type/v3"
"github.com/goccy/go-json"
ast_pb "github.com/unpackdev/protos/dist/go/ast"
"github.com/unpackdev/solgo/parser"
"strings"
)

// StateVariableDeclaration represents a state variable declaration in the Solidity abstract syntax tree (AST).
Expand Down Expand Up @@ -145,6 +145,102 @@ func (v *StateVariableDeclaration) GetInitialValue() Node[NodeType] {
return v.InitialValue
}

// UnmarshalJSON customizes the JSON unmarshaling for StateVariableDeclaration.
func (v *StateVariableDeclaration) UnmarshalJSON(data []byte) error {
var tempMap map[string]json.RawMessage
if err := json.Unmarshal(data, &tempMap); err != nil {
return err
}

if id, ok := tempMap["id"]; ok {
if err := json.Unmarshal(id, &v.Id); err != nil {
return err
}
}

if name, ok := tempMap["name"]; ok {
if err := json.Unmarshal(name, &v.Name); err != nil {
return err
}
}

if isConstant, ok := tempMap["isConstant"]; ok {
if err := json.Unmarshal(isConstant, &v.Constant); err != nil {
return err
}
}

if isStateVariable, ok := tempMap["isStateVariable"]; ok {
if err := json.Unmarshal(isStateVariable, &v.StateVariable); err != nil {
return err
}
}

if nodeType, ok := tempMap["nodeType"]; ok {
if err := json.Unmarshal(nodeType, &v.NodeType); err != nil {
return err
}
}

if visibility, ok := tempMap["visibility"]; ok {
if err := json.Unmarshal(visibility, &v.Visibility); err != nil {
return err
}
}

if storageLocation, ok := tempMap["storageLocation"]; ok {
if err := json.Unmarshal(storageLocation, &v.StorageLocation); err != nil {
return err
}
}

if mutability, ok := tempMap["mutability"]; ok {
if err := json.Unmarshal(mutability, &v.StateMutability); err != nil {
return err
}
}

if src, ok := tempMap["src"]; ok {
if err := json.Unmarshal(src, &v.Src); err != nil {
return err
}
}

if scope, ok := tempMap["scope"]; ok {
if err := json.Unmarshal(scope, &v.Scope); err != nil {
return err
}
}

if expression, ok := tempMap["initialValue"]; ok {
if err := json.Unmarshal(expression, &v.InitialValue); err != nil {
var tempNodeMap map[string]json.RawMessage
if err := json.Unmarshal(expression, &tempNodeMap); err != nil {
return err
}

var tempNodeType ast_pb.NodeType
if err := json.Unmarshal(tempNodeMap["nodeType"], &tempNodeType); err != nil {
return err
}

node, err := unmarshalNode(expression, tempNodeType)
if err != nil {
return err
}
v.InitialValue = node
}
}

if typeDescription, ok := tempMap["typeDescription"]; ok {
if err := json.Unmarshal(typeDescription, &v.TypeDescription); err != nil {
return err
}
}

return nil
}

// ToProto returns the protobuf representation of the state variable declaration.
func (v *StateVariableDeclaration) ToProto() NodeType {
proto := ast_pb.StateVariable{
Expand Down
26 changes: 23 additions & 3 deletions bindings/otterscan.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package bindings
import (
"context"
"fmt"
"github.com/pkg/errors"

"github.com/ethereum/go-ethereum/common"
"github.com/unpackdev/solgo/utils"
Expand All @@ -19,9 +20,9 @@ type CreatorInformation struct {
// GetContractCreator queries the Ethereum blockchain to find the creator of a specified smart contract. This method
// utilizes the Ethereum JSON-RPC API to request creator information, which includes both the creator's address and
// the transaction hash of the contract's creation. It's a valuable tool for auditing and tracking the origins of
// contracts on the network. WORKS ONLY WITH ERIGON NODE - OR NODES THAT SUPPORT OTTERSCAN!
// contracts on the network. WORKS ONLY WITH ERIGON NODE OR QUICKNODE PROVIDER - OR NODES THAT SUPPORT OTTERSCAN!
func (m *Manager) GetContractCreator(ctx context.Context, network utils.Network, contract common.Address) (*CreatorInformation, error) {
client := m.clientPool.GetClientByGroup(string(network))
client := m.clientPool.GetClientByGroup(network.String())
if client == nil {
return nil, fmt.Errorf("client not found for network %s", network)
}
Expand All @@ -30,7 +31,26 @@ func (m *Manager) GetContractCreator(ctx context.Context, network utils.Network,
var result *CreatorInformation

if err := rpcClient.CallContext(ctx, &result, "ots_getContractCreator", contract.Hex()); err != nil {
return nil, fmt.Errorf("failed to fetch otterscan creator information: %v", err)
return nil, errors.Wrap(err, "failed to fetch otterscan creator information")
}

return result, nil
}

// GetTransactionBySenderAndNonce retrieves a transaction hash based on a specific sender's address and nonce.
// This function also utilizes the Ethereum JSON-RPC API and requires a node that supports specific transaction lookup
// by sender and nonce, which is particularly useful for tracking transaction sequences and debugging transaction flows.
func (m *Manager) GetTransactionBySenderAndNonce(ctx context.Context, network utils.Network, sender common.Address, nonce int64) (*common.Hash, error) {
client := m.clientPool.GetClientByGroup(network.String())
if client == nil {
return nil, fmt.Errorf("client not found for network %s", network)
}

rpcClient := client.GetRpcClient()
var result *common.Hash

if err := rpcClient.CallContext(ctx, &result, "ots_getTransactionBySenderAndNonce", sender.Hex(), nonce); err != nil {
return nil, errors.Wrap(err, "failed to fetch otterscan get transaction by sender and nonce information")
}

return result, nil
Expand Down
26 changes: 26 additions & 0 deletions bindings/trace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package bindings

import (
"context"
"fmt"
"github.com/pkg/errors"

"github.com/ethereum/go-ethereum/common"
"github.com/unpackdev/solgo/utils"
)

func (m *Manager) TraceCallMany(ctx context.Context, network utils.Network, sender common.Address, nonce int64) (*common.Hash, error) {
client := m.clientPool.GetClientByGroup(network.String())
if client == nil {
return nil, fmt.Errorf("client not found for network %s", network)
}

rpcClient := client.GetRpcClient()
var result *common.Hash

if err := rpcClient.CallContext(ctx, &result, "trace_callMany", sender.Hex(), nonce); err != nil {
return nil, errors.Wrap(err, "failed to execute trace_callMany")
}

return result, nil
}
12 changes: 9 additions & 3 deletions contracts/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ type Descriptor struct {

// SourcesRaw is the raw sources from Etherscan|BscScan|etc. Should not be used anywhere except in
// the contract discovery process.
SourcesRaw *etherscan.Contract `json:"-"`
Sources *solgo.Sources `json:"sources,omitempty"`
SourceProvider string `json:"source_provider,omitempty"`
SourcesRaw *etherscan.Contract `json:"-"`
Sources *solgo.Sources `json:"sources,omitempty"`
SourcesUnsorted *solgo.Sources `json:"-"`
SourceProvider string `json:"source_provider,omitempty"`

// Source detection related fields.
Detector *detector.Detector `json:"-"`
Expand Down Expand Up @@ -144,6 +145,11 @@ func (d *Descriptor) GetSources() *solgo.Sources {
return d.Sources
}

// GetUnsortedSources returns the parsed sources of the contract, providing a structured view of the contract's code.
func (d *Descriptor) GetUnsortedSources() *solgo.Sources {
return d.SourcesUnsorted
}

// GetSourcesRaw returns the raw contract source as obtained from external providers like Etherscan.
func (d *Descriptor) GetSourcesRaw() *etherscan.Contract {
return d.SourcesRaw
Expand Down
4 changes: 4 additions & 0 deletions contracts/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ func (c *Contract) Parse(ctx context.Context) error {
)
return err
}

// Sets the address for more understanding when we need to troubleshoot contract parsing
parser.GetIR().SetAddress(c.addr)

c.descriptor.Detector = parser
c.descriptor.SolgoVersion = utils.GetBuildVersionByModule("github.com/unpackdev/solgo")

Expand Down
12 changes: 12 additions & 0 deletions contracts/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,19 @@ func (c *Contract) DiscoverSourceCode(ctx context.Context) error {
return fmt.Errorf("failed to create new sources from etherscan response: %s", err)
}

unsortedSources, err := solgo.NewUnsortedSourcesFromEtherScan(response.Name, response.SourceCode)
if err != nil {
zap.L().Error(
"failed to create new unsorted sources from etherscan response",
zap.Error(err),
zap.String("network", c.network.String()),
zap.String("contract_address", c.addr.String()),
)
return fmt.Errorf("failed to create new unsorted sources from etherscan response: %s", err)
}

c.descriptor.Sources = sources
c.descriptor.SourcesUnsorted = unsortedSources

license := strings.ReplaceAll(c.descriptor.SourcesRaw.LicenseType, "\r", "")
license = strings.ReplaceAll(license, "\n", "")
Expand Down
6 changes: 3 additions & 3 deletions ir/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func (c *Contract) ToProto() *ir_pb.Contract {

// processContract processes the contract unit and returns the Contract.
func (b *Builder) processContract(unit *ast.SourceUnit[ast.Node[ast_pb.SourceUnit]]) *Contract {
contract := getContractByNodeType(unit.GetContract())
contract := GetContractByNodeType(unit.GetContract())
contractNode := &Contract{
Unit: unit,

Expand Down Expand Up @@ -360,8 +360,8 @@ func (b *Builder) processContract(unit *ast.SourceUnit[ast.Node[ast_pb.SourceUni
return contractNode
}

// getContractByNodeType returns the ContractNode based on the node type.
func getContractByNodeType(c ast.Node[ast.NodeType]) ContractNode {
// GetContractByNodeType returns the ContractNode based on the node type.
func GetContractByNodeType(c ast.Node[ast.NodeType]) ContractNode {
switch contract := c.(type) {
case *ast.Library:
return contract
Expand Down
Loading
Loading