diff --git a/consensus/beacon/consensus.go b/consensus/beacon/consensus.go index 35881b05cc4a..274b9cf3603e 100644 --- a/consensus/beacon/consensus.go +++ b/consensus/beacon/consensus.go @@ -95,68 +95,71 @@ func (beacon *Beacon) VerifyHeader(chain consensus.ChainHeaderReader, header *ty return beacon.verifyHeader(chain, header, parent) } -// VerifyHeaders is similar to VerifyHeader, but verifies a batch of headers -// concurrently. The method returns a quit channel to abort the operations and -// a results channel to retrieve the async verifications. -// VerifyHeaders expect the headers to be ordered and continuous. -func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers []*types.Header, seals []bool) (chan<- struct{}, <-chan error) { +// errOut constructs an error channel with prefilled errors inside. +func errOut(n int, err error) chan error { + errs := make(chan error, n) + for i := 0; i < n; i++ { + errs <- err + } + return errs +} + +// splitHeaders splits the provided header batch into two parts according to +// the configured ttd. It requires the parent of header batch along with its +// td are stored correctly in chain. If ttd is not configured yet, all headers +// will be treated legacy PoW headers. +// Note, this function will not verify the header validity but just split them. +func (beacon *Beacon) splitHeaders(chain consensus.ChainHeaderReader, headers []*types.Header) ([]*types.Header, []*types.Header, error) { + // TTD is not defined yet, all headers should be in legacy format. ttd := chain.Config().TerminalTotalDifficulty if ttd == nil { - return beacon.ethone.VerifyHeaders(chain, headers, seals) + return headers, nil, nil } - td := chain.GetTd(headers[0].ParentHash, headers[0].Number.Uint64()-1) - if td == nil { - results := make(chan error, len(headers)) - for i := 0; i < len(headers); i++ { - results <- consensus.ErrUnknownAncestor - } - return make(chan struct{}), results + ptd := chain.GetTd(headers[0].ParentHash, headers[0].Number.Uint64()-1) + if ptd == nil { + return nil, nil, consensus.ErrUnknownAncestor + } + // The entire header batch already crosses the transition. + if ptd.Cmp(ttd) >= 0 { + return nil, headers, nil } - td = new(big.Int).Set(td) var ( - preHeaders []*types.Header + preHeaders = headers postHeaders []*types.Header - preSeals []bool + td = new(big.Int).Set(ptd) + tdPassed bool ) - if td.Cmp(ttd) >= 0 { - postHeaders = headers - } else { - tdPassed := false - preHeaders = headers - for index, header := range headers { - if beacon.IsPoSHeader(header) || tdPassed { - preHeaders = headers[:index] - postHeaders = headers[index:] - preSeals = seals[:index] - break - } - td = td.Add(td, header.Difficulty) - if td.Cmp(ttd) >= 0 { - // This is the last PoW header, it still belongs to - // the preHeaders, so we cannot split+break yet. - tdPassed = true - } + for i, header := range headers { + if tdPassed { + preHeaders = headers[:i] + postHeaders = headers[i:] + break + } + td = td.Add(td, header.Difficulty) + if td.Cmp(ttd) >= 0 { + // This is the last PoW header, it still belongs to + // the preHeaders, so we cannot split+break yet. + tdPassed = true } } + return preHeaders, postHeaders, nil +} + +// VerifyHeaders is similar to VerifyHeader, but verifies a batch of headers +// concurrently. The method returns a quit channel to abort the operations and +// a results channel to retrieve the async verifications. +// VerifyHeaders expect the headers to be ordered and continuous. +func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers []*types.Header, seals []bool) (chan<- struct{}, <-chan error) { + preHeaders, postHeaders, err := beacon.splitHeaders(chain, headers) + if err != nil { + return make(chan struct{}), errOut(len(headers), err) + } if len(postHeaders) == 0 { return beacon.ethone.VerifyHeaders(chain, headers, seals) } if len(preHeaders) == 0 { - // All the headers are pos headers. Verify that the parent block reached total terminal difficulty. - if reached, err := IsTTDReached(chain, headers[0].ParentHash, headers[0].Number.Uint64()-1); !reached { - // TTD not reached for the first block, mark subsequent with invalid terminal block - if err == nil { - err = consensus.ErrInvalidTerminalBlock - } - results := make(chan error, len(headers)) - for i := 0; i < len(headers); i++ { - results <- err - } - return make(chan struct{}), results - } return beacon.verifyHeaders(chain, headers, nil) } - // The transition point exists in the middle, separate the headers // into two batches and apply different verification rules for them. var ( @@ -168,16 +171,9 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [ old, new, out = 0, len(preHeaders), 0 errors = make([]error, len(headers)) done = make([]bool, len(headers)) - oldDone, oldResult = beacon.ethone.VerifyHeaders(chain, preHeaders, preSeals) + oldDone, oldResult = beacon.ethone.VerifyHeaders(chain, preHeaders, seals[:len(preHeaders)]) newDone, newResult = beacon.verifyHeaders(chain, postHeaders, preHeaders[len(preHeaders)-1]) ) - // Verify that pre-merge headers don't overflow the TTD - if index, err := verifyTerminalPoWBlock(chain, preHeaders); err != nil { - // Mark all subsequent pow headers with the error. - for i := index; i < len(preHeaders); i++ { - errors[i], done[i] = err, true - } - } // Collect the results for { for ; done[out]; out++ { @@ -205,33 +201,6 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [ return abort, results } -// verifyTerminalPoWBlock verifies that the preHeaders conform to the specification -// wrt. their total difficulty. -// It expects: -// - preHeaders to be at least 1 element -// - the parent of the header element to be stored in the chain correctly -// - the preHeaders to have a set difficulty -// - the last element to be the terminal block -func verifyTerminalPoWBlock(chain consensus.ChainHeaderReader, preHeaders []*types.Header) (int, error) { - td := chain.GetTd(preHeaders[0].ParentHash, preHeaders[0].Number.Uint64()-1) - if td == nil { - return 0, consensus.ErrUnknownAncestor - } - td = new(big.Int).Set(td) - // Check that all blocks before the last one are below the TTD - for i, head := range preHeaders { - if td.Cmp(chain.Config().TerminalTotalDifficulty) >= 0 { - return i, consensus.ErrInvalidTerminalBlock - } - td.Add(td, head.Difficulty) - } - // Check that the last block is the terminal block - if td.Cmp(chain.Config().TerminalTotalDifficulty) < 0 { - return len(preHeaders) - 1, consensus.ErrInvalidTerminalBlock - } - return 0, nil -} - // VerifyUncles verifies that the given block's uncles conform to the consensus // rules of the Ethereum consensus engine. func (beacon *Beacon) VerifyUncles(chain consensus.ChainReader, block *types.Block) error { @@ -443,11 +412,11 @@ func (beacon *Beacon) SetThreads(threads int) { // IsTTDReached checks if the TotalTerminalDifficulty has been surpassed on the `parentHash` block. // It depends on the parentHash already being stored in the database. // If the parentHash is not stored in the database a UnknownAncestor error is returned. -func IsTTDReached(chain consensus.ChainHeaderReader, parentHash common.Hash, number uint64) (bool, error) { +func IsTTDReached(chain consensus.ChainHeaderReader, parentHash common.Hash, parentNumber uint64) (bool, error) { if chain.Config().TerminalTotalDifficulty == nil { return false, nil } - td := chain.GetTd(parentHash, number) + td := chain.GetTd(parentHash, parentNumber) if td == nil { return false, consensus.ErrUnknownAncestor } diff --git a/consensus/beacon/consensus_test.go b/consensus/beacon/consensus_test.go deleted file mode 100644 index 09c0b27c4256..000000000000 --- a/consensus/beacon/consensus_test.go +++ /dev/null @@ -1,137 +0,0 @@ -package beacon - -import ( - "fmt" - "math/big" - "testing" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/consensus" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/params" -) - -type mockChain struct { - config *params.ChainConfig - tds map[uint64]*big.Int -} - -func newMockChain() *mockChain { - return &mockChain{ - config: new(params.ChainConfig), - tds: make(map[uint64]*big.Int), - } -} - -func (m *mockChain) Config() *params.ChainConfig { - return m.config -} - -func (m *mockChain) CurrentHeader() *types.Header { panic("not implemented") } - -func (m *mockChain) GetHeader(hash common.Hash, number uint64) *types.Header { - panic("not implemented") -} - -func (m *mockChain) GetHeaderByNumber(number uint64) *types.Header { panic("not implemented") } - -func (m *mockChain) GetHeaderByHash(hash common.Hash) *types.Header { panic("not implemented") } - -func (m *mockChain) GetTd(hash common.Hash, number uint64) *big.Int { - num, ok := m.tds[number] - if ok { - return new(big.Int).Set(num) - } - return nil -} - -func TestVerifyTerminalBlock(t *testing.T) { - chain := newMockChain() - chain.tds[0] = big.NewInt(10) - chain.config.TerminalTotalDifficulty = big.NewInt(50) - - tests := []struct { - preHeaders []*types.Header - ttd *big.Int - err error - index int - }{ - // valid ttd - { - preHeaders: []*types.Header{ - {Number: big.NewInt(1), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(2), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(3), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(4), Difficulty: big.NewInt(10)}, - }, - ttd: big.NewInt(50), - }, - // last block doesn't reach ttd - { - preHeaders: []*types.Header{ - {Number: big.NewInt(1), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(2), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(3), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(4), Difficulty: big.NewInt(9)}, - }, - ttd: big.NewInt(50), - err: consensus.ErrInvalidTerminalBlock, - index: 3, - }, - // two blocks reach ttd - { - preHeaders: []*types.Header{ - {Number: big.NewInt(1), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(2), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(3), Difficulty: big.NewInt(20)}, - {Number: big.NewInt(4), Difficulty: big.NewInt(10)}, - }, - ttd: big.NewInt(50), - err: consensus.ErrInvalidTerminalBlock, - index: 3, - }, - // three blocks reach ttd - { - preHeaders: []*types.Header{ - {Number: big.NewInt(1), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(2), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(3), Difficulty: big.NewInt(20)}, - {Number: big.NewInt(4), Difficulty: big.NewInt(10)}, - {Number: big.NewInt(4), Difficulty: big.NewInt(10)}, - }, - ttd: big.NewInt(50), - err: consensus.ErrInvalidTerminalBlock, - index: 3, - }, - // parent reached ttd - { - preHeaders: []*types.Header{ - {Number: big.NewInt(1), Difficulty: big.NewInt(10)}, - }, - ttd: big.NewInt(9), - err: consensus.ErrInvalidTerminalBlock, - index: 0, - }, - // unknown parent - { - preHeaders: []*types.Header{ - {Number: big.NewInt(4), Difficulty: big.NewInt(10)}, - }, - ttd: big.NewInt(9), - err: consensus.ErrUnknownAncestor, - index: 0, - }, - } - - for i, test := range tests { - fmt.Printf("Test: %v\n", i) - chain.config.TerminalTotalDifficulty = test.ttd - index, err := verifyTerminalPoWBlock(chain, test.preHeaders) - if err != test.err { - t.Fatalf("Invalid error encountered, expected %v got %v", test.err, err) - } - if index != test.index { - t.Fatalf("Invalid index, expected %v got %v", test.index, index) - } - } -}