Skip to content

Commit

Permalink
Remove ctx field from Generator and Certifier (#5180)
Browse files Browse the repository at this point in the history
## Motivation
It's considered an anti-pattern to store a `Context` in a struct. It leads to code that is difficult to reason about.

## Changes
Removed `ctx` field from `Generator` and `Certifier` structs. Instead, their `Start()` methods take a context. The `Stop()` is handled by cancelling the context derived from the context passed to `Start()`.

## Test Plan
Existing tests should pass
  • Loading branch information
poszu committed Oct 20, 2023
1 parent 21c602e commit d13d18f
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 64 deletions.
45 changes: 17 additions & 28 deletions blocks/certifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"

"golang.org/x/sync/errgroup"

Expand Down Expand Up @@ -50,13 +51,6 @@ func defaultCertConfig() CertConfig {
// CertifierOpt for configuring Certifier.
type CertifierOpt func(*Certifier)

// WithCertContext modifies parent context for Certifier.
func WithCertContext(ctx context.Context) CertifierOpt {
return func(c *Certifier) {
c.ctx = ctx
}
}

// WithCertConfig defines cfg for Certifier.
func WithCertConfig(cfg CertConfig) CertifierOpt {
return func(c *Certifier) {
Expand All @@ -83,8 +77,9 @@ type Certifier struct {
cfg CertConfig
once sync.Once
eg errgroup.Group
ctx context.Context
cancel func()

stop func()
stopped atomic.Bool

db *datastore.CachedDB
oracle hare.Rolacle
Expand Down Expand Up @@ -119,7 +114,6 @@ func NewCertifier(
c := &Certifier{
logger: log.NewNop(),
cfg: defaultCertConfig(),
ctx: context.Background(),
db: db,
oracle: o,
nodeID: n,
Expand All @@ -137,44 +131,39 @@ func NewCertifier(
}
c.collector = newCollector(c)

c.ctx, c.cancel = context.WithCancel(c.ctx)
return c
}

// Start starts the background goroutine for periodic pruning.
func (c *Certifier) Start() {
func (c *Certifier) Start(ctx context.Context) {
c.once.Do(func() {
ctx, c.stop = context.WithCancel(ctx)
c.eg.Go(func() error {
return c.run()
return c.run(ctx)
})
})
}

// Stop stops the outstanding goroutines.
func (c *Certifier) Stop() {
c.cancel()
c.stopped.Store(true)
if c.stop == nil {
return // not started
}
c.stop()
err := c.eg.Wait()
if err != nil && !errors.Is(err, context.Canceled) {
c.logger.With().Error("blockGen task failure", log.Err(err))
}
}

func (c *Certifier) isShuttingDown() bool {
select {
case <-c.ctx.Done():
return true
default:
return false
c.logger.With().Error("certifier task failure", log.Err(err))
}
}

func (c *Certifier) run() error {
func (c *Certifier) run(ctx context.Context) error {
for layer := c.layerClock.CurrentLayer(); ; layer = layer.Add(1) {
select {
case <-c.layerClock.AwaitLayer(layer):
c.prune()
case <-c.ctx.Done():
return fmt.Errorf("context done: %w", c.ctx.Err())
case <-ctx.Done():
return fmt.Errorf("context done: %w", ctx.Err())
}
}
}
Expand Down Expand Up @@ -335,7 +324,7 @@ func (c *Certifier) HandleCertifyMessage(ctx context.Context, peer p2p.Peer, dat

// HandleCertifyMessage is the gossip receiver for certify message.
func (c *Certifier) handleCertifyMessage(ctx context.Context, _ p2p.Peer, data []byte) error {
if c.isShuttingDown() {
if c.stopped.Load() {
return errors.New("certifier shutting down")
}

Expand Down
6 changes: 3 additions & 3 deletions blocks/certifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ func TestStartStop(t *testing.T) {
func(_ types.LayerID) <-chan struct{} {
return ch
}).AnyTimes()
tc.Start()
tc.Start(context.Background())
ch <- struct{}{}
tc.Start() // calling Start() for the second time have no effect
tc.Start(context.Background()) // calling Start() for the second time have no effect
tc.Stop()
}

Expand Down Expand Up @@ -591,7 +591,7 @@ func Test_OldLayersPruned(t *testing.T) {
}
return ch
}).AnyTimes()
tc.Start()
tc.Start(context.Background())
ch <- struct{}{} // for current
ch <- struct{}{} // for current+1
<-pruned
Expand Down
26 changes: 9 additions & 17 deletions blocks/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ type Generator struct {
cfg Config
once sync.Once
eg errgroup.Group
ctx context.Context
cancel func()
stop func()

cdb *datastore.CachedDB
msh meshProvider
Expand Down Expand Up @@ -60,13 +59,6 @@ func defaultConfig() Config {
// GeneratorOpt for configuring Generator.
type GeneratorOpt func(*Generator)

// WithContext modifies default context.
func WithContext(ctx context.Context) GeneratorOpt {
return func(g *Generator) {
g.ctx = ctx
}
}

// WithConfig defines cfg for Generator.
func WithConfig(cfg Config) GeneratorOpt {
return func(g *Generator) {
Expand Down Expand Up @@ -101,7 +93,6 @@ func NewGenerator(
g := &Generator{
logger: log.NewNop(),
cfg: defaultConfig(),
ctx: context.Background(),
cdb: cdb,
msh: m,
executor: exec,
Expand All @@ -113,34 +104,35 @@ func NewGenerator(
for _, opt := range opts {
opt(g)
}
g.ctx, g.cancel = context.WithCancel(g.ctx)

return g
}

// Start starts listening to hare output.
func (g *Generator) Start() {
func (g *Generator) Start(ctx context.Context) {
g.once.Do(func() {
ctx, g.stop = context.WithCancel(ctx)
g.eg.Go(func() error {
return g.run()
return g.run(ctx)
})
})
}

// Stop stops listening to hare output.
func (g *Generator) Stop() {
g.cancel()
g.stop()
err := g.eg.Wait()
if err != nil && !errors.Is(err, context.Canceled) {
g.logger.With().Error("blockGen task failure", log.Err(err))
}
}

func (g *Generator) run() error {
func (g *Generator) run(ctx context.Context) error {
var maxLayer types.LayerID
for {
select {
case <-g.ctx.Done():
return fmt.Errorf("context done: %w", g.ctx.Err())
case <-ctx.Done():
return fmt.Errorf("context done: %w", ctx.Err())
case out := <-g.hareCh:
g.logger.With().Debug("received hare output",
log.Context(out.Ctx),
Expand Down
24 changes: 12 additions & 12 deletions blocks/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ func checkRewards(t *testing.T, atxs []*types.ActivationTx, expWeightPer *big.Ra

func Test_StartStop(t *testing.T) {
tg := createTestGenerator(t)
tg.Start()
tg.Start() // start for the second time is ok.
tg.Start(context.Background())
tg.Start(context.Background()) // start for the second time is ok.
tg.Stop()
}

Expand All @@ -252,7 +252,7 @@ func genData(t *testing.T, cdb *datastore.CachedDB, lid types.LayerID, optimisti

func Test_SerialExecution(t *testing.T) {
tg := createTestGenerator(t)
tg.Start()
tg.Start(context.Background())
tg.mockFetch.EXPECT().GetProposals(gomock.Any(), gomock.Any()).AnyTimes()
layerID := types.GetEffectiveGenesis().Add(100)
require.NoError(t, layers.SetApplied(tg.cdb, layerID-1, types.EmptyBlockID))
Expand Down Expand Up @@ -399,7 +399,7 @@ func Test_run(t *testing.T) {
return nil
})
tg.mockPatrol.EXPECT().CompleteHare(layerID)
tg.Start()
tg.Start(context.Background())
tg.hareCh <- hare.LayerOutput{Ctx: context.Background(), Layer: layerID, Proposals: pids}
require.Eventually(t, func() bool { return len(tg.hareCh) == 0 }, time.Second, 100*time.Millisecond)
tg.Stop()
Expand All @@ -411,7 +411,7 @@ func Test_processHareOutput_EmptyOutput(t *testing.T) {
tg := createTestGenerator(t)
layerID := types.GetEffectiveGenesis().Add(100)
require.NoError(t, layers.SetApplied(tg.cdb, layerID-1, types.EmptyBlockID))
tg.Start()
tg.Start(context.Background())
tg.mockCert.EXPECT().RegisterForCert(gomock.Any(), layerID, types.EmptyBlockID)
tg.mockCert.EXPECT().CertifyIfEligible(gomock.Any(), gomock.Any(), layerID, types.EmptyBlockID)
tg.mockMesh.EXPECT().ProcessLayerPerHareOutput(gomock.Any(), layerID, types.EmptyBlockID, false)
Expand All @@ -425,7 +425,7 @@ func Test_run_FetchFailed(t *testing.T) {
tg := createTestGenerator(t)
layerID := types.GetEffectiveGenesis().Add(100)
require.NoError(t, layers.SetApplied(tg.cdb, layerID-1, types.EmptyBlockID))
tg.Start()
tg.Start(context.Background())
pids := []types.ProposalID{{1}, {2}, {3}}
tg.mockFetch.EXPECT().GetProposals(gomock.Any(), pids).DoAndReturn(
func(_ context.Context, _ []types.ProposalID) error {
Expand All @@ -441,7 +441,7 @@ func Test_run_DiffHasFromConsensus(t *testing.T) {
tg := createTestGenerator(t)
layerID := types.GetEffectiveGenesis().Add(100)
require.NoError(t, layers.SetApplied(tg.cdb, layerID-1, types.EmptyBlockID))
tg.Start()
tg.Start(context.Background())

// create multiple proposals with overlapping TXs
txIDs := createAndSaveTxs(t, 100, tg.cdb)
Expand All @@ -463,7 +463,7 @@ func Test_run_ExecuteFailed(t *testing.T) {
tg := createTestGenerator(t)
layerID := types.GetEffectiveGenesis().Add(100)
require.NoError(t, layers.SetApplied(tg.cdb, layerID-1, types.EmptyBlockID))
tg.Start()
tg.Start(context.Background())
txIDs := createAndSaveTxs(t, 100, tg.cdb)
signers, atxes := createATXs(t, tg.cdb, (layerID.GetEpoch() - 1).FirstLayer(), 10)
activeSet := types.ToATXIDs(atxes)
Expand All @@ -488,7 +488,7 @@ func Test_run_AddBlockFailed(t *testing.T) {
tg := createTestGenerator(t)
layerID := types.GetEffectiveGenesis().Add(100)
require.NoError(t, layers.SetApplied(tg.cdb, layerID-1, types.EmptyBlockID))
tg.Start()
tg.Start(context.Background())
txIDs := createAndSaveTxs(t, 100, tg.cdb)
signers, atxes := createATXs(t, tg.cdb, (layerID.GetEpoch() - 1).FirstLayer(), 10)
activeSet := types.ToATXIDs(atxes)
Expand All @@ -511,7 +511,7 @@ func Test_run_RegisterCertFailureIgnored(t *testing.T) {
tg := createTestGenerator(t)
layerID := types.GetEffectiveGenesis().Add(100)
require.NoError(t, layers.SetApplied(tg.cdb, layerID-1, types.EmptyBlockID))
tg.Start()
tg.Start(context.Background())
txIDs := createAndSaveTxs(t, 100, tg.cdb)
signers, atxes := createATXs(t, tg.cdb, (layerID.GetEpoch() - 1).FirstLayer(), 10)
activeSet := types.ToATXIDs(atxes)
Expand All @@ -537,7 +537,7 @@ func Test_run_CertifyFailureIgnored(t *testing.T) {
tg := createTestGenerator(t)
layerID := types.GetEffectiveGenesis().Add(100)
require.NoError(t, layers.SetApplied(tg.cdb, layerID-1, types.EmptyBlockID))
tg.Start()
tg.Start(context.Background())
txIDs := createAndSaveTxs(t, 100, tg.cdb)
signers, atxes := createATXs(t, tg.cdb, (layerID.GetEpoch() - 1).FirstLayer(), 10)
activeSet := types.ToATXIDs(atxes)
Expand All @@ -563,7 +563,7 @@ func Test_run_ProcessLayerFailed(t *testing.T) {
tg := createTestGenerator(t)
layerID := types.GetEffectiveGenesis().Add(100)
require.NoError(t, layers.SetApplied(tg.cdb, layerID-1, types.EmptyBlockID))
tg.Start()
tg.Start(context.Background())
txIDs := createAndSaveTxs(t, 100, tg.cdb)
signers, atxes := createATXs(t, tg.cdb, (layerID.GetEpoch() - 1).FirstLayer(), 10)
activeSet := types.ToATXIDs(atxes)
Expand Down
6 changes: 2 additions & 4 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,6 @@ func (app *App) initServices(ctx context.Context) error {
app.clock,
beaconProtocol,
trtl,
blocks.WithCertContext(ctx),
blocks.WithCertConfig(blocks.CertConfig{
CommitteeSize: app.Config.HARE.N,
CertifyThreshold: app.Config.HARE.N/2 + 1,
Expand Down Expand Up @@ -836,7 +835,6 @@ func (app *App) initServices(ctx context.Context) error {
fetcherWrapped,
app.certifier,
patrol,
blocks.WithContext(ctx),
blocks.WithConfig(blocks.Config{
BlockGasLimit: app.Config.BlockGasLimit,
OptFilterThreshold: app.Config.OptFilterThreshold,
Expand Down Expand Up @@ -1214,8 +1212,8 @@ func (app *App) startServices(ctx context.Context) error {
app.syncer.Start()
app.beaconProtocol.Start(ctx)

app.blockGen.Start()
app.certifier.Start()
app.blockGen.Start(ctx)
app.certifier.Start(ctx)
if err := app.hare.Start(ctx); err != nil {
return fmt.Errorf("cannot start hare: %w", err)
}
Expand Down

0 comments on commit d13d18f

Please sign in to comment.