Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trigger for doing reconciliation based on watch sets #16052

Merged
merged 4 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 97 additions & 6 deletions agent/consul/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"

"github.com/hashicorp/consul/agent/consul/state"
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-hclog"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -39,13 +41,22 @@ type Controller interface {
// Request retry rate limiter. This should only ever be called prior to
// running Run.
WithBackoff(base, max time.Duration) Controller
// WithLogger sets the logger for the controller, it should be called prior to Start
// being invoked.
WithLogger(logger hclog.Logger) Controller
// WithWorkers sets the number of worker goroutines used to process the queue
// this defaults to 1 goroutine.
WithWorkers(i int) Controller
// WithQueueFactory allows a Controller to replace its underlying work queue
// implementation. This is most useful for testing. This should only ever be called
// prior to running Run.
WithQueueFactory(fn func(ctx context.Context, baseBackoff time.Duration, maxBackoff time.Duration) WorkQueue) Controller
// AddTrigger allows for triggering a reconciliation request when a
// triggering function returns, when the passed in context is canceled
// the trigger must return
AddTrigger(request Request, trigger func(ctx context.Context) error)
// RemoveTrigger removes the triggering function associated with the Request object
RemoveTrigger(request Request)
}

var _ Controller = &controller{}
Expand Down Expand Up @@ -78,8 +89,27 @@ type controller struct {
// publisher is the event publisher that should be subscribed to for any updates
publisher state.EventPublisher

// waitOnce ensures we wait until the controller has started
waitOnce sync.Once
// started signals when the controller has started
started chan struct{}

// group is the error group used in our main start up worker routines
group *errgroup.Group
// groupCtx is the context of the error group to use in spinning up our
// worker routines
groupCtx context.Context

// triggers is a map of cancel functions for out-of-band Request triggers
triggers map[Request]func()
// triggerMutex is used for accessing the above map
triggerMutex sync.Mutex

// running ensures that we are only calling Run a single time
running int32

// logger is the logger for the controller
logger hclog.Logger
}

// New returns a new Controller associated with the given state store and reconciler.
Expand All @@ -91,6 +121,9 @@ func New(publisher state.EventPublisher, reconciler Reconciler) Controller {
baseBackoff: 5 * time.Millisecond,
maxBackoff: 1000 * time.Second,
makeQueue: RunWorkQueue,
started: make(chan struct{}),
triggers: make(map[Request]func()),
logger: hclog.NewNullLogger(),
}
}

Expand Down Expand Up @@ -130,6 +163,14 @@ func (c *controller) WithWorkers(i int) Controller {
return c
}

// WithLogger sets the internal logger for the controller.
func (c *controller) WithLogger(logger hclog.Logger) Controller {
c.ensureNotRunning()

c.logger = logger
return c
}

// WithQueueFactory changes the initialization method for the Controller's work
// queue, this is predominantly just used for testing. This should only ever be called
// prior to running Start.
Expand Down Expand Up @@ -157,15 +198,18 @@ func (c *controller) Run(ctx context.Context) error {
panic("Run cannot be called more than once")
}

group, groupCtx := errgroup.WithContext(ctx)
c.group, c.groupCtx = errgroup.WithContext(ctx)

// set up our queue
c.work = c.makeQueue(groupCtx, c.baseBackoff, c.maxBackoff)
c.work = c.makeQueue(c.groupCtx, c.baseBackoff, c.maxBackoff)

// we can now add stuff to the queue from other contexts
close(c.started)

for _, sub := range c.subscriptions {
// store a reference for the closure
sub := sub
group.Go(func() error {
c.group.Go(func() error {
var index uint64

subscription, err := c.publisher.Subscribe(sub.request)
Expand Down Expand Up @@ -201,25 +245,72 @@ func (c *controller) Run(ctx context.Context) error {
}

for i := 0; i < c.workers; i++ {
group.Go(func() error {
c.group.Go(func() error {
for {
request, shutdown := c.work.Get()
if shutdown {
// Stop working
return nil
}
c.reconcileHandler(groupCtx, request)
c.reconcileHandler(c.groupCtx, request)
// Done is called here because it is required to be called
// when we've finished processing each request
c.work.Done(request)
}
})
}

<-groupCtx.Done()
<-c.groupCtx.Done()
return nil
}

// AddTrigger allows for triggering a reconciliation request every time that the
// triggering function returns, when the passed in context is canceled
// the trigger must return
func (c *controller) AddTrigger(request Request, trigger func(ctx context.Context) error) {
c.wait()

ctx, cancel := context.WithCancel(c.groupCtx)

c.triggerMutex.Lock()
oldCancel, ok := c.triggers[request]
if ok {
oldCancel()
}
c.triggers[request] = cancel
c.triggerMutex.Unlock()

c.group.Go(func() error {
if err := trigger(ctx); err != nil {
c.logger.Error("error while running trigger, adding re-reconcilation anyway", "error", err)
}
select {
case <-ctx.Done():
return nil
default:
c.work.Add(request)
return nil
}
})
}

// RemoveTrigger removes the triggering function associated with the Request object
func (c *controller) RemoveTrigger(request Request) {
c.triggerMutex.Lock()
cancel, ok := c.triggers[request]
if ok {
cancel()
delete(c.triggers, request)
}
c.triggerMutex.Unlock()
}

func (c *controller) wait() {
c.waitOnce.Do(func() {
<-c.started
})
}

func (c *controller) processEvent(sub subscription, event stream.Event) error {
switch payload := event.Payload.(type) {
case state.EventPayloadConfigEntry:
Expand Down
146 changes: 146 additions & 0 deletions agent/consul/controller/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -415,3 +416,148 @@ func TestConfigEntrySubscriptions(t *testing.T) {
})
}
}

func TestBasicController_Triggers(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

reconciler := newTestReconciler(true)

publisher := stream.NewEventPublisher(0)
go publisher.Run(ctx)

controller := New(publisher, reconciler)

go func() {
require.NoError(t, controller.Run(ctx))
}()

ensureCalled := func(request chan Request, name string) bool {
select {
case req := <-request:
require.Equal(t, structs.IngressGateway, req.Kind)
require.Equal(t, name, req.Name)
return true
case <-time.After(10 * time.Millisecond):
return false
}
}

request := Request{
Kind: structs.IngressGateway,
Name: "foo-1",
}

triggerOneChan := make(chan struct{}, 3)
triggerOne := func(ctx context.Context) error {
select {
case <-triggerOneChan:
return nil
case <-ctx.Done():
return nil
}
}
controller.AddTrigger(request, triggerOne)
require.False(t, ensureCalled(reconciler.received, "foo-1"))
triggerOneChan <- struct{}{}
reconciler.stepFor(10 * time.Millisecond)
require.True(t, ensureCalled(reconciler.received, "foo-1"))

// do it again
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

require.False(t, ensureCalled(reconciler.received, "foo-1"))
controller.AddTrigger(request, triggerOne)
triggerOneChan <- struct{}{}
reconciler.stepFor(10 * time.Millisecond)
require.True(t, ensureCalled(reconciler.received, "foo-1"))

// check with the overwritten trigger
controller.AddTrigger(request, triggerOne)
triggerTwoChan := make(chan struct{}, 2)
triggerTwo := func(ctx context.Context) error {
select {
case <-triggerTwoChan:
return nil
case <-ctx.Done():
return nil
}
}
controller.AddTrigger(request, triggerTwo)
triggerOneChan <- struct{}{}
reconciler.stepFor(10 * time.Millisecond)
require.False(t, ensureCalled(reconciler.received, "foo-1"))
triggerTwoChan <- struct{}{}
reconciler.stepFor(10 * time.Millisecond)
require.True(t, ensureCalled(reconciler.received, "foo-1"))

// remove the trigger and make sure we're not called again
controller.RemoveTrigger(request)
triggerTwoChan <- struct{}{}
reconciler.stepFor(10 * time.Millisecond)
require.False(t, ensureCalled(reconciler.received, "foo-1"))
}

func TestDiscoveryChainController(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

reconciler := newTestReconciler(false)

publisher := stream.NewEventPublisher(1 * time.Millisecond)
go publisher.Run(ctx)

// get the store through the FSM since the publisher handlers get registered through it
store := fsm.NewFromDeps(fsm.Deps{
Logger: hclog.New(nil),
NewStateStore: func() *state.Store {
return state.NewStateStoreWithEventPublisher(nil, publisher)
},
Publisher: publisher,
}).State()

controller := New(publisher, reconciler)
go controller.Subscribe(&stream.SubscribeRequest{
Topic: state.EventTopicIngressGateway,
Subject: stream.SubjectWildcard,
}).WithWorkers(10).Run(ctx)

request := Request{
Kind: structs.IngressGateway,
Name: "foo-1",
}

ensureCalled := func(request chan Request, name string) bool {
select {
case req := <-request:
require.Equal(t, structs.IngressGateway, req.Kind)
require.Equal(t, name, req.Name)
return true
case <-time.After(10 * time.Millisecond):
return false
}
}

require.NoError(t, store.EnsureConfigEntry(1, &structs.IngressGatewayConfigEntry{
Kind: structs.IngressGateway,
Name: "foo-1",
}))
require.True(t, ensureCalled(reconciler.received, "foo-1"))

// create the trigger and something that changes in its upstream discovery chain and ensure that we've
// fired the reconciler
ws := memdb.NewWatchSet()
ws.Add(store.AbandonCh())
_, _, err := store.ReadDiscoveryChainConfigEntries(ws, "foo-2", nil)
require.NoError(t, err)
controller.AddTrigger(request, ws.WatchCtx)

require.False(t, ensureCalled(reconciler.received, "foo-1"))
require.NoError(t, store.EnsureConfigEntry(1, &structs.ServiceResolverConfigEntry{
Kind: structs.ServiceResolver,
Name: "foo-2",
}))
require.True(t, ensureCalled(reconciler.received, "foo-1"))
}
7 changes: 7 additions & 0 deletions agent/consul/controller/reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package controller
import (
"context"
"sync"
"time"
)

type testReconciler struct {
Expand Down Expand Up @@ -43,6 +44,12 @@ func (r *testReconciler) setResponse(err error) {
func (r *testReconciler) step() {
r.stepChan <- struct{}{}
}
func (r *testReconciler) stepFor(duration time.Duration) {
select {
case r.stepChan <- struct{}{}:
case <-time.After(duration):
}
}

func (r *testReconciler) stop() {
close(r.stopChan)
Expand Down