diff --git a/core/state/access_list.go b/core/state/access_list.go index 718bf17cf742..b0effbeadc49 100644 --- a/core/state/access_list.go +++ b/core/state/access_list.go @@ -17,7 +17,10 @@ package state import ( + "fmt" "maps" + "slices" + "strings" "github.com/ethereum/go-ethereum/common" ) @@ -130,3 +133,35 @@ func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) { func (al *accessList) DeleteAddress(address common.Address) { delete(al.addresses, address) } + +// Equal returns true if the two access lists are identical +func (al *accessList) Equal(other *accessList) bool { + if !maps.Equal(al.addresses, other.addresses) { + return false + } + return slices.EqualFunc(al.slots, other.slots, + func(m map[common.Hash]struct{}, m2 map[common.Hash]struct{}) bool { + return maps.Equal(m, m2) + }) +} + +// PrettyPrint prints the contents of the access list in a human-readable form +func (al *accessList) PrettyPrint() string { + out := new(strings.Builder) + var sortedAddrs []common.Address + for addr := range al.addresses { + sortedAddrs = append(sortedAddrs, addr) + } + slices.SortFunc(sortedAddrs, common.Address.Cmp) + for _, addr := range sortedAddrs { + idx := al.addresses[addr] + fmt.Fprintf(out, "%#x : (idx %d)\n", addr, idx) + if idx >= 0 { + slotmap := al.slots[idx] + for h := range slotmap { + fmt.Fprintf(out, " %#x\n", h) + } + } + } + return out.String() +} diff --git a/core/state/state_object.go b/core/state/state_object.go index aa748f08ac3c..d3d20c3dc481 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -459,22 +459,22 @@ func (s *stateObject) setBalance(amount *uint256.Int) { func (s *stateObject) deepCopy(db *StateDB) *stateObject { obj := &stateObject{ - db: db, - address: s.address, - addrHash: s.addrHash, - origin: s.origin, - data: s.data, + db: db, + address: s.address, + addrHash: s.addrHash, + origin: s.origin, + data: s.data, + code: s.code, + originStorage: s.originStorage.Copy(), + pendingStorage: s.pendingStorage.Copy(), + dirtyStorage: s.dirtyStorage.Copy(), + dirtyCode: s.dirtyCode, + selfDestructed: s.selfDestructed, + newContract: s.newContract, } if s.trie != nil { obj.trie = db.db.CopyTrie(s.trie) } - obj.code = s.code - obj.originStorage = s.originStorage.Copy() - obj.pendingStorage = s.pendingStorage.Copy() - obj.dirtyStorage = s.dirtyStorage.Copy() - obj.dirtyCode = s.dirtyCode - obj.selfDestructed = s.selfDestructed - obj.newContract = s.newContract return obj } diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 1a3eccfe10b7..71d64f562898 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -21,9 +21,11 @@ import ( "encoding/binary" "errors" "fmt" + "maps" "math" "math/rand" "reflect" + "slices" "strings" "sync" "testing" @@ -557,10 +559,14 @@ func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.H if err != nil { return err } - it := trie.NewIterator(trieIt) + var ( + it = trie.NewIterator(trieIt) + visited = make(map[common.Hash]bool) + ) for it.Next() { key := common.BytesToHash(s.trie.GetKey(it.Key)) + visited[key] = true if value, dirty := so.dirtyStorage[key]; dirty { if !cb(key, value) { return nil @@ -600,6 +606,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) + // Check newContract-flag + if obj := state.getStateObject(addr); obj != nil { + checkeq("IsNewContract", obj.newContract, checkstate.getStateObject(addr).newContract) + } // Check storage. if obj := state.getStateObject(addr); obj != nil { forEachStorage(state, addr, func(key, value common.Hash) bool { @@ -608,12 +618,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { forEachStorage(checkstate, addr, func(key, value common.Hash) bool { return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) }) + other := checkstate.getStateObject(addr) + // Check dirty storage which is not in trie + if !maps.Equal(obj.dirtyStorage, other.dirtyStorage) { + print := func(dirty map[common.Hash]common.Hash) string { + var keys []common.Hash + out := new(strings.Builder) + for key := range dirty { + keys = append(keys, key) + } + slices.SortFunc(keys, common.Hash.Cmp) + for i, key := range keys { + fmt.Fprintf(out, " %d. %v %v\n", i, key, dirty[key]) + } + return out.String() + } + return fmt.Errorf("dirty storage err, have\n%v\nwant\n%v", + print(obj.dirtyStorage), + print(other.dirtyStorage)) + } + } + // Check transient storage. + { + have := state.transientStorage + want := checkstate.transientStorage + eq := maps.EqualFunc(have, want, + func(a Storage, b Storage) bool { + return maps.Equal(a, b) + }) + if !eq { + return fmt.Errorf("transient storage differs ,have\n%v\nwant\n%v", + have.PrettyPrint(), + want.PrettyPrint()) + } } if err != nil { return err } } - + if !checkstate.accessList.Equal(state.accessList) { // Check access lists + return fmt.Errorf("AccessLists are wrong, have \n%v\nwant\n%v", + checkstate.accessList.PrettyPrint(), + state.accessList.PrettyPrint()) + } if state.GetRefund() != checkstate.GetRefund() { return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", state.GetRefund(), checkstate.GetRefund()) @@ -622,6 +669,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) } + if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) { + getKeys := func(dirty map[common.Address]int) string { + var keys []common.Address + out := new(strings.Builder) + for key := range dirty { + keys = append(keys, key) + } + slices.SortFunc(keys, common.Address.Cmp) + for i, key := range keys { + fmt.Fprintf(out, " %d. %v\n", i, key) + } + return out.String() + } + have := getKeys(state.journal.dirties) + want := getKeys(checkstate.journal.dirties) + return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want) + } return nil } diff --git a/core/state/transient_storage.go b/core/state/transient_storage.go index 66e563efa732..e63db39ebab6 100644 --- a/core/state/transient_storage.go +++ b/core/state/transient_storage.go @@ -17,6 +17,10 @@ package state import ( + "fmt" + "slices" + "strings" + "github.com/ethereum/go-ethereum/common" ) @@ -30,10 +34,19 @@ func newTransientStorage() transientStorage { // Set sets the transient-storage `value` for `key` at the given `addr`. func (t transientStorage) Set(addr common.Address, key, value common.Hash) { - if _, ok := t[addr]; !ok { - t[addr] = make(Storage) + if value == (common.Hash{}) { // this is a 'delete' + if _, ok := t[addr]; ok { + delete(t[addr], key) + if len(t[addr]) == 0 { + delete(t, addr) + } + } + } else { + if _, ok := t[addr]; !ok { + t[addr] = make(Storage) + } + t[addr][key] = value } - t[addr][key] = value } // Get gets the transient storage for `key` at the given `addr`. @@ -53,3 +66,27 @@ func (t transientStorage) Copy() transientStorage { } return storage } + +// PrettyPrint prints the contents of the access list in a human-readable form +func (t transientStorage) PrettyPrint() string { + out := new(strings.Builder) + var sortedAddrs []common.Address + for addr := range t { + sortedAddrs = append(sortedAddrs, addr) + slices.SortFunc(sortedAddrs, common.Address.Cmp) + } + + for _, addr := range sortedAddrs { + fmt.Fprintf(out, "%#x:", addr) + var sortedKeys []common.Hash + storage := t[addr] + for key := range storage { + sortedKeys = append(sortedKeys, key) + } + slices.SortFunc(sortedKeys, common.Hash.Cmp) + for _, key := range sortedKeys { + fmt.Fprintf(out, " %X : %X\n", key, storage[key]) + } + } + return out.String() +}