diff --git a/dot/parachain/overseer/overseer.go b/dot/parachain/overseer/overseer.go index 2db2326022..f56c9565c1 100644 --- a/dot/parachain/overseer/overseer.go +++ b/dot/parachain/overseer/overseer.go @@ -232,10 +232,6 @@ func (o *Overseer) Stop() error { // close the errorChan to unblock any listeners on the errChan close(o.errChan) - for _, sub := range o.subsystems { - close(sub) - } - // wait for subsystems to stop // TODO: determine reasonable timeout duration for production, currently this is just for testing timedOut := waitTimeout(&o.wg, 500*time.Millisecond) diff --git a/dot/parachain/overseer/overseer_test.go b/dot/parachain/overseer/overseer_test.go index 67448b199e..35017abdd8 100644 --- a/dot/parachain/overseer/overseer_test.go +++ b/dot/parachain/overseer/overseer_test.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "math/rand" + "sync" "sync/atomic" "testing" "time" @@ -88,39 +89,20 @@ func TestHandleBlockEvents(t *testing.T) { var finalizedCounter atomic.Int32 var importedCounter atomic.Int32 + var wg sync.WaitGroup + wg.Add(4) // number of subsystems * 2 + + // mocked subsystems go func() { for { select { case msg := <-overseerToSubSystem1: - if msg == nil { - continue - } - - _, ok := msg.(parachaintypes.BlockFinalizedSignal) - if ok { - finalizedCounter.Add(1) - } - - _, ok = msg.(parachaintypes.ActiveLeavesUpdateSignal) - if ok { - importedCounter.Add(1) - } + go incrementCounters(t, msg, &finalizedCounter, &importedCounter) + wg.Done() case msg := <-overseerToSubSystem2: - if msg == nil { - continue - } - - _, ok := msg.(parachaintypes.BlockFinalizedSignal) - if ok { - finalizedCounter.Add(1) - } - - _, ok = msg.(parachaintypes.ActiveLeavesUpdateSignal) - if ok { - importedCounter.Add(1) - } + go incrementCounters(t, msg, &finalizedCounter, &importedCounter) + wg.Done() } - } }() @@ -129,7 +111,7 @@ func TestHandleBlockEvents(t *testing.T) { finalizedNotifierChan <- &types.FinalisationInfo{} importedBlockNotiferChan <- &types.Block{} - time.Sleep(1000 * time.Millisecond) + wg.Wait() err = overseer.Stop() require.NoError(t, err) @@ -137,3 +119,18 @@ func TestHandleBlockEvents(t *testing.T) { require.Equal(t, int32(2), finalizedCounter.Load()) require.Equal(t, int32(2), importedCounter.Load()) } + +func incrementCounters(t *testing.T, msg any, finalizedCounter *atomic.Int32, importedCounter *atomic.Int32) { + t.Helper() + + if msg == nil { + return + } + + switch msg.(type) { + case parachaintypes.BlockFinalizedSignal: + finalizedCounter.Add(1) + case parachaintypes.ActiveLeavesUpdateSignal: + importedCounter.Add(1) + } +}