diff --git a/cl/beacon/handler/block_production.go b/cl/beacon/handler/block_production.go index 710fd009c4c..9d9b4bf239c 100644 --- a/cl/beacon/handler/block_production.go +++ b/cl/beacon/handler/block_production.go @@ -300,11 +300,16 @@ func (a *ApiHandler) GetEthV3ValidatorBlock( ) } - // make a simple copy to the current head state - baseState, err := a.forkchoiceStore.GetStateAtBlockRoot( - baseBlockRoot, - true, - ) // we start the block production from this state + var baseState *state.CachingBeaconState + if err := a.syncedData.ViewHeadState(func(headState *state.CachingBeaconState) error { + baseState, err = headState.Copy() + if err != nil { + return err + } + return nil + }); err != nil { + return nil, err + } if err != nil { return nil, err @@ -1177,7 +1182,7 @@ func (a *ApiHandler) storeBlockAndBlobs( return err } - if err := a.forkchoiceStore.OnBlock(ctx, block, true, false, false); err != nil { + if err := a.forkchoiceStore.OnBlock(ctx, block, true, true, false); err != nil { return err } finalizedBlockRoot := a.forkchoiceStore.FinalizedCheckpoint().Root diff --git a/cl/cltypes/solid/hash_list.go b/cl/cltypes/solid/hash_list.go index 1b8bc09fcd2..cca7e5c435c 100644 --- a/cl/cltypes/solid/hash_list.go +++ b/cl/cltypes/solid/hash_list.go @@ -109,6 +109,12 @@ func (h *hashList) CopyTo(t IterableSSZ[libcommon.Hash]) { tu.MerkleTree = &merkle_tree.MerkleTree{} } h.MerkleTree.CopyInto(tu.MerkleTree) + // make the leaf function on the new buffer + tu.MerkleTree.SetComputeLeafFn(func(idx int, out []byte) { + copy(out, tu.u[idx*length.Hash:]) + }) + } else { + tu.MerkleTree = nil } copy(tu.u, h.u) } diff --git a/cl/cltypes/solid/uint64slice_byte.go b/cl/cltypes/solid/uint64slice_byte.go index a251d683c8f..750e9378ffb 100644 --- a/cl/cltypes/solid/uint64slice_byte.go +++ b/cl/cltypes/solid/uint64slice_byte.go @@ -81,6 +81,11 @@ func (arr *byteBasedUint64Slice) CopyTo(target *byteBasedUint64Slice) { target.MerkleTree = &merkle_tree.MerkleTree{} } arr.MerkleTree.CopyInto(target.MerkleTree) + target.SetComputeLeafFn(func(idx int, out []byte) { + copy(out, target.u[idx*length.Hash:]) + }) + } else { + target.MerkleTree = nil } target.u = target.u[:len(arr.u)] diff --git a/cl/cltypes/solid/validator_set.go b/cl/cltypes/solid/validator_set.go index 6c439c5d53c..79e94d67b4c 100644 --- a/cl/cltypes/solid/validator_set.go +++ b/cl/cltypes/solid/validator_set.go @@ -141,6 +141,11 @@ func (v *ValidatorSet) CopyTo(t *ValidatorSet) { t.MerkleTree = &merkle_tree.MerkleTree{} } v.MerkleTree.CopyInto(t.MerkleTree) + t.MerkleTree.SetComputeLeafFn(func(idx int, out []byte) { + copy(out, t.buffer[idx*validatorSize:]) + }) + } else { + t.MerkleTree = nil } // skip copying (unsupported for phase0) t.phase0Data = make([]Phase0Data, t.l) diff --git a/cl/merkle_tree/merkle_tree.go b/cl/merkle_tree/merkle_tree.go index e7480fb6b4b..819ab224dda 100644 --- a/cl/merkle_tree/merkle_tree.go +++ b/cl/merkle_tree/merkle_tree.go @@ -58,6 +58,10 @@ func (m *MerkleTree) Initialize(leavesCount, maxTreeCacheDepth int, computeLeaf m.dirtyLeaves = make([]atomic.Bool, leavesCount) } +func (m *MerkleTree) SetComputeLeafFn(computeLeaf func(idx int, out []byte)) { + m.computeLeaf = computeLeaf +} + func (m *MerkleTree) MarkLeafAsDirty(idx int) { m.mu.RLock() defer m.mu.RUnlock() @@ -207,7 +211,7 @@ func (m *MerkleTree) CopyInto(other *MerkleTree) { m.mu.RLock() defer m.mu.RUnlock() defer other.mu.Unlock() - other.computeLeaf = m.computeLeaf + //other.computeLeaf = m.computeLeaf if len(other.layers) > len(m.layers) { // reset the internal layers for i := len(m.layers); i < len(other.layers); i++ { diff --git a/cl/phase1/core/state/cache.go b/cl/phase1/core/state/cache.go index a5180932831..e3b76a4916b 100644 --- a/cl/phase1/core/state/cache.go +++ b/cl/phase1/core/state/cache.go @@ -267,7 +267,7 @@ func (b *CachingBeaconState) initCaches() error { func (b *CachingBeaconState) InitBeaconState() error { b.totalActiveBalanceCache = nil b._refreshActiveBalancesIfNeeded() - + b.previousStateRoot = common.Hash{} b.publicKeyIndicies = make(map[[48]byte]uint64) b.ForEachValidator(func(validator solid.Validator, i, total int) bool { b.publicKeyIndicies[validator.PublicKey()] = uint64(i)