diff --git a/consensus/bor/bor.go b/consensus/bor/bor.go index 5b32263762..cffe692843 100644 --- a/consensus/bor/bor.go +++ b/consensus/bor/bor.go @@ -298,6 +298,14 @@ func (c *Bor) VerifyHeader(chain consensus.ChainHeaderReader, header *types.Head return c.verifyHeader(chain, header, nil) } +func (c *Bor) GetSpanner() Spanner { + return c.spanner +} + +func (c *Bor) SetSpanner(spanner Spanner) { + c.spanner = spanner +} + // VerifyHeaders is similar to VerifyHeader, but verifies a batch of headers. The // method returns a quit channel to abort the operations and a results channel to // retrieve the async verifications (the order is that of the input slice). @@ -454,6 +462,43 @@ func (c *Bor) verifyCascadingFields(chain consensus.ChainHeaderReader, header *t return err } + // Verify the validator list match the local contract + if IsSprintStart(number+1, c.config.CalculateSprint(number)) { + parentBlockNumber := number - 1 + + var newValidators []*valset.Validator + + var err error + + for newValidators, err = c.spanner.GetCurrentValidators(context.Background(), chain.GetHeaderByNumber(parentBlockNumber).Hash(), number+1); err != nil && parentBlockNumber > 0; { + parentBlockNumber-- + } + + if err != nil { + return errUnknownValidators + } + + sort.Sort(valset.ValidatorsByAddress(newValidators)) + + headerVals, err := valset.ParseValidators(header.Extra[extraVanity : len(header.Extra)-extraSeal]) + + if err != nil { + return err + } + + sort.Sort(valset.ValidatorsByAddress(headerVals)) + + if len(newValidators) != len(headerVals) { + return errInvalidSpanValidators + } + + for i, val := range newValidators { + if !bytes.Equal(val.HeaderBytes(), headerVals[i].HeaderBytes()) { + return errInvalidSpanValidators + } + } + } + // verify the validator list in the last sprint block if IsSprintStart(number, c.config.CalculateSprint(number)) { parentValidatorBytes := parent.Extra[extraVanity : len(parent.Extra)-extraSeal] diff --git a/consensus/bor/heimdall/span/spanner.go b/consensus/bor/heimdall/span/spanner.go index e0f2d66c6b..5001b0f738 100644 --- a/consensus/bor/heimdall/span/spanner.go +++ b/consensus/bor/heimdall/span/spanner.go @@ -116,7 +116,7 @@ func (c *ChainSpanner) GetCurrentValidators(ctx context.Context, headerHash comm Data: &msgData, }, blockNr, nil) if err != nil { - panic(err) + return nil, err } var ( diff --git a/tests/bor/bor_test.go b/tests/bor/bor_test.go index 2dc20a915e..e6e8188ce0 100644 --- a/tests/bor/bor_test.go +++ b/tests/bor/bor_test.go @@ -392,12 +392,18 @@ func TestInsertingSpanSizeBlocks(t *testing.T) { currentValidators := []*valset.Validator{valset.NewValidator(addr, 10)} + spanner := getMockedSpanner(t, currentValidators) + _bor.SetSpanner(spanner) + // Insert sprintSize # of blocks so that span is fetched at the start of a new sprint for i := uint64(1); i <= spanSize; i++ { block = buildNextBlock(t, _bor, chain, block, nil, init.genesis.Config.Bor, nil, currentValidators) insertNewBlock(t, chain, block) } + spanner = getMockedSpanner(t, currentSpan.ValidatorSet.Validators) + _bor.SetSpanner(spanner) + validators, err := _bor.GetCurrentValidators(context.Background(), block.Hash(), spanSize) // check validator set at the first block of new span if err != nil { t.Fatalf("%s", err) @@ -427,6 +433,9 @@ func TestFetchStateSyncEvents(t *testing.T) { currentValidators := []*valset.Validator{valset.NewValidator(addr, 10)} + spanner := getMockedSpanner(t, currentValidators) + _bor.SetSpanner(spanner) + // Insert sprintSize # of blocks so that span is fetched at the start of a new sprint for i := uint64(1); i < sprintSize; i++ { if IsSpanEnd(i) { @@ -528,6 +537,9 @@ func TestFetchStateSyncEvents_2(t *testing.T) { currentValidators = []*valset.Validator{valset.NewValidator(addr, 10)} } + spanner := getMockedSpanner(t, currentValidators) + _bor.SetSpanner(spanner) + block = buildNextBlock(t, _bor, chain, block, nil, init.genesis.Config.Bor, nil, currentValidators) insertNewBlock(t, chain, block) } @@ -554,6 +566,9 @@ func TestFetchStateSyncEvents_2(t *testing.T) { currentValidators = []*valset.Validator{valset.NewValidator(addr, 10)} } + spanner := getMockedSpanner(t, currentValidators) + _bor.SetSpanner(spanner) + block = buildNextBlock(t, _bor, chain, block, nil, init.genesis.Config.Bor, nil, res.Result.ValidatorSet.Validators) insertNewBlock(t, chain, block) } @@ -580,6 +595,8 @@ func TestOutOfTurnSigning(t *testing.T) { h.EXPECT().Close().AnyTimes() + spanner := getMockedSpanner(t, heimdallSpan.ValidatorSet.Validators) + _bor.SetSpanner(spanner) _bor.SetHeimdallClient(h) db := init.ethereum.ChainDb() @@ -1082,6 +1099,9 @@ func TestJaipurFork(t *testing.T) { res, _ := loadSpanFromFile(t) + spanner := getMockedSpanner(t, res.Result.ValidatorSet.Validators) + _bor.SetSpanner(spanner) + for i := uint64(1); i < sprintSize; i++ { block = buildNextBlock(t, _bor, chain, block, nil, init.genesis.Config.Bor, nil, res.Result.ValidatorSet.Validators) insertNewBlock(t, chain, block) diff --git a/tests/bor/helper.go b/tests/bor/helper.go index e28076a3b1..e1b83858b3 100644 --- a/tests/bor/helper.go +++ b/tests/bor/helper.go @@ -352,6 +352,16 @@ func getMockedHeimdallClient(t *testing.T, heimdallSpan *span.HeimdallSpan) (*mo return h, ctrl } +func getMockedSpanner(t *testing.T, validators []*valset.Validator) *bor.MockSpanner { + t.Helper() + + spanner := bor.NewMockSpanner(gomock.NewController(t)) + spanner.EXPECT().GetCurrentValidators(gomock.Any(), gomock.Any(), gomock.Any()).Return(validators, nil).AnyTimes() + spanner.EXPECT().GetCurrentSpan(gomock.Any(), gomock.Any()).Return(&span.Span{0, 0, 0}, nil).AnyTimes() + spanner.EXPECT().CommitSpan(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + return spanner +} + func generateFakeStateSyncEvents(sample *clerk.EventRecordWithTime, count int) []*clerk.EventRecordWithTime { events := make([]*clerk.EventRecordWithTime, count) event := *sample