Skip to content

Commit

Permalink
xin-168 node stops because dead lock on timeout events (ethereum#73)
Browse files Browse the repository at this point in the history
* fix race condition issue

* add test to prove
  • Loading branch information
liam-lai authored Mar 25, 2022
1 parent ee02538 commit a3d5d82
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 73 deletions.
12 changes: 7 additions & 5 deletions common/countdown/countdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ func (t *CountdownTimer) startTimer(i interface{}) {
return
case <-timer.C:
log.Debug("Countdown time reached!")
err := t.OnTimeoutFn(time.Now(), i)
if err != nil {
log.Error("OnTimeoutFn error", err)
}
log.Debug("Reset timer after timeout reached and OnTimeoutFn processed")
go func() {
err := t.OnTimeoutFn(time.Now(), i)
if err != nil {
log.Error("OnTimeoutFn error", "error", err)
}
log.Debug("OnTimeoutFn processed")
}()
timer.Reset(t.timeoutDuration)
case <-t.resetc:
log.Debug("Reset countdown timer")
Expand Down
15 changes: 15 additions & 0 deletions common/countdown/countdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,18 @@ func TestCountdownShouldBeAbleToStop(t *testing.T) {
countdown.StopTimer()
assert.False(t, countdown.isInitilised())
}

func TestCountdownShouldAvoidDeadlock(t *testing.T) {
var fakeI interface{}
called := make(chan int)
countdown := NewCountDown(5000 * time.Millisecond)
OnTimeoutFn := func(time.Time, interface{}) error {
countdown.Reset(fakeI)
called <- 1
return nil
}

countdown.OnTimeoutFn = OnTimeoutFn
countdown.Reset(fakeI)
<-called
}
35 changes: 5 additions & 30 deletions eth/bft/bft_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ import (
"github.com/XinFinOrg/XDPoSChain/consensus/XDPoS/utils"
"github.com/XinFinOrg/XDPoSChain/core"
"github.com/XinFinOrg/XDPoSChain/log"
lru "github.com/hashicorp/golang-lru"
)

const (
messageLimit = 1024
)

//Define Boradcast Group functions
Expand All @@ -24,11 +19,6 @@ type Bfter struct {
quit chan struct{}
consensus ConsensusFns
broadcast BroadcastFns

// Message Cache
knownVotes *lru.Cache
knownSyncInfos *lru.Cache
knownTimeouts *lru.Cache
}

type ConsensusFns struct {
Expand All @@ -49,16 +39,11 @@ type BroadcastFns struct {
}

func New(broadcasts BroadcastFns, blockChainReader *core.BlockChain) *Bfter {
knownVotes, _ := lru.New(messageLimit)
knownSyncInfos, _ := lru.New(messageLimit)
knownTimeouts, _ := lru.New(messageLimit)

return &Bfter{
quit: make(chan struct{}),
broadcastCh: make(chan interface{}),
broadcast: broadcasts,
knownVotes: knownVotes,
knownSyncInfos: knownSyncInfos,
knownTimeouts: knownTimeouts,
blockChainReader: blockChainReader,
}
}
Expand All @@ -79,10 +64,6 @@ func (b *Bfter) SetConsensusFuns(engine consensus.Engine) {

func (b *Bfter) Vote(vote *utils.Vote) error {
log.Trace("Receive Vote", "hash", vote.Hash().Hex(), "voted block hash", vote.ProposedBlockInfo.Hash.Hex(), "number", vote.ProposedBlockInfo.Number, "round", vote.ProposedBlockInfo.Round)
if exist, _ := b.knownVotes.ContainsOrAdd(vote.Hash(), true); exist {
log.Debug("Discarded vote, known vote", "vote hash", vote.Hash(), "voted block hash", vote.ProposedBlockInfo.Hash.Hex(), "number", vote.ProposedBlockInfo.Number, "round", vote.ProposedBlockInfo.Round)
return nil
}

verified, err := b.consensus.verifyVote(b.blockChainReader, vote)

Expand All @@ -108,11 +89,8 @@ func (b *Bfter) Vote(vote *utils.Vote) error {
return nil
}
func (b *Bfter) Timeout(timeout *utils.Timeout) error {
log.Trace("Receive Timeout", "timeout", timeout)
if exist, _ := b.knownTimeouts.ContainsOrAdd(timeout.Hash(), true); exist {
log.Trace("Discarded Timeout, known Timeout", "Signature", timeout.Signature, "hash", timeout.Hash(), "round", timeout.Round)
return nil
}
log.Debug("Receive Timeout", "timeout", timeout)

verified, err := b.consensus.verifyTimeout(b.blockChainReader, timeout)
if err != nil {
log.Error("Verify BFT Timeout", "timeoutRound", timeout.Round, "timeoutGapNum", timeout.GapNumber, "error", err)
Expand All @@ -135,11 +113,8 @@ func (b *Bfter) Timeout(timeout *utils.Timeout) error {
return nil
}
func (b *Bfter) SyncInfo(syncInfo *utils.SyncInfo) error {
log.Trace("Receive SyncInfo", "syncInfo", syncInfo)
if exist, _ := b.knownSyncInfos.ContainsOrAdd(syncInfo.Hash(), true); exist {
log.Trace("Discarded SyncInfo, known SyncInfo", "hash", syncInfo.Hash())
return nil
}
log.Debug("Receive SyncInfo", "syncInfo", syncInfo)

verified, err := b.consensus.verifySyncInfo(b.blockChainReader, syncInfo)
if err != nil {
log.Error("Verify BFT SyncInfo", "error", err)
Expand Down
34 changes: 0 additions & 34 deletions eth/bft/bft_hander_test.go → eth/bft/bft_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,40 +83,6 @@ func TestSequentialVotes(t *testing.T) {
}
}

// Tests that vote already being retrieved will not be duplicated.
func TestDuplicateVotes(t *testing.T) {
tester := newTester()
verifyCounter := uint32(0)
handlerCounter := uint32(0)
broadcastCounter := uint32(0)
targetVotes := 1

tester.bfter.consensus.verifyVote = func(chain consensus.ChainReader, vote *utils.Vote) (bool, error) {
atomic.AddUint32(&verifyCounter, 1)
return true, nil
}

tester.bfter.consensus.voteHandler = func(chain consensus.ChainReader, vote *utils.Vote) error {
atomic.AddUint32(&handlerCounter, 1)
return nil
}

tester.bfter.broadcast.Vote = func(*utils.Vote) {
atomic.AddUint32(&broadcastCounter, 1)
}

vote := utils.Vote{ProposedBlockInfo: &utils.BlockInfo{}}

// send twice
tester.bfter.Vote(&vote)
tester.bfter.Vote(&vote)

time.Sleep(50 * time.Millisecond)
if int(verifyCounter) != targetVotes || int(handlerCounter) != targetVotes || int(broadcastCounter) != targetVotes {
t.Fatalf("count mismatch: have %v on verify, %v on handler, %v on broadcast, want %v", verifyCounter, handlerCounter, broadcastCounter, targetVotes)
}
}

// Test that avoid boardcast if there is bad vote
func TestNotBoardcastInvalidVote(t *testing.T) {
tester := newTester()
Expand Down
48 changes: 44 additions & 4 deletions eth/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ type ProtocolManager struct {
knownTxs *lru.Cache
knowOrderTxs *lru.Cache
knowLendingTxs *lru.Cache

// V2 messages
knownVotes *lru.Cache
knownSyncInfos *lru.Cache
knownTimeouts *lru.Cache
}

// NewProtocolManagerEx add order pool to protocol
Expand All @@ -127,6 +132,11 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne
knownTxs, _ := lru.New(maxKnownTxs)
knowOrderTxs, _ := lru.New(maxKnownOrderTxs)
knowLendingTxs, _ := lru.New(maxKnownLendingTxs)

knownVotes, _ := lru.New(maxKnownVote)
knownSyncInfos, _ := lru.New(maxKnownSyncInfo)
knownTimeouts, _ := lru.New(maxKnownTimeout)

// Create the protocol manager with the base fields
manager := &ProtocolManager{
networkId: networkID,
Expand All @@ -142,6 +152,9 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne
knownTxs: knownTxs,
knowOrderTxs: knowOrderTxs,
knowLendingTxs: knowLendingTxs,
knownVotes: knownVotes,
knownSyncInfos: knownSyncInfos,
knownTimeouts: knownTimeouts,
orderpool: nil,
lendingpool: nil,
orderTxSub: nil,
Expand Down Expand Up @@ -834,7 +847,14 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
// because peer has 2 address sender and receive, so use p.id to find the right address
p = pm.peers.Peer(p.id)
p.MarkVote(vote.Hash())
pm.bft.Vote(&vote)

exist, _ := pm.knownVotes.ContainsOrAdd(vote.Hash(), true)
if !exist {
go pm.bft.Vote(&vote)
} else {
log.Debug("Discarded vote, known vote", "vote hash", vote.Hash(), "voted block hash", vote.ProposedBlockInfo.Hash.Hex(), "number", vote.ProposedBlockInfo.Number, "round", vote.ProposedBlockInfo.Round)
}

case msg.Code == TimeoutMsg:
var timeout utils.Timeout
if err := msg.Decode(&timeout); err != nil {
Expand All @@ -845,7 +865,15 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
// because peer has 2 address sender and receive, so use p.id to find the right address
p = pm.peers.Peer(p.id)
p.MarkTimeout(timeout.Hash())
pm.bft.Timeout(&timeout)

exist, _ := pm.knownTimeouts.ContainsOrAdd(timeout.Hash(), true)

if !exist {
go pm.bft.Timeout(&timeout)
} else {
log.Trace("Discarded Timeout, known Timeout", "Signature", timeout.Signature, "hash", timeout.Hash(), "round", timeout.Round)
}

case msg.Code == SyncInfoMsg:
var syncInfo utils.SyncInfo
if err := msg.Decode(&syncInfo); err != nil {
Expand All @@ -855,7 +883,13 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
// because peer has 2 address sender and receive, so use p.id to find the right address
p = pm.peers.Peer(p.id)
p.MarkSyncInfo(syncInfo.Hash())
pm.bft.SyncInfo(&syncInfo)

exist, _ := pm.knownSyncInfos.ContainsOrAdd(syncInfo.Hash(), true)
if !exist {
go pm.bft.SyncInfo(&syncInfo)
} else {
log.Trace("Discarded SyncInfo, known SyncInfo", "hash", syncInfo.Hash())
}

default:
return errResp(ErrInvalidMsgCode, "%v", msg.Code)
Expand Down Expand Up @@ -917,9 +951,11 @@ func (pm *ProtocolManager) BroadcastVote(vote *utils.Vote) {
err := peer.SendVote(vote)
if err != nil {
log.Error("[BroadcastVote] Fail to broadcast vote message", "NumberOfPeers", len(peers), "peerId", peer.id, "vote", vote, "Error", err)
log.Error("[BroadcastVote] Remove Peer", "id", peer.id, "version", peer.version)
pm.removePeer(peer.id)
}
}
log.Info("Propagated Vote", "vote hash", vote.Hash(), "voted block hash", vote.ProposedBlockInfo.Hash.Hex(), "number", vote.ProposedBlockInfo.Number, "round", vote.ProposedBlockInfo.Round, "recipients", len(peers))
log.Trace("Propagated Vote", "vote hash", vote.Hash(), "voted block hash", vote.ProposedBlockInfo.Hash.Hex(), "number", vote.ProposedBlockInfo.Number, "round", vote.ProposedBlockInfo.Round, "recipients", len(peers))
}
}

Expand All @@ -933,6 +969,8 @@ func (pm *ProtocolManager) BroadcastTimeout(timeout *utils.Timeout) {
err := peer.SendTimeout(timeout)
if err != nil {
log.Error("[BroadcastTimeout] Fail to broadcast timeout message", "NumberOfPeers", len(peers), "peerId", peer.id, "timeout", timeout, "Error", err)
log.Error("[BroadcastTimeout] Remove Peer", "id", peer.id, "version", peer.version)
pm.removePeer(peer.id)
}
}
log.Trace("Propagated Timeout", "hash", hash, "recipients", len(peers))
Expand All @@ -949,6 +987,8 @@ func (pm *ProtocolManager) BroadcastSyncInfo(syncInfo *utils.SyncInfo) {
err := peer.SendSyncInfo(syncInfo)
if err != nil {
log.Error("[BroadcastSyncInfo] Fail to broadcast syncInfo message", "NumberOfPeers", len(peers), "peerId", peer.id, "syncInfo", syncInfo, "Error", err)
log.Error("[BroadcastSyncInfo] Remove Peer", "id", peer.id, "version", peer.version)
pm.removePeer(peer.id)
}
}
log.Trace("Propagated SyncInfo", "hash", hash, "recipients", len(peers))
Expand Down

0 comments on commit a3d5d82

Please sign in to comment.