From 5e422156d5968a893cdaa8855bc7e3482745c396 Mon Sep 17 00:00:00 2001 From: noot <36753753+noot@users.noreply.github.com> Date: Thu, 21 Oct 2021 10:15:44 +0100 Subject: [PATCH] refactor(dot/state): only store finalised blocks in database (#1833) --- dot/core/service_test.go | 2 +- dot/digest/digest_test.go | 6 +- dot/rpc/modules/chain.go | 2 +- dot/rpc/modules/chain_test.go | 55 ++--- dot/rpc/modules/dev_test.go | 2 +- dot/rpc/modules/state_test.go | 4 +- dot/rpc/modules/system_test.go | 2 +- dot/services.go | 12 - dot/state/base.go | 30 --- dot/state/base_test.go | 50 ----- dot/state/block.go | 324 +++++++++++---------------- dot/state/block_finalisation.go | 82 ++++--- dot/state/block_finalisation_test.go | 65 +++--- dot/state/block_race_test.go | 14 +- dot/state/block_test.go | 108 +-------- dot/state/epoch.go | 8 +- dot/state/epoch_test.go | 2 +- dot/state/initialize.go | 24 +- dot/state/offline_pruner.go | 18 +- dot/state/service.go | 118 +++------- dot/state/service_test.go | 8 + dot/state/test_helpers.go | 2 +- dot/sync/test_helpers.go | 45 ++-- dot/types/block.go | 10 + dot/types/extrinsic.go | 4 + lib/babe/build.go | 9 +- lib/babe/build_test.go | 5 +- lib/babe/verify_test.go | 3 +- lib/blocktree/blocktree.go | 109 +++++---- lib/blocktree/blocktree_test.go | 236 +++++++++++-------- lib/blocktree/database.go | 139 ------------ lib/blocktree/database_test.go | 112 --------- lib/blocktree/errors.go | 8 + lib/blocktree/leaves.go | 10 +- lib/blocktree/node.go | 27 +-- lib/blocktree/node_test.go | 4 +- lib/grandpa/message_handler_test.go | 6 +- 37 files changed, 612 insertions(+), 1053 deletions(-) delete mode 100644 lib/blocktree/database.go delete mode 100644 lib/blocktree/database_test.go diff --git a/dot/core/service_test.go b/dot/core/service_test.go index 9ce906c5d7..c625eba2a7 100644 --- a/dot/core/service_test.go +++ b/dot/core/service_test.go @@ -201,7 +201,7 @@ func TestHandleChainReorg_WithReorg_Trans(t *testing.T) { nonce := uint64(0) - // Add extrinsic to block `block31` + // Add extrinsic to block `block41` ext := createExtrinsic(t, rt, bs.GenesisHash(), nonce) block41 := sync.BuildBlock(t, rt, &block31.Header, ext) diff --git a/dot/digest/digest_test.go b/dot/digest/digest_test.go index 86e4301dc3..2ed55f1951 100644 --- a/dot/digest/digest_test.go +++ b/dot/digest/digest_test.go @@ -94,13 +94,15 @@ func TestHandler_GrandpaScheduledChange(t *testing.T) { headers, _ := state.AddBlocksToState(t, handler.blockState.(*state.BlockState), 2, false) for i, h := range headers { - handler.blockState.(*state.BlockState).SetFinalisedHash(h.Hash(), uint64(i), 0) + err = handler.blockState.(*state.BlockState).SetFinalisedHash(h.Hash(), uint64(i), 0) + require.NoError(t, err) } // authorities should change on start of block 3 from start headers, _ = state.AddBlocksToState(t, handler.blockState.(*state.BlockState), 1, false) for _, h := range headers { - handler.blockState.(*state.BlockState).SetFinalisedHash(h.Hash(), 3, 0) + err = handler.blockState.(*state.BlockState).SetFinalisedHash(h.Hash(), 3, 0) + require.NoError(t, err) } time.Sleep(time.Millisecond * 500) diff --git a/dot/rpc/modules/chain.go b/dot/rpc/modules/chain.go index 782a078d28..345afce814 100644 --- a/dot/rpc/modules/chain.go +++ b/dot/rpc/modules/chain.go @@ -103,7 +103,7 @@ func (cm *ChainModule) GetBlock(r *http.Request, req *ChainHashRequest, res *Cha return err } for _, e := range ext { - res.Block.Body = append(res.Block.Body, fmt.Sprintf("0x%x", e)) + res.Block.Body = append(res.Block.Body, e.String()) } } return nil diff --git a/dot/rpc/modules/chain_test.go b/dot/rpc/modules/chain_test.go index 67795a9a6c..5e183f111b 100644 --- a/dot/rpc/modules/chain_test.go +++ b/dot/rpc/modules/chain_test.go @@ -314,13 +314,14 @@ func TestChainGetFinalizedHeadByRound(t *testing.T) { digest := types.NewDigest() digest.Add(*types.NewBabeSecondaryPlainPreDigest(0, 1).ToPreRuntimeDigest()) - header := &types.Header{ + + header := types.Header{ ParentHash: genesisHeader.Hash(), Number: big.NewInt(1), Digest: digest, } err = state.Block.AddBlock(&types.Block{ - Header: *header, + Header: header, Body: types.Body{}, }) require.NoError(t, err) @@ -369,8 +370,7 @@ func newTestStateService(t *testing.T) *state.Service { rt, err := wasmer.NewRuntimeFromGenesis(rtCfg) require.NoError(t, err) - err = loadTestBlocks(t, genesisHeader.Hash(), stateSrvc.Block, rt) - require.NoError(t, err) + loadTestBlocks(t, genesisHeader.Hash(), stateSrvc.Block, rt) t.Cleanup(func() { stateSrvc.Stop() @@ -378,51 +378,40 @@ func newTestStateService(t *testing.T) *state.Service { return stateSrvc } -func loadTestBlocks(t *testing.T, gh common.Hash, bs *state.BlockState, rt runtime.Instance) error { - // Create header - header0 := &types.Header{ - Number: big.NewInt(0), +func loadTestBlocks(t *testing.T, gh common.Hash, bs *state.BlockState, rt runtime.Instance) { + header1 := &types.Header{ + Number: big.NewInt(1), Digest: types.NewDigest(), ParentHash: gh, StateRoot: trie.EmptyHash, } - // Create blockHash - blockHash0 := header0.Hash() - block0 := &types.Block{ - Header: *header0, - Body: sampleBodyBytes, - } - err := bs.AddBlock(block0) - if err != nil { - return err + block1 := &types.Block{ + Header: *header1, + Body: sampleBodyBytes, } - bs.StoreRuntime(block0.Header.Hash(), rt) + err := bs.AddBlock(block1) + require.NoError(t, err) + bs.StoreRuntime(header1.Hash(), rt) - // Create header & blockData for block 1 digest := types.NewDigest() err = digest.Add(*types.NewBabeSecondaryPlainPreDigest(0, 1).ToPreRuntimeDigest()) require.NoError(t, err) - header1 := &types.Header{ - Number: big.NewInt(1), + + header2 := &types.Header{ + Number: big.NewInt(2), Digest: digest, - ParentHash: blockHash0, + ParentHash: header1.Hash(), StateRoot: trie.EmptyHash, } - block1 := &types.Block{ - Header: *header1, + block2 := &types.Block{ + Header: *header2, Body: sampleBodyBytes, } - // Add the block1 to the DB - err = bs.AddBlock(block1) - if err != nil { - return err - } - - bs.StoreRuntime(block1.Header.Hash(), rt) - - return nil + err = bs.AddBlock(block2) + require.NoError(t, err) + bs.StoreRuntime(header2.Hash(), rt) } diff --git a/dot/rpc/modules/dev_test.go b/dot/rpc/modules/dev_test.go index 06fc9ee416..913d71e663 100644 --- a/dot/rpc/modules/dev_test.go +++ b/dot/rpc/modules/dev_test.go @@ -35,7 +35,7 @@ func newState(t *testing.T) (*state.BlockState, *state.EpochState) { _, _, genesisHeader := genesis.NewTestGenesisWithTrieAndHeader(t) bs, err := state.NewBlockStateFromGenesis(db, genesisHeader) require.NoError(t, err) - es, err := state.NewEpochStateFromGenesis(db, genesisBABEConfig) + es, err := state.NewEpochStateFromGenesis(db, bs, genesisBABEConfig) require.NoError(t, err) return bs, es } diff --git a/dot/rpc/modules/state_test.go b/dot/rpc/modules/state_test.go index 5e8c8b9220..cd2c15ca2b 100644 --- a/dot/rpc/modules/state_test.go +++ b/dot/rpc/modules/state_test.go @@ -543,7 +543,7 @@ func setupStateModule(t *testing.T) (*StateModule, *common.Hash, *common.Hash) { b := &types.Block{ Header: types.Header{ ParentHash: chain.Block.BestBlockHash(), - Number: big.NewInt(2), + Number: big.NewInt(3), StateRoot: sr1, }, Body: *types.NewBody([]types.Extrinsic{[]byte{}}), @@ -557,7 +557,7 @@ func setupStateModule(t *testing.T) (*StateModule, *common.Hash, *common.Hash) { chain.Block.StoreRuntime(b.Header.Hash(), rt) - hash, _ := chain.Block.GetBlockHash(big.NewInt(2)) + hash, _ := chain.Block.GetBlockHash(big.NewInt(3)) core := newCoreService(t, chain) return NewStateModule(net, chain.Storage, core), &hash, &sr1 } diff --git a/dot/rpc/modules/system_test.go b/dot/rpc/modules/system_test.go index 0d726dd242..b7ac869b2e 100644 --- a/dot/rpc/modules/system_test.go +++ b/dot/rpc/modules/system_test.go @@ -308,7 +308,7 @@ func setupSystemModule(t *testing.T) *SystemModule { require.NoError(t, err) err = chain.Block.AddBlock(&types.Block{ Header: types.Header{ - Number: big.NewInt(1), + Number: big.NewInt(3), ParentHash: chain.Block.BestBlockHash(), StateRoot: ts.MustRoot(), }, diff --git a/dot/services.go b/dot/services.go index c1f12dd409..0af544837f 100644 --- a/dot/services.go +++ b/dot/services.go @@ -74,18 +74,6 @@ func createStateService(cfg *Config) (*state.Service, error) { } } - // load most recent state from database - latestState, err := stateSrvc.Base.LoadLatestStorageHash() - if err != nil { - return nil, fmt.Errorf("failed to load latest state root hash: %s", err) - } - - // load most recent state from database - _, err = stateSrvc.Storage.LoadFromDB(latestState) - if err != nil { - return nil, fmt.Errorf("failed to load latest state from database: %s", err) - } - return stateSrvc, nil } diff --git a/dot/state/base.go b/dot/state/base.go index 9a2d64dfb2..2858c14db9 100644 --- a/dot/state/base.go +++ b/dot/state/base.go @@ -54,21 +54,6 @@ func (s *BaseState) LoadNodeGlobalName() (string, error) { return string(nodeName), nil } -// StoreBestBlockHash stores the hash at the BestBlockHashKey -func (s *BaseState) StoreBestBlockHash(hash common.Hash) error { - return s.db.Put(common.BestBlockHashKey, hash[:]) -} - -// LoadBestBlockHash loads the hash stored at BestBlockHashKey -func (s *BaseState) LoadBestBlockHash() (common.Hash, error) { - hash, err := s.db.Get(common.BestBlockHashKey) - if err != nil { - return common.Hash{}, err - } - - return common.NewHash(hash), nil -} - // StoreGenesisData stores the given genesis data at the known GenesisDataKey. func (s *BaseState) StoreGenesisData(gen *genesis.Data) error { enc, err := json.Marshal(gen) @@ -95,21 +80,6 @@ func (s *BaseState) LoadGenesisData() (*genesis.Data, error) { return data, nil } -// StoreLatestStorageHash stores the current root hash in the database at LatestStorageHashKey -func (s *BaseState) StoreLatestStorageHash(root common.Hash) error { - return s.db.Put(common.LatestStorageHashKey, root[:]) -} - -// LoadLatestStorageHash retrieves the hash stored at LatestStorageHashKey from the DB -func (s *BaseState) LoadLatestStorageHash() (common.Hash, error) { - hashbytes, err := s.db.Get(common.LatestStorageHashKey) - if err != nil { - return common.Hash{}, err - } - - return common.NewHash(hashbytes), nil -} - // StoreCodeSubstitutedBlockHash stores the hash at the CodeSubstitutedBlock key func (s *BaseState) StoreCodeSubstitutedBlockHash(hash common.Hash) error { return s.db.Put(common.CodeSubstitutedBlock, hash[:]) diff --git a/dot/state/base_test.go b/dot/state/base_test.go index 957a81323a..e4f68b62c0 100644 --- a/dot/state/base_test.go +++ b/dot/state/base_test.go @@ -39,42 +39,6 @@ func TestTrie_StoreAndLoadFromDB(t *testing.T) { require.Equal(t, expected, tt.MustHash()) } -type test struct { - key []byte - value []byte -} - -func TestStoreAndLoadLatestStorageHash(t *testing.T) { - db := NewInMemoryDB(t) - base := NewBaseState(db) - tt := trie.NewEmptyTrie() - - tests := []test{ - {key: []byte{0x01, 0x35}, value: []byte("pen")}, - {key: []byte{0x01, 0x35, 0x79}, value: []byte("penguin")}, - {key: []byte{0x01, 0x35, 0x7}, value: []byte("g")}, - {key: []byte{0xf2}, value: []byte("feather")}, - {key: []byte{0xf2, 0x3}, value: []byte("f")}, - {key: []byte{0x09, 0xd3}, value: []byte("noot")}, - {key: []byte{0x07}, value: []byte("ramen")}, - {key: []byte{0}, value: nil}, - } - - for _, test := range tests { - tt.Put(test.key, test.value) - } - - expected, err := tt.Hash() - require.NoError(t, err) - - err = base.StoreLatestStorageHash(expected) - require.NoError(t, err) - - hash, err := base.LoadLatestStorageHash() - require.NoError(t, err) - require.Equal(t, expected, hash) -} - func TestStoreAndLoadGenesisData(t *testing.T) { db := NewInMemoryDB(t) base := NewBaseState(db) @@ -99,20 +63,6 @@ func TestStoreAndLoadGenesisData(t *testing.T) { require.Equal(t, expected, gen) } -func TestStoreAndLoadBestBlockHash(t *testing.T) { - db := NewInMemoryDB(t) - base := NewBaseState(db) - - hash, _ := common.HexToHash("0x3f5a19b9e9507e05276216f3877bb289e47885f8184010c65d0e41580d3663cc") - - err := base.StoreBestBlockHash(hash) - require.NoError(t, err) - - res, err := base.LoadBestBlockHash() - require.NoError(t, err) - require.Equal(t, hash, res) -} - func TestLoadStoreEpochLength(t *testing.T) { db := NewInMemoryDB(t) base := NewBaseState(db) diff --git a/dot/state/block.go b/dot/state/block.go index a69b15a523..cbb6b04f6b 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -36,20 +36,34 @@ import ( "github.com/ChainSafe/gossamer/lib/runtime/wasmer" ) -var blockPrefix = "block" +const ( + pruneKeyBufferSize = 1000 + blockPrefix = "block" +) -const pruneKeyBufferSize = 1000 +var ( + headerPrefix = []byte("hdr") // headerPrefix + hash -> header + blockBodyPrefix = []byte("blb") // blockBodyPrefix + hash -> body + headerHashPrefix = []byte("hsh") // headerHashPrefix + encodedBlockNum -> hash + arrivalTimePrefix = []byte("arr") // arrivalTimePrefix || hash -> arrivalTime + receiptPrefix = []byte("rcp") // receiptPrefix + hash -> receipt + messageQueuePrefix = []byte("mqp") // messageQueuePrefix + hash -> message queue + justificationPrefix = []byte("jcp") // justificationPrefix + hash -> justification -// BlockState defines fields for manipulating the state of blocks, such as BlockTree, -// BlockDB and Header + errNilBlockBody = errors.New("block body is nil") +) + +// BlockState contains the historical block data of the blockchain, including block headers and bodies. +// It wraps the blocktree (which contains unfinalised blocks) and the database (which contains finalised blocks). type BlockState struct { bt *blocktree.BlockTree baseState *BaseState dbPath string db chaindb.Database sync.RWMutex - genesisHash common.Hash - lastFinalised common.Hash + genesisHash common.Hash + lastFinalised common.Hash + unfinalisedBlocks *sync.Map // map[common.Hash]*types.Block // block notifiers imported map[chan *types.Block]struct{} @@ -63,42 +77,42 @@ type BlockState struct { } // NewBlockState will create a new BlockState backed by the database located at basePath -func NewBlockState(db chaindb.Database, bt *blocktree.BlockTree) (*BlockState, error) { - if bt == nil { - return nil, fmt.Errorf("block tree is nil") - } - +func NewBlockState(db chaindb.Database) (*BlockState, error) { bs := &BlockState{ - bt: bt, dbPath: db.Path(), baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), + unfinalisedBlocks: new(sync.Map), imported: make(map[chan *types.Block]struct{}), finalised: make(map[chan *types.FinalisationInfo]struct{}), pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), } - genesisBlock, err := bs.GetBlockByNumber(big.NewInt(0)) + gh, err := bs.db.Get(headerHashKey(0)) if err != nil { - return nil, fmt.Errorf("failed to get genesis header: %w", err) + return nil, fmt.Errorf("cannot get block 0: %w", err) } + genesisHash := common.NewHash(gh) - bs.genesisHash = genesisBlock.Header.Hash() - bs.lastFinalised, err = bs.GetHighestFinalisedHash() + header, err := bs.GetHighestFinalisedHeader() if err != nil { - return nil, fmt.Errorf("failed to get last finalised hash: %w", err) + return nil, fmt.Errorf("failed to get last finalised header: %w", err) } + bs.genesisHash = genesisHash + bs.lastFinalised = header.Hash() + bs.bt = blocktree.NewBlockTreeFromRoot(header) return bs, nil } // NewBlockStateFromGenesis initialises a BlockState from a genesis header, saving it to the database located at basePath func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header) (*BlockState, error) { bs := &BlockState{ - bt: blocktree.NewBlockTreeFromRoot(header, db), + bt: blocktree.NewBlockTreeFromRoot(header), baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), + unfinalisedBlocks: new(sync.Map), imported: make(map[chan *types.Block]struct{}), finalised: make(map[chan *types.FinalisationInfo]struct{}), pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), @@ -123,6 +137,9 @@ func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header) (*Block return nil, err } + bs.genesisHash = header.Hash() + bs.lastFinalised = header.Hash() + if err := bs.db.Put(highestRoundAndSetIDKey, roundAndSetIDToBytes(0, 0)); err != nil { return nil, err } @@ -135,17 +152,6 @@ func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header) (*Block return bs, nil } -var ( - // Data prefixes - headerPrefix = []byte("hdr") // headerPrefix + hash -> header - blockBodyPrefix = []byte("blb") // blockBodyPrefix + hash -> body - headerHashPrefix = []byte("hsh") // headerHashPrefix + encodedBlockNum -> hash - arrivalTimePrefix = []byte("arr") // arrivalTimePrefix || hash -> arrivalTime - receiptPrefix = []byte("rcp") // receiptPrefix + hash -> receipt - messageQueuePrefix = []byte("mqp") // messageQueuePrefix + hash -> message queue - justificationPrefix = []byte("jcp") // justificationPrefix + hash -> justification -) - // encodeBlockNumber encodes a block number as big endian uint64 func encodeBlockNumber(number uint64) []byte { enc := make([]byte, 8) // encoding results in 8 bytes @@ -178,60 +184,59 @@ func (bs *BlockState) GenesisHash() common.Hash { return bs.genesisHash } -// DeleteBlock deletes all instances of the block and its related data in the database -func (bs *BlockState) DeleteBlock(hash common.Hash) error { - if has, _ := bs.HasHeader(hash); has { - err := bs.db.Del(headerKey(hash)) - if err != nil { - return err - } - } +func (bs *BlockState) storeUnfinalisedBlock(block *types.Block) { + bs.unfinalisedBlocks.Store(block.Header.Hash(), block) +} - if has, _ := bs.HasBlockBody(hash); has { - err := bs.db.Del(blockBodyKey(hash)) - if err != nil { - return err - } - } +func (bs *BlockState) hasUnfinalisedBlock(hash common.Hash) bool { + _, has := bs.unfinalisedBlocks.Load(hash) + return has +} - if has, _ := bs.HasArrivalTime(hash); has { - err := bs.db.Del(arrivalTimeKey(hash)) - if err != nil { - return err - } +func (bs *BlockState) getUnfinalisedHeader(hash common.Hash) (*types.Header, bool) { + block, has := bs.getUnfinalisedBlock(hash) + if !has { + return nil, false } - if has, _ := bs.HasReceipt(hash); has { - err := bs.db.Del(prefixKey(hash, receiptPrefix)) - if err != nil { - return err - } - } + return &block.Header, true +} - if has, _ := bs.HasMessageQueue(hash); has { - err := bs.db.Del(prefixKey(hash, messageQueuePrefix)) - if err != nil { - return err - } +func (bs *BlockState) getUnfinalisedBlock(hash common.Hash) (*types.Block, bool) { + block, has := bs.unfinalisedBlocks.Load(hash) + if !has { + return nil, false } - if has, _ := bs.HasJustification(hash); has { - err := bs.db.Del(prefixKey(hash, justificationPrefix)) - if err != nil { - return err - } + // TODO: dot/core tx re-org test seems to abort here due to block body being invalid? + return block.(*types.Block), true +} + +func (bs *BlockState) getAndDeleteUnfinalisedBlock(hash common.Hash) (*types.Block, bool) { + block, has := bs.unfinalisedBlocks.LoadAndDelete(hash) + if !has { + return nil, false } - return nil + return block.(*types.Block), true } // HasHeader returns if the db contains a header with the given hash func (bs *BlockState) HasHeader(hash common.Hash) (bool, error) { + if bs.hasUnfinalisedBlock(hash) { + return true, nil + } + return bs.db.Has(headerKey(hash)) } // GetHeader returns a BlockHeader for a given hash func (bs *BlockState) GetHeader(hash common.Hash) (*types.Header, error) { + header, has := bs.getUnfinalisedHeader(hash) + if has { + return header, nil + } + result := types.NewEmptyHeader() if bs.db == nil { @@ -260,8 +265,16 @@ func (bs *BlockState) GetHeader(hash common.Hash) (*types.Header, error) { return result, nil } -// GetHashByNumber returns the block hash given the number +// GetHashByNumber returns the block hash on our best chain with the given number func (bs *BlockState) GetHashByNumber(num *big.Int) (common.Hash, error) { + hash, err := bs.bt.GetHashByNumber(num) + if err == nil { + return hash, nil + } else if !errors.Is(err, blocktree.ErrNumLowerThanRoot) { + return common.Hash{}, fmt.Errorf("failed to get hash from blocktree: %w", err) + } + + // if error is ErrNumLowerThanRoot, number has already been finalised, so check db bh, err := bs.db.Get(headerHashKey(num.Uint64())) if err != nil { return common.Hash{}, fmt.Errorf("cannot get block %d: %w", num, err) @@ -270,83 +283,91 @@ func (bs *BlockState) GetHashByNumber(num *big.Int) (common.Hash, error) { return common.NewHash(bh), nil } -// GetHeaderByNumber returns a block header given a number +// GetHeaderByNumber returns the block header on our best chain with the given number func (bs *BlockState) GetHeaderByNumber(num *big.Int) (*types.Header, error) { - bh, err := bs.db.Get(headerHashKey(num.Uint64())) + hash, err := bs.GetHashByNumber(num) if err != nil { - return nil, fmt.Errorf("cannot get block %d: %w", num, err) + return nil, err } - hash := common.NewHash(bh) return bs.GetHeader(hash) } -// GetBlockByHash returns a block for a given hash -func (bs *BlockState) GetBlockByHash(hash common.Hash) (*types.Block, error) { - header, err := bs.GetHeader(hash) +// GetBlockByNumber returns the block on our best chain with the given number +func (bs *BlockState) GetBlockByNumber(num *big.Int) (*types.Block, error) { + hash, err := bs.GetHashByNumber(num) if err != nil { return nil, err } - blockBody, err := bs.GetBlockBody(hash) + block, err := bs.GetBlockByHash(hash) if err != nil { return nil, err } - return &types.Block{Header: *header, Body: *blockBody}, nil + + return block, nil } -// GetBlockByNumber returns a block for a given blockNumber -func (bs *BlockState) GetBlockByNumber(num *big.Int) (*types.Block, error) { - // First retrieve the block hash in a byte array based on the block number from the database - byteHash, err := bs.db.Get(headerHashKey(num.Uint64())) - if err != nil { - return nil, fmt.Errorf("cannot get block %d: %w", num, err) +// GetBlockByHash returns a block for a given hash +func (bs *BlockState) GetBlockByHash(hash common.Hash) (*types.Block, error) { + bs.RLock() + defer bs.RUnlock() + + block, has := bs.getUnfinalisedBlock(hash) + if has { + return block, nil } - // Then find the block based on the hash - hash := common.NewHash(byteHash) - block, err := bs.GetBlockByHash(hash) + header, err := bs.GetHeader(hash) if err != nil { return nil, err } - return block, nil -} - -// GetBlockHash returns block hash for a given blockNumber -func (bs *BlockState) GetBlockHash(blockNumber *big.Int) (common.Hash, error) { - byteHash, err := bs.db.Get(headerHashKey(blockNumber.Uint64())) + blockBody, err := bs.GetBlockBody(hash) if err != nil { - return common.Hash{}, fmt.Errorf("cannot get block %d: %w", blockNumber, err) + return nil, err } + return &types.Block{Header: *header, Body: *blockBody}, nil +} - return common.NewHash(byteHash), nil +// GetBlockHash returns block hash for a given block number +// TODO: remove in favour of GetHashByNumber +func (bs *BlockState) GetBlockHash(num *big.Int) (common.Hash, error) { + return bs.GetHashByNumber(num) } // SetHeader will set the header into DB func (bs *BlockState) SetHeader(header *types.Header) error { - hash := header.Hash() - // Write the encoded header bh, err := scale.Marshal(*header) if err != nil { return err } - err = bs.db.Put(headerKey(hash), bh) - if err != nil { - return err - } - - return nil + return bs.db.Put(headerKey(header.Hash()), bh) } // HasBlockBody returns true if the db contains the block body func (bs *BlockState) HasBlockBody(hash common.Hash) (bool, error) { + bs.RLock() + defer bs.RUnlock() + + if bs.hasUnfinalisedBlock(hash) { + return true, nil + } + return bs.db.Has(blockBodyKey(hash)) } // GetBlockBody will return Body for a given hash func (bs *BlockState) GetBlockBody(hash common.Hash) (*types.Body, error) { + bs.RLock() + defer bs.RUnlock() + + block, has := bs.getUnfinalisedBlock(hash) + if has { + return &block.Body, nil + } + data, err := bs.db.Get(blockBodyKey(hash)) if err != nil { return nil, err @@ -395,93 +416,22 @@ func (bs *BlockState) AddBlock(block *types.Block) error { // AddBlockWithArrivalTime adds a block to the blocktree and the DB with the given arrival time func (bs *BlockState) AddBlockWithArrivalTime(block *types.Block, arrivalTime time.Time) error { - // add block to blocktree - if err := bs.bt.AddBlock(&block.Header, uint64(arrivalTime.UnixNano())); err != nil { - return err - } - - if err := bs.setArrivalTime(block.Header.Hash(), arrivalTime); err != nil { - return err - } - - prevHead := bs.bt.DeepestBlockHash() - - // add the header to the DB - err := bs.SetHeader(&block.Header) - if err != nil { - return err - } - hash := block.Header.Hash() - - // set best block key if this is the highest block we've seen - if hash == bs.BestBlockHash() { - err = bs.setBestBlockHashKey(hash) - if err != nil { - return err - } - } - - // only set number->hash mapping for our current chain - var onChain bool - if onChain, err = bs.isBlockOnCurrentChain(&block.Header); onChain && err == nil { - err = bs.db.Put(headerHashKey(block.Header.Number.Uint64()), hash.ToBytes()) - if err != nil { - return err - } - } - - err = bs.SetBlockBody(block.Header.Hash(), &block.Body) - if err != nil { - return err + if block.Body == nil { + return errNilBlockBody } - // check if there was a re-org, if so, re-set the canonical number->hash mapping - err = bs.handleAddedBlock(prevHead, bs.bt.DeepestBlockHash()) - if err != nil { + // add block to blocktree + if err := bs.bt.AddBlock(&block.Header, arrivalTime); err != nil { return err } + bs.storeUnfinalisedBlock(block) go bs.notifyImported(block) - return bs.db.Flush() -} - -// handleAddedBlock re-sets the canonical number->hash mapping if there was a chain re-org. -// prev is the previous best block hash before the new block was added to the blocktree. -// curr is the current best block hash. -func (bs *BlockState) handleAddedBlock(prev, curr common.Hash) error { - ancestor, err := bs.HighestCommonAncestor(prev, curr) - if err != nil { - return err - } - - // if the highest common ancestor of the previous chain head and current chain head is the previous chain head, - // then the current chain head is the descendant of the previous and thus are on the same chain - if ancestor == prev { - return nil - } - - subchain, err := bs.SubChain(ancestor, curr) - if err != nil { - return err - } - - batch := bs.db.NewBatch() - for _, hash := range subchain { - header, err := bs.GetHeader(hash) - if err != nil { - return fmt.Errorf("failed to get header in subchain: %w", err) - } - - err = batch.Put(headerHashKey(header.Number.Uint64()), hash.ToBytes()) - if err != nil { - return err - } - } - - return batch.Flush() + return nil } // AddBlockToBlockTree adds the given block to the blocktree. It does not write it to the database. +// TODO: remove this func and usage from sync (after sync refactor?) func (bs *BlockState) AddBlockToBlockTree(header *types.Header) error { bs.Lock() defer bs.Unlock() @@ -491,12 +441,12 @@ func (bs *BlockState) AddBlockToBlockTree(header *types.Header) error { arrivalTime = time.Now() } - return bs.bt.AddBlock(header, uint64(arrivalTime.UnixNano())) + return bs.bt.AddBlock(header, arrivalTime) } // GetAllBlocksAtDepth returns all hashes with the depth of the given hash plus one func (bs *BlockState) GetAllBlocksAtDepth(hash common.Hash) []common.Hash { - return bs.bt.GetAllBlocksAtDepth(hash) + return bs.bt.GetAllBlocksAtNumber(hash) } func (bs *BlockState) isBlockOnCurrentChain(header *types.Header) (bool, error) { @@ -619,17 +569,13 @@ func (bs *BlockState) BlocktreeAsString() string { return bs.bt.String() } -func (bs *BlockState) setBestBlockHashKey(hash common.Hash) error { - return bs.baseState.StoreBestBlockHash(hash) -} - -// HasArrivalTime returns true if the db contains the block's arrival time -func (bs *BlockState) HasArrivalTime(hash common.Hash) (bool, error) { - return bs.db.Has(arrivalTimeKey(hash)) -} - // GetArrivalTime returns the arrival time in nanoseconds since the Unix epoch of a block given its hash func (bs *BlockState) GetArrivalTime(hash common.Hash) (time.Time, error) { + at, err := bs.bt.GetArrivalTime(hash) + if err == nil { + return at, nil + } + arrivalTime, err := bs.db.Get(arrivalTimeKey(hash)) if err != nil { return time.Time{}, err diff --git a/dot/state/block_finalisation.go b/dot/state/block_finalisation.go index 14b05cfacf..ece7ddbd6e 100644 --- a/dot/state/block_finalisation.go +++ b/dot/state/block_finalisation.go @@ -39,7 +39,7 @@ func (bs *BlockState) HasFinalisedBlock(round, setID uint64) (bool, error) { // NumberIsFinalised checks if a block number is finalised or not func (bs *BlockState) NumberIsFinalised(num *big.Int) (bool, error) { - header, err := bs.GetFinalisedHeader(0, 0) + header, err := bs.GetHighestFinalisedHeader() if err != nil { return false, err } @@ -93,7 +93,7 @@ func (bs *BlockState) setHighestRoundAndSetID(round, setID uint64) error { func (bs *BlockState) GetHighestRoundAndSetID() (uint64, uint64, error) { b, err := bs.db.Get(highestRoundAndSetIDKey) if err != nil { - return 0, 0, err + return 0, 0, fmt.Errorf("failed to get highest round and setID: %w", err) } round := binary.LittleEndian.Uint64(b[:8]) @@ -136,13 +136,16 @@ func (bs *BlockState) SetFinalisedHash(hash common.Hash, round, setID uint64) er return fmt.Errorf("cannot finalise unknown block %s", hash) } - // if nothing was previously finalised, set the first slot of the network to the - // slot number of block 1, which is now being set as final - if bs.lastFinalised.Equal(bs.genesisHash) && !hash.Equal(bs.genesisHash) { - err := bs.setFirstSlotOnFinalisation() - if err != nil { - return err - } + if err := bs.handleFinalisedBlock(hash); err != nil { + return fmt.Errorf("failed to set finalised subchain in db on finalisation: %w", err) + } + + if err := bs.db.Put(finalisedHashKey(round, setID), hash[:]); err != nil { + return fmt.Errorf("failed to set finalised hash key: %w", err) + } + + if err := bs.setHighestRoundAndSetID(round, setID); err != nil { + return fmt.Errorf("failed to set highest round and set ID: %w", err) } if err := bs.handleFinalisedBlock(hash); err != nil { @@ -155,31 +158,28 @@ func (bs *BlockState) SetFinalisedHash(hash common.Hash, round, setID uint64) er pruned := bs.bt.Prune(hash) for _, hash := range pruned { - header, err := bs.GetHeader(hash) - if err != nil { - logger.Debug("failed to get pruned header", "hash", hash, "error", err) + block, has := bs.getAndDeleteUnfinalisedBlock(hash) + if !has { continue } - err = bs.DeleteBlock(hash) - if err != nil { - logger.Debug("failed to delete block", "hash", hash, "error", err) - continue - } + logger.Trace("pruned block", "hash", hash, "number", block.Header.Number) - logger.Trace("pruned block", "hash", hash, "number", header.Number) go func(header *types.Header) { bs.pruneKeyCh <- header - }(header) + }(&block.Header) } - bs.lastFinalised = hash - - if err := bs.db.Put(finalisedHashKey(round, setID), hash[:]); err != nil { - return err + // if nothing was previously finalised, set the first slot of the network to the + // slot number of block 1, which is now being set as final + if bs.lastFinalised.Equal(bs.genesisHash) && !hash.Equal(bs.genesisHash) { + if err := bs.setFirstSlotOnFinalisation(); err != nil { + return fmt.Errorf("failed to set first slot on finalisation: %w", err) + } } - return bs.setHighestRoundAndSetID(round, setID) + bs.lastFinalised = hash + return nil } func (bs *BlockState) handleFinalisedBlock(curr common.Hash) error { @@ -189,7 +189,7 @@ func (bs *BlockState) handleFinalisedBlock(curr common.Hash) error { prev, err := bs.GetHighestFinalisedHash() if err != nil { - return err + return fmt.Errorf("failed to get highest finalised hash: %w", err) } if prev.Equal(curr) { @@ -202,16 +202,38 @@ func (bs *BlockState) handleFinalisedBlock(curr common.Hash) error { } batch := bs.db.NewBatch() - for _, hash := range subchain { - header, err := bs.GetHeader(hash) - if err != nil { - return fmt.Errorf("failed to get header in subchain: %w", err) + + // root of subchain is previously finalised block, which has already been stored in the db + for _, hash := range subchain[1:] { + if hash.Equal(bs.genesisHash) { + continue + } + + block, has := bs.getAndDeleteUnfinalisedBlock(hash) + if !has { + return fmt.Errorf("failed to find block in unfinalised block map, block=%s", hash) } - err = batch.Put(headerHashKey(header.Number.Uint64()), hash.ToBytes()) + if err = bs.SetHeader(&block.Header); err != nil { + return err + } + + if err = bs.SetBlockBody(hash, &block.Body); err != nil { + return err + } + + arrivalTime, err := bs.bt.GetArrivalTime(hash) if err != nil { return err } + + if err = bs.setArrivalTime(hash, arrivalTime); err != nil { + return err + } + + if err = batch.Put(headerHashKey(block.Header.Number.Uint64()), hash.ToBytes()); err != nil { + return err + } } return batch.Flush() diff --git a/dot/state/block_finalisation_test.go b/dot/state/block_finalisation_test.go index 0acc0960a4..0e047228b4 100644 --- a/dot/state/block_finalisation_test.go +++ b/dot/state/block_finalisation_test.go @@ -75,15 +75,46 @@ func TestHighestRoundAndSetID(t *testing.T) { func TestBlockState_SetFinalisedHash(t *testing.T) { bs := newTestBlockState(t, testGenesisHeader) + h, err := bs.GetFinalisedHash(0, 0) + require.NoError(t, err) + require.Equal(t, testGenesisHeader.Hash(), h) digest := types.NewDigest() - err := digest.Add(*types.NewBabeSecondaryPlainPreDigest(0, 1).ToPreRuntimeDigest()) + err = digest.Add(*types.NewBabeSecondaryPlainPreDigest(0, 1).ToPreRuntimeDigest()) require.NoError(t, err) - digest2 := types.NewDigest() - err = digest2.Add(*types.NewBabeSecondaryPlainPreDigest(0, 2).ToPreRuntimeDigest()) + header := &types.Header{ + ParentHash: testGenesisHeader.Hash(), + Number: big.NewInt(1), + Digest: digest, + } + + testhash := header.Hash() + err = bs.db.Put(headerKey(testhash), []byte{}) + require.NoError(t, err) + + err = bs.AddBlock(&types.Block{ + Header: *header, + Body: types.Body{}, + }) + require.NoError(t, err) + + err = bs.SetFinalisedHash(testhash, 1, 1) + require.NoError(t, err) + + h, err = bs.GetFinalisedHash(1, 1) + require.NoError(t, err) + require.Equal(t, testhash, h) +} + +func TestSetFinalisedHash_setFirstSlotOnFinalisation(t *testing.T) { + bs := newTestBlockState(t, testGenesisHeader) + firstSlot := uint64(42069) + + digest := types.NewDigest() + err := digest.Add(*types.NewBabeSecondaryPlainPreDigest(0, firstSlot).ToPreRuntimeDigest()) require.NoError(t, err) - digest3 := types.NewDigest() - err = digest3.Add(*types.NewBabeSecondaryPlainPreDigest(0, 200).ToPreRuntimeDigest()) + digest2 := types.NewDigest() + err = digest2.Add(*types.NewBabeSecondaryPlainPreDigest(0, firstSlot+100).ToPreRuntimeDigest()) require.NoError(t, err) header1 := types.Header{ @@ -98,12 +129,6 @@ func TestBlockState_SetFinalisedHash(t *testing.T) { ParentHash: header1.Hash(), } - header2Again := types.Header{ - Number: big.NewInt(2), - Digest: digest3, - ParentHash: header1.Hash(), - } - err = bs.AddBlock(&types.Block{ Header: header1, Body: types.Body{}, @@ -116,21 +141,11 @@ func TestBlockState_SetFinalisedHash(t *testing.T) { }) require.NoError(t, err) - err = bs.AddBlock(&types.Block{ - Header: header2Again, - Body: types.Body{}, - }) - require.NoError(t, err) - - err = bs.SetFinalisedHash(header2Again.Hash(), 0, 0) - require.NoError(t, err) - require.Equal(t, header2Again.Hash(), bs.lastFinalised) - - h1, err := bs.GetHeaderByNumber(big.NewInt(1)) + err = bs.SetFinalisedHash(header2.Hash(), 1, 1) require.NoError(t, err) - require.Equal(t, &header1, h1) + require.Equal(t, header2.Hash(), bs.lastFinalised) - h2, err := bs.GetHeaderByNumber(big.NewInt(2)) + res, err := bs.baseState.loadFirstSlot() require.NoError(t, err) - require.Equal(t, &header2Again, h2) + require.Equal(t, firstSlot, res) } diff --git a/dot/state/block_race_test.go b/dot/state/block_race_test.go index 567ea67743..4bf323d03e 100644 --- a/dot/state/block_race_test.go +++ b/dot/state/block_race_test.go @@ -42,22 +42,20 @@ func TestConcurrencySetHeader(t *testing.T) { go func(index int) { defer pend.Done() - bs := &BlockState{ - db: dbs[index], - } + bs, err := NewBlockStateFromGenesis(dbs[index], testGenesisHeader) + require.NoError(t, err) header := &types.Header{ - Number: big.NewInt(0), + Number: big.NewInt(1), StateRoot: trie.EmptyHash, Digest: types.NewDigest(), } - err := bs.SetHeader(header) - require.Nil(t, err) + err = bs.SetHeader(header) + require.NoError(t, err) res, err := bs.GetHeader(header.Hash()) - require.Nil(t, err) - + require.NoError(t, err) require.Equal(t, header, res) }(i) diff --git a/dot/state/block_test.go b/dot/state/block_test.go index 871ced5ae7..0dd652dc81 100644 --- a/dot/state/block_test.go +++ b/dot/state/block_test.go @@ -25,7 +25,6 @@ import ( "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie" - "github.com/ChainSafe/chaindb" "github.com/stretchr/testify/require" ) @@ -40,9 +39,7 @@ var testGenesisHeader = &types.Header{ func newTestBlockState(t *testing.T, header *types.Header) *BlockState { db := NewInMemoryDB(t) if header == nil { - return &BlockState{ - db: chaindb.NewTable(db, blockPrefix), - } + header = testGenesisHeader } bs, err := NewBlockStateFromGenesis(db, header) @@ -98,7 +95,6 @@ func TestGetBlockByNumber(t *testing.T) { Body: sampleBlockBody, } - // AddBlock also sets mapping [blockNumber : hash] in DB err := bs.AddBlock(block) require.NoError(t, err) @@ -112,7 +108,7 @@ func TestAddBlock(t *testing.T) { // Create header header0 := &types.Header{ - Number: big.NewInt(0), + Number: big.NewInt(1), Digest: types.NewDigest(), ParentHash: testGenesisHeader.Hash(), } @@ -127,9 +123,9 @@ func TestAddBlock(t *testing.T) { err := bs.AddBlock(block0) require.NoError(t, err) - // Create header & blockData for block 1 + // Create header & blockData for block 2 header1 := &types.Header{ - Number: big.NewInt(1), + Number: big.NewInt(2), Digest: types.NewDigest(), ParentHash: blockHash0, } @@ -267,39 +263,6 @@ func TestAddBlock_BlockNumberToHash(t *testing.T) { } } -func TestFinalizedHash(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) - h, err := bs.GetFinalisedHash(0, 0) - require.NoError(t, err) - require.Equal(t, testGenesisHeader.Hash(), h) - - digest := types.NewDigest() - err = digest.Add(*types.NewBabeSecondaryPlainPreDigest(0, 1).ToPreRuntimeDigest()) - require.NoError(t, err) - header := &types.Header{ - ParentHash: testGenesisHeader.Hash(), - Number: big.NewInt(1), - Digest: digest, - } - - testhash := header.Hash() - err = bs.db.Put(headerKey(testhash), []byte{}) - require.NoError(t, err) - - err = bs.AddBlock(&types.Block{ - Header: *header, - Body: types.Body{}, - }) - require.NoError(t, err) - - err = bs.SetFinalisedHash(testhash, 1, 1) - require.NoError(t, err) - - h, err = bs.GetFinalisedHash(1, 1) - require.NoError(t, err) - require.Equal(t, testhash, h) -} - func TestFinalization_DeleteBlock(t *testing.T) { bs := newTestBlockState(t, testGenesisHeader) AddBlocksToState(t, bs, 5, false) @@ -308,12 +271,6 @@ func TestFinalization_DeleteBlock(t *testing.T) { before := bs.bt.GetAllBlocks() leaves := bs.Leaves() - for _, n := range before { - has, err := bs.HasArrivalTime(n) - require.NoError(t, err) - require.True(t, has, n) - } - // pick block to finalise fin := leaves[len(leaves)-1] err := bs.SetFinalisedHash(fin, 1, 1) @@ -358,14 +315,6 @@ func TestFinalization_DeleteBlock(t *testing.T) { } else { require.False(t, has) } - - // has, err = bs.HasArrivalTime(b) - // require.NoError(t, err) - // if isFinalised && b != bs.genesisHash { - // require.True(t, has, b) - // } else { - // require.False(t, has) - // } } } @@ -557,15 +506,13 @@ func TestNumberIsFinalised(t *testing.T) { Body: types.Body{}, }) require.NoError(t, err) - err = bs.db.Put(headerHashKey(header1.Number.Uint64()), header1.Hash().ToBytes()) - require.NoError(t, err) err = bs.AddBlock(&types.Block{ Header: header2, Body: types.Body{}, }) require.NoError(t, err) - err = bs.SetFinalisedHash(header2.Hash(), 0, 0) + err = bs.SetFinalisedHash(header2.Hash(), 1, 1) require.NoError(t, err) fin, err = bs.NumberIsFinalised(big.NewInt(0)) @@ -580,50 +527,7 @@ func TestNumberIsFinalised(t *testing.T) { require.NoError(t, err) require.True(t, fin) - fin, err = bs.NumberIsFinalised(big.NewInt(3)) + fin, err = bs.NumberIsFinalised(big.NewInt(100)) require.NoError(t, err) require.False(t, fin) } - -func TestSetFinalisedHash_setFirstSlotOnFinalisation(t *testing.T) { - bs := newTestBlockState(t, testGenesisHeader) - firstSlot := uint64(42069) - - digest := types.NewDigest() - err := digest.Add(*types.NewBabeSecondaryPlainPreDigest(0, firstSlot).ToPreRuntimeDigest()) - require.NoError(t, err) - digest2 := types.NewDigest() - err = digest2.Add(*types.NewBabeSecondaryPlainPreDigest(0, firstSlot+100).ToPreRuntimeDigest()) - require.NoError(t, err) - - header1 := types.Header{ - Number: big.NewInt(1), - Digest: digest, - ParentHash: testGenesisHeader.Hash(), - } - - header2 := types.Header{ - Number: big.NewInt(2), - Digest: digest2, - ParentHash: header1.Hash(), - } - - err = bs.AddBlock(&types.Block{ - Header: header1, - Body: types.Body{}, - }) - require.NoError(t, err) - - err = bs.AddBlock(&types.Block{ - Header: header2, - Body: types.Body{}, - }) - require.NoError(t, err) - err = bs.SetFinalisedHash(header2.Hash(), 0, 0) - require.NoError(t, err) - require.Equal(t, header2.Hash(), bs.lastFinalised) - - res, err := bs.baseState.loadFirstSlot() - require.NoError(t, err) - require.Equal(t, firstSlot, res) -} diff --git a/dot/state/epoch.go b/dot/state/epoch.go index edbe6c3530..56998a9b4f 100644 --- a/dot/state/epoch.go +++ b/dot/state/epoch.go @@ -63,7 +63,7 @@ type EpochState struct { } // NewEpochStateFromGenesis returns a new EpochState given information for the first epoch, fetched from the runtime -func NewEpochStateFromGenesis(db chaindb.Database, genesisConfig *types.BabeConfiguration) (*EpochState, error) { +func NewEpochStateFromGenesis(db chaindb.Database, blockState *BlockState, genesisConfig *types.BabeConfiguration) (*EpochState, error) { baseState := NewBaseState(db) err := baseState.storeFirstSlot(1) // this may change once the first block is imported @@ -83,6 +83,7 @@ func NewEpochStateFromGenesis(db chaindb.Database, genesisConfig *types.BabeConf s := &EpochState{ baseState: NewBaseState(db), + blockState: blockState, db: epochDB, epochLength: genesisConfig.EpochLength, } @@ -121,9 +122,6 @@ func NewEpochStateFromGenesis(db chaindb.Database, genesisConfig *types.BabeConf return nil, err } - s.blockState = &BlockState{ - db: chaindb.NewTable(db, blockPrefix), - } return s, nil } @@ -350,7 +348,7 @@ func (s *EpochState) GetEpochFromTime(t time.Time) (uint64, error) { // SetFirstSlot sets the first slot number of the network func (s *EpochState) SetFirstSlot(slot uint64) error { // check if block 1 was finalised already; if it has, don't set first slot again - header, err := s.blockState.GetFinalisedHeader(0, 0) + header, err := s.blockState.GetHighestFinalisedHeader() if err != nil { return err } diff --git a/dot/state/epoch_test.go b/dot/state/epoch_test.go index a83bc43ca2..5fbfadc66d 100644 --- a/dot/state/epoch_test.go +++ b/dot/state/epoch_test.go @@ -39,7 +39,7 @@ var genesisBABEConfig = &types.BabeConfiguration{ func newEpochStateFromGenesis(t *testing.T) *EpochState { db := NewInMemoryDB(t) - s, err := NewEpochStateFromGenesis(db, genesisBABEConfig) + s, err := NewEpochStateFromGenesis(db, newTestBlockState(t, nil), genesisBABEConfig) require.NoError(t, err) return s } diff --git a/dot/state/initialize.go b/dot/state/initialize.go index 4440e148b0..0ba46b6be2 100644 --- a/dot/state/initialize.go +++ b/dot/state/initialize.go @@ -24,7 +24,6 @@ import ( "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/dot/state/pruner" "github.com/ChainSafe/gossamer/dot/types" - "github.com/ChainSafe/gossamer/lib/blocktree" "github.com/ChainSafe/gossamer/lib/genesis" "github.com/ChainSafe/gossamer/lib/runtime" rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" @@ -71,17 +70,10 @@ func (s *Service) Initialise(gen *genesis.Genesis, header *types.Header, t *trie } // write initial genesis values to database - if err = s.storeInitialValues(gen.GenesisData(), header, t); err != nil { + if err = s.storeInitialValues(gen.GenesisData(), t); err != nil { return fmt.Errorf("failed to write genesis values to database: %s", err) } - // create and store blocktree from genesis block - bt := blocktree.NewBlockTreeFromRoot(header, db) - err = bt.Store() - if err != nil { - return fmt.Errorf("failed to write blocktree to database: %s", err) - } - // create block state from genesis block blockState, err := NewBlockStateFromGenesis(db, header) if err != nil { @@ -94,7 +86,7 @@ func (s *Service) Initialise(gen *genesis.Genesis, header *types.Header, t *trie return fmt.Errorf("failed to create storage state from trie: %s", err) } - epochState, err := NewEpochStateFromGenesis(db, babeCfg) + epochState, err := NewEpochStateFromGenesis(db, blockState, babeCfg) if err != nil { return fmt.Errorf("failed to create epoch state: %s", err) } @@ -153,22 +145,12 @@ func loadGrandpaAuthorities(t *trie.Trie) ([]types.GrandpaVoter, error) { } // storeInitialValues writes initial genesis values to the state database -func (s *Service) storeInitialValues(data *genesis.Data, header *types.Header, t *trie.Trie) error { +func (s *Service) storeInitialValues(data *genesis.Data, t *trie.Trie) error { // write genesis trie to database if err := t.Store(chaindb.NewTable(s.db, storagePrefix)); err != nil { return fmt.Errorf("failed to write trie to database: %s", err) } - // write storage hash to database - if err := s.Base.StoreLatestStorageHash(t.MustHash()); err != nil { - return fmt.Errorf("failed to write storage hash to database: %s", err) - } - - // write best block hash to state database - if err := s.Base.StoreBestBlockHash(header.Hash()); err != nil { - return fmt.Errorf("failed to write best block hash to database: %s", err) - } - // write genesis data to state database if err := s.Base.StoreGenesisData(data); err != nil { return fmt.Errorf("failed to write genesis data to database: %s", err) diff --git a/dot/state/offline_pruner.go b/dot/state/offline_pruner.go index 67b3f757b9..3f6c6abcf4 100644 --- a/dot/state/offline_pruner.go +++ b/dot/state/offline_pruner.go @@ -7,7 +7,6 @@ import ( "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/dot/state/pruner" - "github.com/ChainSafe/gossamer/lib/blocktree" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie" "github.com/ChainSafe/gossamer/lib/utils" @@ -38,22 +37,15 @@ func NewOfflinePruner(inputDBPath, prunedDBPath string, bloomSize uint64, retain return nil, fmt.Errorf("failed to load DB %w", err) } - base := NewBaseState(db) - bestHash, err := base.LoadBestBlockHash() + // create blockState state + blockState, err := NewBlockState(db) if err != nil { - return nil, fmt.Errorf("failed to get best block hash: %w", err) - } - - // load blocktree - bt := blocktree.NewEmptyBlockTree(db) - if err = bt.Load(); err != nil { - return nil, fmt.Errorf("failed to load blocktree: %w", err) + return nil, fmt.Errorf("failed to create block state: %w", err) } - // create blockState state - blockState, err := NewBlockState(db, bt) + bestHash, err := blockState.GetHighestFinalisedHash() if err != nil { - return nil, fmt.Errorf("failed to create block state: %w", err) + return nil, fmt.Errorf("failed to get best finalised hash: %w", err) } // create bloom filter diff --git a/dot/state/service.go b/dot/state/service.go index 5a42565534..1d937d6ea0 100644 --- a/dot/state/service.go +++ b/dot/state/service.go @@ -17,7 +17,6 @@ package state import ( - "bytes" "fmt" "math/big" "os" @@ -121,40 +120,23 @@ func (s *Service) Start() error { s.Base = NewBaseState(db) } - // retrieve latest header - bestHash, err := s.Base.LoadBestBlockHash() - if err != nil { - return fmt.Errorf("failed to get best block hash: %w", err) - } - - logger.Trace("start", "best block hash", bestHash) - - // load blocktree - bt := blocktree.NewEmptyBlockTree(db) - if err = bt.Load(); err != nil { - return fmt.Errorf("failed to load blocktree: %w", err) - } + var err error // create block state - s.Block, err = NewBlockState(db, bt) + s.Block, err = NewBlockState(db) if err != nil { return fmt.Errorf("failed to create block state: %w", err) } - // if blocktree head isn't "best hash", then the node shutdown abnormally. - // restore state from last finalised hash. - btHead := bt.DeepestBlockHash() - if !bytes.Equal(btHead[:], bestHash[:]) { - logger.Info("detected abnormal node shutdown, restoring from last finalised block") - - lastFinalised, err := s.Block.GetHighestFinalisedHeader() //nolint - if err != nil { - return fmt.Errorf("failed to get latest finalised block: %w", err) - } - - s.Block.bt = blocktree.NewBlockTreeFromRoot(lastFinalised, db) + // retrieve latest header + bestHeader, err := s.Block.GetHighestFinalisedHeader() + if err != nil { + return fmt.Errorf("failed to get best block hash: %w", err) } + stateRoot := bestHeader.StateRoot + logger.Debug("start", "latest state root", stateRoot) + pr, err := s.Base.loadPruningData() if err != nil { return err @@ -166,14 +148,7 @@ func (s *Service) Start() error { return fmt.Errorf("failed to create storage state: %w", err) } - stateRoot, err := s.Base.LoadLatestStorageHash() - if err != nil { - return fmt.Errorf("cannot load latest storage root: %w", err) - } - - logger.Debug("start", "latest state root", stateRoot) - - // load current storage state + // load current storage state trie into memory _, err = s.Storage.LoadFromDB(stateRoot) if err != nil { return fmt.Errorf("failed to load storage trie from database: %w", err) @@ -194,7 +169,11 @@ func (s *Service) Start() error { } num, _ := s.Block.BestBlockNumber() - logger.Info("created state service", "head", s.Block.BestBlockHash(), "highest number", num) + logger.Info("created state service", + "head", s.Block.BestBlockHash(), + "highest number", num, + "genesis hash", s.Block.genesisHash, + ) // Start background goroutine to GC pruned keys. go s.Storage.pruneStorage(s.closeCh) @@ -216,11 +195,15 @@ func (s *Service) Rewind(toBlock int64) error { return err } - s.Block.bt = blocktree.NewBlockTreeFromRoot(&root.Header, s.db) - newHead := s.Block.BestBlockHash() + s.Block.bt = blocktree.NewBlockTreeFromRoot(&root.Header) - header, _ := s.Block.BestBlockHeader() - logger.Info("rewinding state...", "new height", header.Number, "best block hash", newHead) + header, err := s.Block.BestBlockHeader() + if err != nil { + return err + } + + s.Block.lastFinalised = header.Hash() + logger.Info("rewinding state...", "new height", header.Number, "best block hash", header.Hash()) epoch, err := s.Epoch.GetEpochForBlock(header) if err != nil { @@ -266,49 +249,20 @@ func (s *Service) Rewind(toBlock int64) error { } } - return s.Base.StoreBestBlockHash(newHead) + //return s.Base.StoreBestBlockHash(newHead) + return nil } // Stop closes each state database func (s *Service) Stop() error { - head, err := s.Block.BestBlockStateRoot() - if err != nil { - return err - } - - st, has := s.Storage.tries.Load(head) - if !has { - return errTrieDoesNotExist(head) - } - - t := st.(*trie.Trie) - - if err = s.Base.StoreLatestStorageHash(head); err != nil { - return err - } - - logger.Debug("storing latest storage trie", "root", head) - - if err = t.Store(s.Storage.db); err != nil { - return err - } - - if err = s.Block.bt.Store(); err != nil { - return err - } - - hash := s.Block.BestBlockHash() - if err = s.Base.StoreBestBlockHash(hash); err != nil { - return err - } + close(s.closeCh) - thash, err := t.Hash() + hash, err := s.Block.GetHighestFinalisedHash() if err != nil { return err } - close(s.closeCh) - logger.Debug("stop", "best block hash", hash, "latest state root", thash) + logger.Debug("stop", "best finalised hash", hash) if err = s.db.Flush(); err != nil { return err @@ -367,30 +321,26 @@ func (s *Service) Import(header *types.Header, t *trie.Trie, firstSlot uint64) e return fmt.Errorf("trie state root does not equal header state root") } - if err := s.Base.StoreLatestStorageHash(root); err != nil { - return err - } - logger.Info("importing storage trie...", "basepath", s.dbPath, "root", root) if err := t.Store(storage.db); err != nil { return err } - bt := blocktree.NewBlockTreeFromRoot(header, s.db) - if err := bt.Store(); err != nil { + hash := header.Hash() + if err := block.SetHeader(header); err != nil { return err } - if err := s.Base.StoreBestBlockHash(header.Hash()); err != nil { + // TODO: this is broken, need to know round and setID for the header as well + if err := block.db.Put(finalisedHashKey(0, 0), hash[:]); err != nil { return err } - - if err := block.SetHeader(header); err != nil { + if err := block.setHighestRoundAndSetID(0, 0); err != nil { return err } - logger.Debug("Import", "best block hash", header.Hash(), "latest state root", root) + logger.Debug("Import", "best block hash", hash, "latest state root", root) if err := s.db.Flush(); err != nil { return err } diff --git a/dot/state/service_test.go b/dot/state/service_test.go index 01f8aa45bb..dcd3a46d98 100644 --- a/dot/state/service_test.go +++ b/dot/state/service_test.go @@ -133,6 +133,10 @@ func TestService_BlockTree(t *testing.T) { // add blocks to state AddBlocksToState(t, stateA.Block, 10, false) + head := stateA.Block.BestBlockHash() + + err = stateA.Block.SetFinalisedHash(head, 1, 1) + require.NoError(t, err) err = stateA.Stop() require.NoError(t, err) @@ -308,6 +312,10 @@ func TestService_Rewind(t *testing.T) { require.NoError(t, err) AddBlocksToState(t, serv.Block, 12, false) + head := serv.Block.BestBlockHash() + err = serv.Block.SetFinalisedHash(head, 0, 0) + require.NoError(t, err) + err = serv.Rewind(6) require.NoError(t, err) diff --git a/dot/state/test_helpers.go b/dot/state/test_helpers.go index 391a517560..db60aeeddf 100644 --- a/dot/state/test_helpers.go +++ b/dot/state/test_helpers.go @@ -207,7 +207,7 @@ func AddBlocksToStateWithFixedBranches(t *testing.T, blockState *BlockState, dep block := &types.Block{ Header: types.Header{ ParentHash: previousHash, - Number: big.NewInt(int64(i)), + Number: big.NewInt(int64(i) + 1), StateRoot: trie.EmptyHash, Digest: digest, }, diff --git a/dot/sync/test_helpers.go b/dot/sync/test_helpers.go index 0b165a568e..8afcfb4716 100644 --- a/dot/sync/test_helpers.go +++ b/dot/sync/test_helpers.go @@ -22,9 +22,7 @@ import ( "time" "github.com/ChainSafe/gossamer/dot/types" - "github.com/ChainSafe/gossamer/lib/babe" "github.com/ChainSafe/gossamer/lib/runtime" - "github.com/ChainSafe/gossamer/lib/transaction" "github.com/ChainSafe/gossamer/pkg/scale" "github.com/stretchr/testify/require" @@ -54,52 +52,51 @@ func BuildBlock(t *testing.T, instance runtime.Instance, parent *types.Header, e ienc, err := idata.Encode() require.NoError(t, err) - // Call BlockBuilder_inherent_extrinsics which returns the inherents as extrinsics + // Call BlockBuilder_inherent_extrinsics which returns the inherents as encoded extrinsics inherentExts, err := instance.InherentExtrinsics(ienc) require.NoError(t, err) // decode inherent extrinsics - var exts [][]byte - err = scale.Unmarshal(inherentExts, &exts) + cp := make([]byte, len(inherentExts)) + copy(cp, inherentExts) + var inExts [][]byte + err = scale.Unmarshal(cp, &inExts) require.NoError(t, err) - inExt := exts - - var body *types.Body - if ext != nil { - var txn *transaction.Validity - externalExt := types.Extrinsic(append([]byte{byte(types.TxnExternal)}, ext...)) - txn, err = instance.ValidateTransaction(externalExt) + // apply each inherent extrinsic + for _, inherent := range inExts { + in, err := scale.Marshal(inherent) //nolint require.NoError(t, err) - vtx := transaction.NewValidTransaction(ext, txn) - ret, err := instance.ApplyExtrinsic(ext) //nolint + ret, err := instance.ApplyExtrinsic(in) require.NoError(t, err) require.Equal(t, ret, []byte{0, 0}) + } - body, err = babe.ExtrinsicsToBody(inExt, []*transaction.ValidTransaction{vtx}) - require.NoError(t, err) + body := types.Body(types.BytesArrayToExtrinsics(inExts)) - } else { - body = types.NewBody(types.BytesArrayToExtrinsics(exts)) - } + if ext != nil { + // validate and apply extrinsic + var ret []byte - // apply each inherent extrinsic - for _, ext := range inExt { - in, err := scale.Marshal(ext) //nolint + externalExt := types.Extrinsic(append([]byte{byte(types.TxnExternal)}, ext...)) + _, err = instance.ValidateTransaction(externalExt) require.NoError(t, err) - ret, err := instance.ApplyExtrinsic(in) + ret, err = instance.ApplyExtrinsic(ext) require.NoError(t, err) require.Equal(t, ret, []byte{0, 0}) + + body = append(body, ext) } res, err := instance.FinalizeBlock() require.NoError(t, err) res.Number = header.Number + res.Hash() return &types.Block{ Header: *res, - Body: *body, + Body: body, } } diff --git a/dot/types/block.go b/dot/types/block.go index 9b31035575..22d7762499 100644 --- a/dot/types/block.go +++ b/dot/types/block.go @@ -17,6 +17,8 @@ package types import ( + "fmt" + "github.com/ChainSafe/gossamer/pkg/scale" ) @@ -42,6 +44,14 @@ func NewEmptyBlock() Block { } } +// String returns the formatted Block string +func (b *Block) String() string { + return fmt.Sprintf("header: %v\nbody: %v", + &b.Header, + b.Body, + ) +} + // Empty returns a boolean indicating is the Block is empty func (b *Block) Empty() bool { return b.Header.Empty() && len(b.Body) == 0 diff --git a/dot/types/extrinsic.go b/dot/types/extrinsic.go index 89c51acfca..57bb1be0ed 100644 --- a/dot/types/extrinsic.go +++ b/dot/types/extrinsic.go @@ -32,6 +32,10 @@ func NewExtrinsic(e []byte) Extrinsic { return Extrinsic(e) } +func (e Extrinsic) String() string { + return common.BytesToHex(e) +} + // Hash returns the blake2b hash of the extrinsic func (e Extrinsic) Hash() common.Hash { hash, err := common.Blake2bHash(e) diff --git a/lib/babe/build.go b/lib/babe/build.go index 95eb6e9bfe..b39eb95970 100644 --- a/lib/babe/build.go +++ b/lib/babe/build.go @@ -166,14 +166,14 @@ func (b *BlockBuilder) buildBlock(parent *types.Header, slot Slot, rt runtime.In logger.Trace("built block seal") - body, err := ExtrinsicsToBody(inherents, included) + body, err := extrinsicsToBody(inherents, included) if err != nil { return nil, err } block := &types.Block{ Header: *header, - Body: *body, + Body: body, } return block, nil @@ -358,8 +358,7 @@ func hasSlotEnded(slot Slot) bool { return time.Since(slotEnd) >= 0 } -// ExtrinsicsToBody returns scale encoded block body which contains inherent and extrinsic. -func ExtrinsicsToBody(inherents [][]byte, txs []*transaction.ValidTransaction) (*types.Body, error) { +func extrinsicsToBody(inherents [][]byte, txs []*transaction.ValidTransaction) (types.Body, error) { extrinsics := types.BytesArrayToExtrinsics(inherents) for _, tx := range txs { @@ -371,5 +370,5 @@ func ExtrinsicsToBody(inherents [][]byte, txs []*transaction.ValidTransaction) ( extrinsics = append(extrinsics, decExt) } - return types.NewBody(extrinsics), nil + return types.Body(extrinsics), nil } diff --git a/lib/babe/build_test.go b/lib/babe/build_test.go index 2e52b1bc5c..28cbb6b515 100644 --- a/lib/babe/build_test.go +++ b/lib/babe/build_test.go @@ -149,6 +149,7 @@ func createTestBlock(t *testing.T, babeService *Service, parent *types.Header, e // build block block, err := babeService.buildBlock(parent, slot, rt) require.NoError(t, err) + babeService.blockState.StoreRuntime(block.Header.Hash(), rt) return block, slot } @@ -444,10 +445,10 @@ func TestDecodeExtrinsicBody(t *testing.T) { vtx := transaction.NewValidTransaction(ext, &transaction.Validity{}) - body, err := ExtrinsicsToBody(inh, []*transaction.ValidTransaction{vtx}) + body, err := extrinsicsToBody(inh, []*transaction.ValidTransaction{vtx}) require.Nil(t, err) require.NotNil(t, body) - require.Len(t, *body, 3) + require.Len(t, body, 3) contains, err := body.HasExtrinsic(ext) require.Nil(t, err) diff --git a/lib/babe/verify_test.go b/lib/babe/verify_test.go index 6bd6ddd525..4a4c0730e3 100644 --- a/lib/babe/verify_test.go +++ b/lib/babe/verify_test.go @@ -54,7 +54,7 @@ func newTestVerificationManager(t *testing.T, genCfg *types.BabeConfiguration) * err = dbSrv.Start() require.NoError(t, err) - dbSrv.Epoch, err = state.NewEpochStateFromGenesis(dbSrv.DB(), genCfg) + dbSrv.Epoch, err = state.NewEpochStateFromGenesis(dbSrv.DB(), dbSrv.Block, genCfg) require.NoError(t, err) logger = log.New("pkg", "babe") @@ -158,7 +158,6 @@ func TestVerificationManager_VerifyBlock_Ok(t *testing.T) { vm := newTestVerificationManager(t, cfg) block, _ := createTestBlock(t, babeService, genesisHeader, [][]byte{}, 1, testEpochIndex) - err = vm.VerifyBlock(&block.Header) require.NoError(t, err) } diff --git a/lib/blocktree/blocktree.go b/lib/blocktree/blocktree.go index 9ba0aafd80..5f9d6b9b06 100644 --- a/lib/blocktree/blocktree.go +++ b/lib/blocktree/blocktree.go @@ -22,7 +22,6 @@ import ( "sync" "time" - database "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/runtime" @@ -36,18 +35,16 @@ type Hash = common.Hash type BlockTree struct { root *node leaves *leafMap - db database.Database sync.RWMutex nodeCache map[Hash]*node runtime *sync.Map } // NewEmptyBlockTree creates a BlockTree with a nil head -func NewEmptyBlockTree(db database.Database) *BlockTree { +func NewEmptyBlockTree() *BlockTree { return &BlockTree{ root: nil, leaves: newEmptyLeafMap(), - db: db, nodeCache: make(map[Hash]*node), runtime: &sync.Map{}, // map[Hash]runtime.Instance } @@ -55,19 +52,18 @@ func NewEmptyBlockTree(db database.Database) *BlockTree { // NewBlockTreeFromRoot initialises a blocktree with a root block. The root block is always the most recently // finalised block (ie the genesis block if the node is just starting.) -func NewBlockTreeFromRoot(root *types.Header, db database.Database) *BlockTree { +func NewBlockTreeFromRoot(root *types.Header) *BlockTree { n := &node{ hash: root.Hash(), parent: nil, children: []*node{}, - depth: big.NewInt(0), - arrivalTime: uint64(time.Now().Unix()), + number: root.Number, + arrivalTime: time.Now(), } return &BlockTree{ root: n, leaves: newLeafMap(n), - db: db, nodeCache: make(map[Hash]*node), runtime: &sync.Map{}, } @@ -82,7 +78,7 @@ func (bt *BlockTree) GenesisHash() Hash { // AddBlock inserts the block as child of its parent node // Note: Assumes block has no children -func (bt *BlockTree) AddBlock(header *types.Header, arrivalTime uint64) error { +func (bt *BlockTree) AddBlock(header *types.Header, arrivalTime time.Time) error { bt.Lock() defer bt.Unlock() @@ -97,14 +93,18 @@ func (bt *BlockTree) AddBlock(header *types.Header, arrivalTime uint64) error { return ErrBlockExists } - depth := big.NewInt(0) - depth.Add(parent.depth, big.NewInt(1)) + number := big.NewInt(0) + number.Add(parent.number, big.NewInt(1)) + + if number.Cmp(header.Number) != 0 { + return errUnexpectedNumber + } n = &node{ hash: header.Hash(), parent: parent, children: []*node{}, - depth: depth, + number: number, arrivalTime: arrivalTime, } parent.addChild(n) @@ -114,29 +114,9 @@ func (bt *BlockTree) AddBlock(header *types.Header, arrivalTime uint64) error { return nil } -// Rewind rewinds the block tree by the given height. If the blocktree is less than the given height, -// it will only rewind until the blocktree has one node. -func (bt *BlockTree) Rewind(numBlocks int) { - bt.Lock() - defer bt.Unlock() - - for i := 0; i < numBlocks; i++ { - deepest := bt.leaves.deepestLeaf() - - for _, leaf := range bt.leaves.nodes() { - if leaf.parent == nil || leaf.depth.Cmp(deepest.depth) < 0 { - continue - } - - bt.leaves.replace(leaf, leaf.parent) - leaf.parent.deleteChild(leaf) - } - } -} - -// GetAllBlocksAtDepth will return all blocks hashes with the depth of the given hash plus one. -// To find all blocks at a depth matching a certain block, pass in that block's parent hash -func (bt *BlockTree) GetAllBlocksAtDepth(hash common.Hash) []common.Hash { +// GetAllBlocksAtNumber will return all blocks hashes with the number of the given hash plus one. +// To find all blocks at a number matching a certain block, pass in that block's parent hash +func (bt *BlockTree) GetAllBlocksAtNumber(hash common.Hash) []common.Hash { bt.RLock() defer bt.RUnlock() @@ -146,14 +126,14 @@ func (bt *BlockTree) GetAllBlocksAtDepth(hash common.Hash) []common.Hash { return hashes } - depth := big.NewInt(0).Add(bt.getNode(hash).depth, big.NewInt(1)) + number := big.NewInt(0).Add(bt.getNode(hash).number, big.NewInt(1)) - if bt.root.depth.Cmp(depth) == 0 { + if bt.root.number.Cmp(number) == 0 { hashes = append(hashes, bt.root.hash) return hashes } - return bt.root.getNodesWithDepth(depth, hashes) + return bt.root.getNodesWithNumber(number, hashes) } func (bt *BlockTree) setInCache(b *node) { @@ -216,6 +196,8 @@ func (bt *BlockTree) Prune(finalised Hash) (pruned []Hash) { pruned = bt.root.prune(n, nil) bt.root = n + bt.root.parent = nil + leaves := n.getLeaves(nil) bt.leaves = newEmptyLeafMap() for _, leaf := range leaves { @@ -374,13 +356,62 @@ func (bt *BlockTree) GetAllBlocks() []Hash { return bt.root.getAllDescendants(nil) } +// GetHashByNumber returns the block hash with the given number that is on the best chain. +// If the number is lower or higher than the numbers in the blocktree, an error is returned. +func (bt *BlockTree) GetHashByNumber(num *big.Int) (common.Hash, error) { + bt.RLock() + defer bt.RUnlock() + + deepest := bt.leaves.deepestLeaf() + if deepest.number.Cmp(num) == -1 { + return common.Hash{}, ErrNumGreaterThanHighest + } + + if deepest.number.Cmp(num) == 0 { + return deepest.hash, nil + } + + if bt.root.number.Cmp(num) == 1 { + return common.Hash{}, ErrNumLowerThanRoot + } + + if bt.root.number.Cmp(num) == 0 { + return bt.root.hash, nil + } + + curr := deepest.parent + for { + if curr == nil { + return common.Hash{}, ErrNodeNotFound + } + + if curr.number.Cmp(num) == 0 { + return curr.hash, nil + } + + curr = curr.parent + } +} + +// GetArrivalTime returns the arrival time of a block +func (bt *BlockTree) GetArrivalTime(hash common.Hash) (time.Time, error) { + bt.RLock() + defer bt.RUnlock() + + n, has := bt.nodeCache[hash] + if !has { + return time.Time{}, ErrNodeNotFound + } + + return n.arrivalTime, nil +} + // DeepCopy returns a copy of the BlockTree func (bt *BlockTree) DeepCopy() *BlockTree { bt.RLock() defer bt.RUnlock() btCopy := &BlockTree{ - db: bt.db, nodeCache: make(map[Hash]*node), } diff --git a/lib/blocktree/blocktree_test.go b/lib/blocktree/blocktree_test.go index 4dae7c4300..0996ed8d76 100644 --- a/lib/blocktree/blocktree_test.go +++ b/lib/blocktree/blocktree_test.go @@ -18,11 +18,13 @@ package blocktree import ( "bytes" + "fmt" "math/big" + "math/rand" "reflect" "testing" + "time" - database "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/stretchr/testify/require" @@ -35,22 +37,86 @@ var testHeader = &types.Header{ Digest: types.NewDigest(), } -func newBlockTreeFromNode(root *node, db database.Database) *BlockTree { +type testBranch struct { + hash Hash + number *big.Int + arrivalTime int64 +} + +func newBlockTreeFromNode(root *node) *BlockTree { return &BlockTree{ root: root, leaves: newLeafMap(root), - db: db, } } -func createFlatTree(t *testing.T, depth int) (*BlockTree, []common.Hash) { - bt := NewBlockTreeFromRoot(testHeader, nil) +func createTestBlockTree(t *testing.T, header *types.Header, number int) (*BlockTree, []testBranch) { + bt := NewBlockTreeFromRoot(header) + previousHash := header.Hash() + + // branch tree randomly + branches := []testBranch{} + r := *rand.New(rand.NewSource(rand.Int63())) + + at := int64(0) + + // create base tree + for i := 1; i <= number; i++ { + header := &types.Header{ + ParentHash: previousHash, + Number: big.NewInt(int64(i)), + Digest: types.NewDigest(), + } + + hash := header.Hash() + err := bt.AddBlock(header, time.Unix(0, at)) + require.NoError(t, err) + previousHash = hash + + isBranch := r.Intn(2) + if isBranch == 1 { + branches = append(branches, testBranch{ + hash: hash, + number: bt.getNode(hash).number, + arrivalTime: at, + }) + } + + at += int64(r.Intn(8)) + } + + // create tree branches + for _, branch := range branches { + at := branch.arrivalTime + previousHash = branch.hash + + for i := int(branch.number.Uint64()); i <= number; i++ { + header := &types.Header{ + ParentHash: previousHash, + Number: big.NewInt(int64(i) + 1), + StateRoot: common.Hash{0x1}, + Digest: types.NewDigest(), + } + + hash := header.Hash() + err := bt.AddBlock(header, time.Unix(0, at)) + require.NoError(t, err) + previousHash = hash + at += int64(r.Intn(8)) + } + } + + return bt, branches +} + +func createFlatTree(t *testing.T, number int) (*BlockTree, []common.Hash) { + bt := NewBlockTreeFromRoot(testHeader) require.NotNil(t, bt) previousHash := bt.root.hash hashes := []common.Hash{bt.root.hash} - for i := 1; i <= depth; i++ { + for i := 1; i <= number; i++ { header := &types.Header{ ParentHash: previousHash, Number: big.NewInt(int64(i)), @@ -60,7 +126,7 @@ func createFlatTree(t *testing.T, depth int) (*BlockTree, []common.Hash) { hash := header.Hash() hashes = append(hashes, hash) - err := bt.AddBlock(header, 0) + err := bt.AddBlock(header, time.Unix(0, 0)) require.Nil(t, err) previousHash = hash } @@ -73,7 +139,7 @@ func TestNewBlockTreeFromNode(t *testing.T) { var branches []testBranch for { - bt, branches = createTestBlockTree(testHeader, 5, nil) + bt, branches = createTestBlockTree(t, testHeader, 5) if len(branches) > 0 && len(bt.getNode(branches[0].hash).children) > 0 { break } @@ -82,7 +148,7 @@ func TestNewBlockTreeFromNode(t *testing.T) { testNode := bt.getNode(branches[0].hash).children[0] leaves := testNode.getLeaves(nil) - newBt := newBlockTreeFromNode(testNode, nil) + newBt := newBlockTreeFromNode(testNode) require.ElementsMatch(t, leaves, newBt.leaves.nodes()) } @@ -105,11 +171,11 @@ func TestBlockTree_AddBlock(t *testing.T) { header := &types.Header{ ParentHash: hashes[1], - Number: big.NewInt(1), + Number: big.NewInt(2), } hash := header.Hash() - err := bt.AddBlock(header, 0) + err := bt.AddBlock(header, time.Unix(0, 0)) require.Nil(t, err) node := bt.getNode(hash) @@ -126,7 +192,7 @@ func TestBlockTree_AddBlock(t *testing.T) { } func TestNode_isDecendantOf(t *testing.T) { - // Create tree with depth 4 (with 4 nodes) + // Create tree with number 4 (with 4 nodes) bt, hashes := createFlatTree(t, 4) // Check leaf is descendant of root @@ -151,7 +217,7 @@ func TestBlockTree_LongestPath(t *testing.T) { } header.Hash() - err := bt.AddBlock(header, 0) + err := bt.AddBlock(header, time.Unix(0, 0)) require.NotNil(t, err) longestPath := bt.longestPath() @@ -175,7 +241,7 @@ func TestBlockTree_Subchain(t *testing.T) { } extraBlock.Hash() - err := bt.AddBlock(extraBlock, 0) + err := bt.AddBlock(extraBlock, time.Unix(0, 0)) require.NotNil(t, err) subChain, err := bt.subChain(hashes[1], hashes[3]) @@ -191,22 +257,22 @@ func TestBlockTree_Subchain(t *testing.T) { } func TestBlockTree_DeepestLeaf(t *testing.T) { - arrivalTime := uint64(256) + arrivalTime := int64(256) var expected Hash - bt, _ := createTestBlockTree(testHeader, 8, nil) + bt, _ := createTestBlockTree(t, testHeader, 8) deepest := big.NewInt(0) for leaf, node := range bt.leaves.toMap() { - node.arrivalTime = arrivalTime + node.arrivalTime = time.Unix(arrivalTime, 0) arrivalTime-- - if node.depth.Cmp(deepest) >= 0 { - deepest = node.depth + if node.number.Cmp(deepest) >= 0 { + deepest = node.number expected = leaf } - t.Logf("leaf=%s depth=%d arrivalTime=%d", leaf, node.depth, node.arrivalTime) + t.Logf("leaf=%s number=%d arrivalTime=%s", leaf, node.number, node.arrivalTime) } deepestLeaf := bt.deepestLeaf() @@ -216,66 +282,46 @@ func TestBlockTree_DeepestLeaf(t *testing.T) { } func TestBlockTree_GetNode(t *testing.T) { - bt, branches := createTestBlockTree(testHeader, 16, nil) - - for _, branch := range branches { - header := &types.Header{ - ParentHash: branch.hash, - Number: branch.depth, - StateRoot: Hash{0x1}, - } - - err := bt.AddBlock(header, 0) - require.Nil(t, err) - } -} - -func TestBlockTree_GetNodeCache(t *testing.T) { - bt, branches := createTestBlockTree(testHeader, 16, nil) + bt, branches := createTestBlockTree(t, testHeader, 16) for _, branch := range branches { header := &types.Header{ ParentHash: branch.hash, - Number: branch.depth, - StateRoot: Hash{0x1}, + Number: big.NewInt(0).Add(branch.number, big.NewInt(1)), + StateRoot: Hash{0x2}, } - err := bt.AddBlock(header, 0) - require.Nil(t, err) + err := bt.AddBlock(header, time.Unix(0, 0)) + require.NoError(t, err) } block := bt.getNode(branches[0].hash) cachedBlock, ok := bt.nodeCache[block.hash] - require.True(t, len(bt.nodeCache) > 0) require.True(t, ok) require.NotNil(t, cachedBlock) require.Equal(t, cachedBlock, block) - } -func TestBlockTree_GetAllBlocksAtDepth(t *testing.T) { - bt, _ := createTestBlockTree(testHeader, 8, nil) - hashes := bt.root.getNodesWithDepth(big.NewInt(10), []common.Hash{}) +func TestBlockTree_GetAllBlocksAtNumber(t *testing.T) { + bt, _ := createTestBlockTree(t, testHeader, 8) + hashes := bt.root.getNodesWithNumber(big.NewInt(10), []common.Hash{}) expected := []common.Hash{} - - if !reflect.DeepEqual(hashes, expected) { - t.Fatalf("Fail: expected empty array") - } + require.Equal(t, expected, hashes) // create one-path tree - btDepth := 8 - desiredDepth := 6 - bt, btHashes := createFlatTree(t, btDepth) + btNumber := 8 + desiredNumber := 6 + bt, btHashes := createFlatTree(t, btNumber) - expected = []common.Hash{btHashes[desiredDepth]} + expected = []common.Hash{btHashes[desiredNumber]} // add branch previousHash := btHashes[4] - for i := 4; i <= btDepth; i++ { + for i := 4; i <= btNumber; i++ { digest := types.NewDigest() err := digest.Add(types.ConsensusDigest{ ConsensusEngineID: types.BabeEngineID, @@ -284,15 +330,16 @@ func TestBlockTree_GetAllBlocksAtDepth(t *testing.T) { require.NoError(t, err) header := &types.Header{ ParentHash: previousHash, - Number: big.NewInt(int64(i)), + Number: big.NewInt(int64(i) + 1), Digest: digest, } hash := header.Hash() - bt.AddBlock(header, 0) + err = bt.AddBlock(header, time.Unix(0, 0)) + require.NoError(t, err) previousHash = hash - if i == desiredDepth-1 { + if i == desiredNumber-1 { expected = append(expected, hash) } } @@ -300,7 +347,7 @@ func TestBlockTree_GetAllBlocksAtDepth(t *testing.T) { // add another branch previousHash = btHashes[2] - for i := 2; i <= btDepth; i++ { + for i := 2; i <= btNumber; i++ { digest := types.NewDigest() err := digest.Add(types.SealDigest{ ConsensusEngineID: types.BabeEngineID, @@ -309,28 +356,28 @@ func TestBlockTree_GetAllBlocksAtDepth(t *testing.T) { require.NoError(t, err) header := &types.Header{ ParentHash: previousHash, - Number: big.NewInt(int64(i)), + Number: big.NewInt(int64(i) + 1), Digest: digest, } hash := header.Hash() - bt.AddBlock(header, 0) + err = bt.AddBlock(header, time.Unix(0, 0)) + require.NoError(t, err) previousHash = hash - if i == desiredDepth-1 { + if i == desiredNumber-1 { expected = append(expected, hash) } } - hashes = bt.root.getNodesWithDepth(big.NewInt(int64(desiredDepth)), []common.Hash{}) - + hashes = bt.root.getNodesWithNumber(big.NewInt(int64(desiredNumber)), []common.Hash{}) if !reflect.DeepEqual(hashes, expected) { t.Fatalf("Fail: did not get all expected hashes got %v expected %v", hashes, expected) } } func TestBlockTree_IsDecendantOf(t *testing.T) { - // Create tree with depth 4 (with 4 nodes) + // Create tree with number 4 (with 4 nodes) bt, hashes := createFlatTree(t, 4) isDescendant, err := bt.IsDescendantOf(bt.root.hash, hashes[3]) @@ -348,7 +395,7 @@ func TestBlockTree_HighestCommonAncestor(t *testing.T) { var branches []testBranch for { - bt, branches = createTestBlockTree(testHeader, 8, nil) + bt, branches = createTestBlockTree(t, testHeader, 8) leaves = bt.Leaves() if len(leaves) == 2 { break @@ -366,7 +413,7 @@ func TestBlockTree_HighestCommonAncestor(t *testing.T) { } func TestBlockTree_HighestCommonAncestor_SameNode(t *testing.T) { - bt, _ := createTestBlockTree(testHeader, 8, nil) + bt, _ := createTestBlockTree(t, testHeader, 8) leaves := bt.Leaves() a := leaves[0] @@ -377,7 +424,7 @@ func TestBlockTree_HighestCommonAncestor_SameNode(t *testing.T) { } func TestBlockTree_HighestCommonAncestor_SameChain(t *testing.T) { - bt, _ := createTestBlockTree(testHeader, 8, nil) + bt, _ := createTestBlockTree(t, testHeader, 8) leaves := bt.Leaves() a := leaves[0] @@ -394,7 +441,7 @@ func TestBlockTree_Prune(t *testing.T) { var branches []testBranch for { - bt, branches = createTestBlockTree(testHeader, 5, nil) + bt, branches = createTestBlockTree(t, testHeader, 5) if len(branches) > 0 && len(bt.getNode(branches[0].hash).children) > 1 { break } @@ -429,7 +476,7 @@ func TestBlockTree_PruneCache(t *testing.T) { var branches []testBranch for { - bt, branches = createTestBlockTree(testHeader, 5, nil) + bt, branches = createTestBlockTree(t, testHeader, 5) if len(branches) > 0 && len(bt.getNode(branches[0].hash).children) > 1 { break } @@ -445,15 +492,33 @@ func TestBlockTree_PruneCache(t *testing.T) { require.False(t, ok) require.Nil(t, block) } +} +func TestBlockTree_GetHashByNumber(t *testing.T) { + bt, _ := createTestBlockTree(t, testHeader, 8) + best := bt.DeepestBlockHash() + bn := bt.nodeCache[best] + + for i := int64(0); i < bn.number.Int64(); i++ { + hash, err := bt.GetHashByNumber(big.NewInt(i)) + require.NoError(t, err) + require.Equal(t, big.NewInt(i), bt.nodeCache[hash].number) + desc, err := bt.IsDescendantOf(hash, best) + require.NoError(t, err) + require.True(t, desc, fmt.Sprintf("index %d failed, got hash=%s", i, hash)) + } + + _, err := bt.GetHashByNumber(big.NewInt(-1)) + require.Error(t, err) + + _, err = bt.GetHashByNumber(big.NewInt(0).Add(bn.number, big.NewInt(1))) + require.Error(t, err) } func TestBlockTree_DeepCopy(t *testing.T) { bt, _ := createFlatTree(t, 8) btCopy := bt.DeepCopy() - - require.Equal(t, bt.db, btCopy.db) for hash := range bt.nodeCache { b, ok := btCopy.nodeCache[hash] b2 := bt.nodeCache[hash] @@ -464,8 +529,9 @@ func TestBlockTree_DeepCopy(t *testing.T) { require.True(t, equalNodeValue(b, b2)) } + require.True(t, equalNodeValue(bt.root, btCopy.root), "BlockTree heads not equal") - require.True(t, equalLeave(bt.leaves, btCopy.leaves), "BlockTree leaves not equal") + require.True(t, equalLeaves(bt.leaves, btCopy.leaves), "BlockTree leaves not equal") btCopy.root = &node{} require.NotEqual(t, bt.root, btCopy.root) @@ -475,7 +541,7 @@ func equalNodeValue(nd *node, ndCopy *node) bool { if nd.hash != ndCopy.hash { return false } - if nd.depth.Cmp(ndCopy.depth) != 0 { + if nd.number.Cmp(ndCopy.number) != 0 { return false } if nd.arrivalTime != ndCopy.arrivalTime { @@ -490,13 +556,13 @@ func equalNodeValue(nd *node, ndCopy *node) bool { if nd.parent.arrivalTime != ndCopy.parent.arrivalTime { return false } - if nd.parent.depth.Cmp(ndCopy.parent.depth) != 0 { + if nd.parent.number.Cmp(ndCopy.parent.number) != 0 { return false } return true } -func equalLeave(lm *leafMap, lmCopy *leafMap) bool { +func equalLeaves(lm *leafMap, lmCopy *leafMap) bool { lmm := lm.toMap() lmCopyM := lmCopy.toMap() for key, val := range lmm { @@ -505,23 +571,3 @@ func equalLeave(lm *leafMap, lmCopy *leafMap) bool { } return true } - -func TestBlockTree_Rewind(t *testing.T) { - var bt *BlockTree - var branches []testBranch - - rewind := 6 - - for { - bt, branches = createTestBlockTree(testHeader, 12, nil) - if len(branches) > 0 && len(bt.getNode(branches[0].hash).children) > 1 { - break - } - } - - start := bt.leaves.deepestLeaf() - - bt.Rewind(rewind) - deepest := bt.leaves.deepestLeaf() - require.Equal(t, start.depth.Int64()-int64(rewind), deepest.depth.Int64()) -} diff --git a/lib/blocktree/database.go b/lib/blocktree/database.go deleted file mode 100644 index 2a3f8e0279..0000000000 --- a/lib/blocktree/database.go +++ /dev/null @@ -1,139 +0,0 @@ -package blocktree - -import ( - "bytes" - "encoding/binary" - "io" - "math/big" - - "github.com/ChainSafe/gossamer/lib/common" -) - -// Store stores the blocktree in the underlying db -func (bt *BlockTree) Store() error { - if bt.db == nil { - return ErrNilDatabase - } - - enc, err := bt.Encode() - if err != nil { - return err - } - - return bt.db.Put(common.BlockTreeKey, enc) -} - -// Load loads the blocktree from the underlying db -func (bt *BlockTree) Load() error { - if bt.db == nil { - return ErrNilDatabase - } - - enc, err := bt.db.Get(common.BlockTreeKey) - if err != nil { - return err - } - - return bt.Decode(enc) -} - -// Encode recursively encodes the block tree -// enc(node) = [32B block hash + 8B arrival time + 8B num children n] | enc(children[0]) | ... | enc(children[n-1]) -func (bt *BlockTree) Encode() ([]byte, error) { - return encodeRecursive(bt.root, []byte{}) -} - -// encode recursively encodes the blocktree by depth-first traversal -func encodeRecursive(n *node, enc []byte) ([]byte, error) { - if n == nil { - return enc, nil - } - - // encode hash and arrival time - enc = append(enc, n.hash[:]...) - buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, n.arrivalTime) - enc = append(enc, buf...) - - binary.LittleEndian.PutUint64(buf, uint64(len(n.children))) - enc = append(enc, buf...) - - var err error - for _, child := range n.children { - enc, err = encodeRecursive(child, enc) - if err != nil { - return nil, err - } - } - - return enc, nil -} - -// Decode recursively decodes an encoded block tree -func (bt *BlockTree) Decode(in []byte) error { - r := &bytes.Buffer{} - _, err := r.Write(in) - if err != nil { - return err - } - - hash, err := common.ReadHash(r) - if err != nil { - return err - } - arrivalTime, err := common.ReadUint64(r) - if err != nil { - return err - } - numChildren, err := common.ReadUint64(r) - if err != nil { - return err - } - - bt.root = &node{ - hash: hash, - parent: nil, - children: make([]*node, numChildren), - depth: big.NewInt(0), - arrivalTime: arrivalTime, - } - - bt.leaves = newLeafMap(bt.root) - - return bt.decodeRecursive(r, bt.root) -} - -// decode recursively decodes the blocktree -func (bt *BlockTree) decodeRecursive(r io.Reader, parent *node) error { - for i := range parent.children { - hash, err := common.ReadHash(r) - if err != nil { - return err - } - arrivalTime, err := common.ReadUint64(r) - if err != nil { - return err - } - numChildren, err := common.ReadUint64(r) - if err != nil { - return err - } - - parent.children[i] = &node{ - hash: hash, - parent: parent, - children: make([]*node, numChildren), - depth: big.NewInt(0).Add(parent.depth, big.NewInt(1)), - arrivalTime: arrivalTime, - } - - bt.leaves.replace(parent, parent.children[i]) - - err = bt.decodeRecursive(r, parent.children[i]) - if err != nil { - return err - } - } - - return nil -} diff --git a/lib/blocktree/database_test.go b/lib/blocktree/database_test.go deleted file mode 100644 index 1237263053..0000000000 --- a/lib/blocktree/database_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package blocktree - -import ( - "io/ioutil" - "math/big" - "math/rand" - "reflect" - "testing" - - "github.com/ChainSafe/chaindb" - "github.com/ChainSafe/gossamer/dot/types" - "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/utils" - - "github.com/stretchr/testify/require" -) - -type testBranch struct { - hash Hash - depth *big.Int -} - -func createTestBlockTree(header *types.Header, depth int, db chaindb.Database) (*BlockTree, []testBranch) { - bt := NewBlockTreeFromRoot(header, db) - previousHash := header.Hash() - - // branch tree randomly - branches := []testBranch{} - r := *rand.New(rand.NewSource(rand.Int63())) - - // create base tree - for i := 1; i <= depth; i++ { - header := &types.Header{ - ParentHash: previousHash, - Number: big.NewInt(int64(i)), - Digest: types.NewDigest(), - } - - hash := header.Hash() - bt.AddBlock(header, 0) - previousHash = hash - - isBranch := r.Intn(2) - if isBranch == 1 { - branches = append(branches, testBranch{ - hash: hash, - depth: bt.getNode(hash).depth, - }) - } - } - - // create tree branches - for _, branch := range branches { - previousHash = branch.hash - - for i := int(branch.depth.Uint64()); i <= depth; i++ { - digest := types.NewDigest() - err := digest.Add(types.ConsensusDigest{ - ConsensusEngineID: types.BabeEngineID, - Data: common.MustHexToBytes("0x0118ca239392960473fe1bc65f94ee27d890a49c1b200c006ff5dcc525330ecc16770100000000000000b46f01874ce7abbb5220e8fd89bede0adad14c73039d91e28e881823433e723f0100000000000000d684d9176d6eb69887540c9a89fa6097adea82fc4b0ff26d1062b488f352e179010000000000000068195a71bdde49117a616424bdc60a1733e96acb1da5aeab5d268cf2a572e94101000000000000001a0575ef4ae24bdfd31f4cb5bd61239ae67c12d4e64ae51ac756044aa6ad8200010000000000000018168f2aad0081a25728961ee00627cfe35e39833c805016632bf7c14da5800901000000000000000000000000000000000000000000000000000000000000000000000000000000"), - }) - if err != nil { - return nil, nil - } - header := &types.Header{ - ParentHash: previousHash, - Number: big.NewInt(int64(i)), - Digest: digest, - } - - hash := header.Hash() - bt.AddBlock(header, 0) - previousHash = hash - } - } - - return bt, branches -} - -func TestStoreBlockTree(t *testing.T) { - db := newInMemoryDB(t) - bt, _ := createTestBlockTree(testHeader, 10, db) - - err := bt.Store() - require.NoError(t, err) - - resBt := NewBlockTreeFromRoot(testHeader, db) - err = resBt.Load() - require.NoError(t, err) - - if !reflect.DeepEqual(bt.root, resBt.root) { - t.Fatalf("Fail: got %v expected %v", resBt, bt) - } - - btLeafMap := bt.leaves.toMap() - resLeafMap := bt.leaves.toMap() - if !reflect.DeepEqual(btLeafMap, resLeafMap) { - t.Fatalf("Fail: got %v expected %v", btLeafMap, resLeafMap) - } -} -func newInMemoryDB(t *testing.T) chaindb.Database { - testDatadirPath, err := ioutil.TempDir("/tmp", "test-datadir-*") - require.NoError(t, err) - - db, err := utils.SetupDatabase(testDatadirPath, true) - require.NoError(t, err) - t.Cleanup(func() { - db.Close() - }) - - return db -} diff --git a/lib/blocktree/errors.go b/lib/blocktree/errors.go index 1c67434258..57a3be78d2 100644 --- a/lib/blocktree/errors.go +++ b/lib/blocktree/errors.go @@ -31,4 +31,12 @@ var ( // ErrFailedToGetRuntime is returned when runtime doesn't exist in blockTree for corresponding block. ErrFailedToGetRuntime = errors.New("failed to get runtime instance") + + // ErrNumGreaterThanHighest is returned when attempting to get a hash by number that is higher than any in the blocktree + ErrNumGreaterThanHighest = errors.New("cannot find node with number greater than highest in blocktree") + + // ErrNumLowerThanRoot is returned when attempting to get a hash by number that is lower than the root node + ErrNumLowerThanRoot = errors.New("cannot find node with number lower than root node") + + errUnexpectedNumber = errors.New("block number is not parent number + 1") ) diff --git a/lib/blocktree/leaves.go b/lib/blocktree/leaves.go index a12e9b0c67..e9e5a71822 100644 --- a/lib/blocktree/leaves.go +++ b/lib/blocktree/leaves.go @@ -65,8 +65,8 @@ func (ls *leafMap) replace(oldNode, newNode *node) { ls.store(newNode.hash, newNode) } -// DeepestLeaf searches the stored leaves to the find the one with the greatest depth. -// If there are two leaves with the same depth, choose the one with the earliest arrival time. +// DeepestLeaf searches the stored leaves to the find the one with the greatest number. +// If there are two leaves with the same number, choose the one with the earliest arrival time. func (ls *leafMap) deepestLeaf() *node { max := big.NewInt(-1) @@ -78,10 +78,10 @@ func (ls *leafMap) deepestLeaf() *node { node := n.(*node) - if max.Cmp(node.depth) < 0 { - max = node.depth + if max.Cmp(node.number) < 0 { + max = node.number dLeaf = node - } else if max.Cmp(node.depth) == 0 && node.arrivalTime < dLeaf.arrivalTime { + } else if max.Cmp(node.number) == 0 && node.arrivalTime.Before(dLeaf.arrivalTime) { dLeaf = node } diff --git a/lib/blocktree/node.go b/lib/blocktree/node.go index 4ec9c89192..1222e14a98 100644 --- a/lib/blocktree/node.go +++ b/lib/blocktree/node.go @@ -19,6 +19,7 @@ package blocktree import ( "fmt" "math/big" + "time" "github.com/ChainSafe/gossamer/lib/common" "github.com/disiqueira/gotree" @@ -29,8 +30,8 @@ type node struct { hash common.Hash // Block hash parent *node // Parent Node children []*node // Nodes of children blocks - depth *big.Int // Depth within the tree - arrivalTime uint64 // Arrival time of the block + number *big.Int // block number + arrivalTime time.Time // Arrival time of the block } // addChild appends Node to n's list of children @@ -38,9 +39,9 @@ func (n *node) addChild(node *node) { n.children = append(n.children, node) } -// string returns stringified hash and depth of node +// string returns stringified hash and number of node func (n *node) string() string { - return fmt.Sprintf("{hash: %s, depth: %s, arrivalTime: %d}", n.hash.String(), n.depth, n.arrivalTime) + return fmt.Sprintf("{hash: %s, number: %s, arrivalTime: %s}", n.hash.String(), n.number, n.arrivalTime) } // createTree adds all the nodes children to the existing printable tree. @@ -68,20 +69,20 @@ func (n *node) getNode(h common.Hash) *node { return nil } -// getNodesWithDepth returns all descendent nodes with the desired depth -func (n *node) getNodesWithDepth(depth *big.Int, hashes []common.Hash) []common.Hash { +// getNodesWithNumber returns all descendent nodes with the desired number +func (n *node) getNodesWithNumber(number *big.Int, hashes []common.Hash) []common.Hash { for _, child := range n.children { - // depth matches - if child.depth.Cmp(depth) == 0 { + // number matches + if child.number.Cmp(number) == 0 { hashes = append(hashes, child.hash) } - // are deeper than desired depth, return - if child.depth.Cmp(depth) > 0 { + // are deeper than desired number, return + if child.number.Cmp(number) > 0 { return hashes } - hashes = child.getNodesWithDepth(depth, hashes) + hashes = child.getNodesWithNumber(number, hashes) } return hashes @@ -190,8 +191,8 @@ func (n *node) deepCopy(parent *node) *node { nCopy.hash = n.hash nCopy.arrivalTime = n.arrivalTime - if n.depth != nil { - nCopy.depth = new(big.Int).Set(n.depth) + if n.number != nil { + nCopy.number = new(big.Int).Set(n.number) } nCopy.children = make([]*node, len(n.children)) diff --git a/lib/blocktree/node_test.go b/lib/blocktree/node_test.go index 96572addae..dfba5e3e02 100644 --- a/lib/blocktree/node_test.go +++ b/lib/blocktree/node_test.go @@ -27,7 +27,7 @@ func TestNode_GetLeaves(t *testing.T) { var branches []testBranch for { - bt, branches = createTestBlockTree(testHeader, 5, nil) + bt, branches = createTestBlockTree(t, testHeader, 5) if len(branches) > 0 && len(bt.getNode(branches[0].hash).children) > 0 { break } @@ -51,7 +51,7 @@ func TestNode_Prune(t *testing.T) { var branches []testBranch for { - bt, branches = createTestBlockTree(testHeader, 5, nil) + bt, branches = createTestBlockTree(t, testHeader, 5) if len(branches) > 0 && len(bt.getNode(branches[0].hash).children) > 1 { break } diff --git a/lib/grandpa/message_handler_test.go b/lib/grandpa/message_handler_test.go index 7b8a429c81..81e59936fe 100644 --- a/lib/grandpa/message_handler_test.go +++ b/lib/grandpa/message_handler_test.go @@ -195,7 +195,7 @@ func TestMessageHandler_NeighbourMessage(t *testing.T) { Version: 1, Round: 2, SetID: 3, - Number: 2, + Number: 1, } _, err := h.handleMessage("", msg) require.NoError(t, err) @@ -209,7 +209,7 @@ func TestMessageHandler_NeighbourMessage(t *testing.T) { block := &types.Block{ Header: types.Header{ - Number: big.NewInt(2), + Number: big.NewInt(1), ParentHash: st.Block.GenesisHash(), Digest: digest, }, @@ -351,7 +351,7 @@ func TestMessageHandler_CatchUpRequest_WithResponse(t *testing.T) { block := &types.Block{ Header: types.Header{ ParentHash: testGenesisHeader.Hash(), - Number: big.NewInt(2), + Number: big.NewInt(1), Digest: digest, }, Body: types.Body{},