From a9f78bea46ebb8ac56405aee3b991af9edc219ab Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Mon, 20 Nov 2023 10:55:15 -0800 Subject: [PATCH 1/5] events: Add filters to keep track of local and other subscriptions This adds a very basic implementation of a list of namespace+eventType combinations that each node is interested in by just running the glob operations in for-loops. Some parallelization is possible, but not enabled by default. It only wires up keeping track of what the local event bus is interested in for now (but doesn't use it yet to filter messages). --- vault/core.go | 6 +- vault/eventbus/bus.go | 66 ++++++++++++++---- vault/eventbus/bus_test.go | 29 +++++--- vault/eventbus/filter.go | 124 ++++++++++++++++++++++++++++++++++ vault/eventbus/filter_test.go | 34 ++++++++++ 5 files changed, 233 insertions(+), 26 deletions(-) create mode 100644 vault/eventbus/filter.go create mode 100644 vault/eventbus/filter_test.go diff --git a/vault/core.go b/vault/core.go index 73a29f03ccfe..e407860f1eb6 100644 --- a/vault/core.go +++ b/vault/core.go @@ -1261,7 +1261,11 @@ func NewCore(conf *CoreConfig) (*Core, error) { eventsLogger := conf.Logger.Named("events") c.allLoggers = append(c.allLoggers, eventsLogger) // start the event system - events, err := eventbus.NewEventBus(eventsLogger) + nodeID, err := c.LoadNodeID() + if err != nil { + return nil, err + } + events, err := eventbus.NewEventBus(c.clusterID.Load, nodeID, eventsLogger) if err != nil { return nil, err } diff --git a/vault/eventbus/bus.go b/vault/eventbus/bus.go index 0185e9fbadf4..7b3a8c853618 100644 --- a/vault/eventbus/bus.go +++ b/vault/eventbus/bus.go @@ -44,8 +44,19 @@ var ( "path", logical.EventMetadataDataPath, } + initCloudEventsFormatterFilterOnce sync.Once ) +func init() { + // Initialize with a blank source URL until an event bus is created. + cloudEventsFormatterFilter = &cloudevents.FormatterFilter{ + Source: &url.URL{}, + Predicate: func(_ context.Context, e interface{}) (bool, error) { + return true, nil + }, + } +} + // EventBus contains the main logic of running an event broker for Vault. // Start() must be called before the EventBus will accept events for sending. type EventBus struct { @@ -54,6 +65,7 @@ type EventBus struct { started atomic.Bool formatterNodeID eventlogger.NodeID timeout time.Duration + filters *Filters } type pluginEventBus struct { @@ -72,6 +84,7 @@ type asyncChanNode struct { closeOnce sync.Once cancelFunc context.CancelFunc pipelineID eventlogger.PipelineID + removeFilter func() removePipeline func(ctx context.Context, t eventlogger.EventType, id eventlogger.PipelineID) (bool, error) } @@ -162,21 +175,36 @@ func (bus *pluginEventBus) SendEvent(ctx context.Context, eventType logical.Even return bus.bus.SendEventInternal(ctx, bus.namespace, bus.pluginInfo, eventType, data) } -func init() { - // TODO: maybe this should relate to the Vault core somehow? - sourceUrl, err := url.Parse("https://vaultproject.io/") - if err != nil { - panic(err) - } - cloudEventsFormatterFilter = &cloudevents.FormatterFilter{ - Source: sourceUrl, - Predicate: func(_ context.Context, e interface{}) (bool, error) { - return true, nil - }, +func setClusterID(clusterIDFunc func() string, localNodeID string) { + // Use the local node ID, in case we aren't running in a cluster. + if cloudEventsFormatterFilter.Source.Scheme == "" { + cloudEventsFormatterFilter.Source, _ = url.Parse("vault://" + localNodeID) } + // The cluster ID is not available until after the cluster is unsealed. + // Poll for the cluster ID with exponential backoff + // TODO: refactor the core.clusterID to support condition variable maybe? + go func() { + clusterID := clusterIDFunc() + backoff := 1 * time.Millisecond + for clusterID == "" { + backoff = backoff * 2 + if backoff > time.Hour { + backoff = time.Hour + } + time.Sleep(backoff) + } + initCloudEventsFormatterFilterOnce.Do(func() { + sourceUrl, err := url.Parse("vault://" + clusterID) + if err != nil { + panic(err) + } + cloudEventsFormatterFilter.Source = sourceUrl + }) + }() } -func NewEventBus(logger hclog.Logger) (*EventBus, error) { +func NewEventBus(clusterIDFunc func() string, localNodeID string, logger hclog.Logger) (*EventBus, error) { + setClusterID(clusterIDFunc, localNodeID) broker, err := eventlogger.NewBroker() if err != nil { return nil, err @@ -197,6 +225,7 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) { broker: broker, formatterNodeID: formatterNodeID, timeout: defaultTimeout, + filters: NewFilters(localNodeID), }, nil } @@ -240,7 +269,12 @@ func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespaceP } ctx, cancel := context.WithCancel(ctx) - asyncNode := newAsyncNode(ctx, bus.logger, bus.broker) + + bus.filters.addPattern(bus.filters.self, namespacePathPatterns, pattern) + + asyncNode := newAsyncNode(ctx, bus.logger, bus.broker, func() { + bus.filters.removePattern(bus.filters.self, namespacePathPatterns, pattern) + }) err = bus.broker.RegisterNode(eventlogger.NodeID(sinkNodeID), asyncNode) if err != nil { defer cancel() @@ -301,7 +335,7 @@ func newFilterNode(namespacePatterns []string, pattern string, bexprFilter strin } } - // Filter for correct event type, including wildcards. + // NodeFilter for correct event type, including wildcards. if !glob.Glob(pattern, eventRecv.EventType) { return false, nil } @@ -315,11 +349,12 @@ func newFilterNode(namespacePatterns []string, pattern string, bexprFilter strin }, nil } -func newAsyncNode(ctx context.Context, logger hclog.Logger, broker *eventlogger.Broker) *asyncChanNode { +func newAsyncNode(ctx context.Context, logger hclog.Logger, broker *eventlogger.Broker, removeFilter func()) *asyncChanNode { return &asyncChanNode{ ctx: ctx, ch: make(chan *eventlogger.Event), logger: logger, + removeFilter: removeFilter, removePipeline: broker.RemovePipelineAndNodes, } } @@ -328,6 +363,7 @@ func newAsyncNode(ctx context.Context, logger hclog.Logger, broker *eventlogger. func (node *asyncChanNode) Close(ctx context.Context) { node.closeOnce.Do(func() { defer node.cancelFunc() + node.removeFilter() removed, err := node.removePipeline(ctx, eventTypeAll, node.pipelineID) switch { diff --git a/vault/eventbus/bus_test.go b/vault/eventbus/bus_test.go index d6e5c9c9cbaf..31dc78b7e214 100644 --- a/vault/eventbus/bus_test.go +++ b/vault/eventbus/bus_test.go @@ -23,7 +23,7 @@ import ( // TestBusBasics tests that basic event sending and subscribing function. func TestBusBasics(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", "", nil) if err != nil { t.Fatal(err) } @@ -78,7 +78,7 @@ func TestBusBasics(t *testing.T) { // TestBusIgnoresSendContext tests that the context is ignored when sending to an event, // so that we do not give up too quickly. func TestBusIgnoresSendContext(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", "", nil) if err != nil { t.Fatal(err) } @@ -119,7 +119,7 @@ func TestBusIgnoresSendContext(t *testing.T) { // TestSubscribeNonRootNamespace verifies that events for non-root namespaces // aren't filtered out by the bus. func TestSubscribeNonRootNamespace(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", "", nil) if err != nil { t.Fatal(err) } @@ -162,7 +162,7 @@ func TestSubscribeNonRootNamespace(t *testing.T) { // TestNamespaceFiltering verifies that events for other namespaces are filtered out by the bus. func TestNamespaceFiltering(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", "", nil) if err != nil { t.Fatal(err) } @@ -222,7 +222,7 @@ func TestNamespaceFiltering(t *testing.T) { // TestBus2Subscriptions verifies that events of different types are successfully routed to the correct subscribers. func TestBus2Subscriptions(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", "", nil) if err != nil { t.Fatal(err) } @@ -293,7 +293,7 @@ func TestBusSubscriptionsCancel(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("cancel=%v", tc.cancel), func(t *testing.T) { subscriptions.Store(0) - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", "", nil) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func waitFor(t *testing.T, maxWait time.Duration, f func() bool) { // TestBusWildcardSubscriptions tests that a single subscription can receive // multiple event types using * for glob patterns. func TestBusWildcardSubscriptions(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", "", nil) if err != nil { t.Fatal(err) } @@ -471,7 +471,7 @@ func TestBusWildcardSubscriptions(t *testing.T) { // TestDataPathIsPrependedWithMount tests that "data_path", if present in the // metadata, is prepended with the plugin's mount. func TestDataPathIsPrependedWithMount(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", "", nil) if err != nil { t.Fatal(err) } @@ -591,7 +591,7 @@ func TestDataPathIsPrependedWithMount(t *testing.T) { // TestBexpr tests go-bexpr filters are evaluated on an event. func TestBexpr(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", "", nil) if err != nil { t.Fatal(err) } @@ -671,7 +671,7 @@ func TestBexpr(t *testing.T) { // TestPipelineCleanedUp ensures pipelines are properly cleaned up after // subscriptions are closed. func TestPipelineCleanedUp(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", "", nil) if err != nil { t.Fatal(err) } @@ -683,6 +683,10 @@ func TestPipelineCleanedUp(t *testing.T) { if err != nil { t.Fatal(err) } + // check that the filters are set + if !bus.filters.anyMatch(namespace.RootNamespace.Path, eventType) { + t.Fatal() + } if !bus.broker.IsAnyPipelineRegistered(eventTypeAll) { cancel() t.Fatal() @@ -693,4 +697,9 @@ func TestPipelineCleanedUp(t *testing.T) { if bus.broker.IsAnyPipelineRegistered(eventTypeAll) { t.Fatal() } + + // and that the filters are cleaned up + if bus.filters.anyMatch(namespace.RootNamespace.Path, eventType) { + t.Fatal() + } } diff --git a/vault/eventbus/filter.go b/vault/eventbus/filter.go new file mode 100644 index 000000000000..5bac0b446c65 --- /dev/null +++ b/vault/eventbus/filter.go @@ -0,0 +1,124 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package eventbus + +import ( + "slices" + "sync" + "sync/atomic" + + "github.com/hashicorp/vault/sdk/logical" + "github.com/ryanuber/go-glob" +) + +// Filters keeps track of all the event patterns that each node is interested in. +type Filters struct { + lock sync.RWMutex + parallel bool + self nodeID + filters map[nodeID]*NodeFilter +} + +// nodeID is used to syntactically indicate that the string is a node's name identifier. +type nodeID string + +// pattern is used to represent one or more combinations of patterns +type pattern struct { + eventTypePattern string + namespacePatterns []string +} + +// NodeFilter keeps track of all patterns that a particular node is interested in. +type NodeFilter struct { + patterns []pattern +} + +func (nf *NodeFilter) match(nsPath string, eventType logical.EventType) bool { + if nf == nil { + return false + } + for _, p := range nf.patterns { + if glob.Glob(p.eventTypePattern, string(eventType)) { + for _, nsp := range p.namespacePatterns { + if glob.Glob(nsp, nsPath) { + return true + } + } + } + } + return false +} + +// NewFilters creates an empty set of filters to keep track of each node's pattern interests. +func NewFilters(self string) *Filters { + return &Filters{ + self: nodeID(self), + filters: map[nodeID]*NodeFilter{}, + } +} + +// addPattern adds a pattern to a node's list. +func (f *Filters) addPattern(node nodeID, namespacePatterns []string, eventTypePattern string) { + f.lock.Lock() + defer f.lock.Unlock() + if _, ok := f.filters[node]; !ok { + f.filters[node] = &NodeFilter{} + } + f.filters[node].patterns = append(f.filters[node].patterns, pattern{eventTypePattern: eventTypePattern, namespacePatterns: namespacePatterns}) +} + +// removePattern removes a pattern from a node's list. +func (f *Filters) removePattern(node nodeID, namespacePatterns []string, eventTypePattern string) { + check := pattern{eventTypePattern: eventTypePattern, namespacePatterns: namespacePatterns} + f.lock.Lock() + defer f.lock.Unlock() + filters, ok := f.filters[node] + if !ok { + return + } + filters.patterns = slices.DeleteFunc(filters.patterns, func(m pattern) bool { + return m.eventTypePattern == check.eventTypePattern && + slices.Equal(m.namespacePatterns, check.namespacePatterns) + }) +} + +// anyMatch returns true if any node's pattern list matches the arguments. +func (f *Filters) anyMatch(nsPath string, eventType logical.EventType) bool { + f.lock.RLock() + defer f.lock.RUnlock() + if f.parallel { + wg := sync.WaitGroup{} + anyMatched := atomic.Bool{} + for _, nf := range f.filters { + wg.Add(1) + go func(nf *NodeFilter) { + if nf.match(nsPath, eventType) { + anyMatched.Store(true) + } + wg.Done() + }(nf) + } + wg.Wait() + return anyMatched.Load() + } else { + for _, nf := range f.filters { + if nf.match(nsPath, eventType) { + return true + } + } + return false + } +} + +// nodeMatch returns true if the given node's pattern list matches the arguments. +func (f *Filters) nodeMatch(node nodeID, nsPath string, eventType logical.EventType) bool { + f.lock.RLock() + defer f.lock.RUnlock() + return f.filters[node].match(nsPath, eventType) +} + +// localMatch returns true if the local node's pattern list matches the arguments. +func (f *Filters) localMatch(nsPath string, eventType logical.EventType) bool { + return f.nodeMatch(f.self, nsPath, eventType) +} diff --git a/vault/eventbus/filter_test.go b/vault/eventbus/filter_test.go new file mode 100644 index 000000000000..21d54b5db936 --- /dev/null +++ b/vault/eventbus/filter_test.go @@ -0,0 +1,34 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package eventbus + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFilters_AddRemoveMatchLocal(t *testing.T) { + f := NewFilters("self") + + assert.False(t, f.localMatch("ns1", "abc")) + assert.False(t, f.anyMatch("ns1", "abc")) + f.addPattern("self", []string{"ns1"}, "abc") + assert.True(t, f.localMatch("ns1", "abc")) + assert.False(t, f.localMatch("ns1", "abcd")) + assert.True(t, f.anyMatch("ns1", "abc")) + assert.False(t, f.anyMatch("ns1", "abcd")) + f.removePattern("self", []string{"ns1"}, "abc") + assert.False(t, f.localMatch("ns1", "abc")) + assert.False(t, f.anyMatch("ns1", "abc")) +} + +func TestFilters_ParallelAnyMatch(t *testing.T) { + f := NewFilters("self") + f.parallel = true + + f.addPattern("self", []string{"ns1"}, "abc") + assert.True(t, f.anyMatch("ns1", "abc")) + assert.False(t, f.anyMatch("ns1", "abcd")) +} From 14ebbcce0e7faeab7f6ec99e15f76653db016d9f Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Mon, 20 Nov 2023 11:09:33 -0800 Subject: [PATCH 2/5] Add test docs --- vault/eventbus/filter_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vault/eventbus/filter_test.go b/vault/eventbus/filter_test.go index 21d54b5db936..b84cc2a41841 100644 --- a/vault/eventbus/filter_test.go +++ b/vault/eventbus/filter_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" ) +// TestFilters_AddRemoveMatchLocal checks that basic matching, adding, and removing of patterns all work. func TestFilters_AddRemoveMatchLocal(t *testing.T) { f := NewFilters("self") @@ -24,6 +25,7 @@ func TestFilters_AddRemoveMatchLocal(t *testing.T) { assert.False(t, f.anyMatch("ns1", "abc")) } +// TestFilters_ParallelAnyMatch checks that anyMatch works with parallel set to true. func TestFilters_ParallelAnyMatch(t *testing.T) { f := NewFilters("self") f.parallel = true From e6c7a93dd10f87dd81feb182df638bc9a5b85917 Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Mon, 20 Nov 2023 12:52:39 -0800 Subject: [PATCH 3/5] Don't use cluster ID for cloudevents URL as it introduces race conditions; local node is probably better anyway --- vault/core.go | 2 +- vault/eventbus/bus.go | 84 +++++++++++++------------------------- vault/eventbus/bus_test.go | 20 ++++----- 3 files changed, 40 insertions(+), 66 deletions(-) diff --git a/vault/core.go b/vault/core.go index e407860f1eb6..eff0e9cb1a43 100644 --- a/vault/core.go +++ b/vault/core.go @@ -1265,7 +1265,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { if err != nil { return nil, err } - events, err := eventbus.NewEventBus(c.clusterID.Load, nodeID, eventsLogger) + events, err := eventbus.NewEventBus(nodeID, eventsLogger) if err != nil { return nil, err } diff --git a/vault/eventbus/bus.go b/vault/eventbus/bus.go index 7b3a8c853618..6dd100ff13e7 100644 --- a/vault/eventbus/bus.go +++ b/vault/eventbus/bus.go @@ -35,9 +35,8 @@ const ( ) var ( - ErrNotStarted = errors.New("event broker has not been started") - cloudEventsFormatterFilter *cloudevents.FormatterFilter - subscriptions atomic.Int64 // keeps track of event subscription count in all event buses + ErrNotStarted = errors.New("event broker has not been started") + subscriptions atomic.Int64 // keeps track of event subscription count in all event buses // these metadata fields will have the plugin mount path prepended to them metadataPrependPathFields = []string{ @@ -47,25 +46,16 @@ var ( initCloudEventsFormatterFilterOnce sync.Once ) -func init() { - // Initialize with a blank source URL until an event bus is created. - cloudEventsFormatterFilter = &cloudevents.FormatterFilter{ - Source: &url.URL{}, - Predicate: func(_ context.Context, e interface{}) (bool, error) { - return true, nil - }, - } -} - // EventBus contains the main logic of running an event broker for Vault. // Start() must be called before the EventBus will accept events for sending. type EventBus struct { - logger hclog.Logger - broker *eventlogger.Broker - started atomic.Bool - formatterNodeID eventlogger.NodeID - timeout time.Duration - filters *Filters + logger hclog.Logger + broker *eventlogger.Broker + started atomic.Bool + formatterNodeID eventlogger.NodeID + timeout time.Duration + filters *Filters + cloudEventsFormatterFilter *cloudevents.FormatterFilter } type pluginEventBus struct { @@ -175,36 +165,7 @@ func (bus *pluginEventBus) SendEvent(ctx context.Context, eventType logical.Even return bus.bus.SendEventInternal(ctx, bus.namespace, bus.pluginInfo, eventType, data) } -func setClusterID(clusterIDFunc func() string, localNodeID string) { - // Use the local node ID, in case we aren't running in a cluster. - if cloudEventsFormatterFilter.Source.Scheme == "" { - cloudEventsFormatterFilter.Source, _ = url.Parse("vault://" + localNodeID) - } - // The cluster ID is not available until after the cluster is unsealed. - // Poll for the cluster ID with exponential backoff - // TODO: refactor the core.clusterID to support condition variable maybe? - go func() { - clusterID := clusterIDFunc() - backoff := 1 * time.Millisecond - for clusterID == "" { - backoff = backoff * 2 - if backoff > time.Hour { - backoff = time.Hour - } - time.Sleep(backoff) - } - initCloudEventsFormatterFilterOnce.Do(func() { - sourceUrl, err := url.Parse("vault://" + clusterID) - if err != nil { - panic(err) - } - cloudEventsFormatterFilter.Source = sourceUrl - }) - }() -} - -func NewEventBus(clusterIDFunc func() string, localNodeID string, logger hclog.Logger) (*EventBus, error) { - setClusterID(clusterIDFunc, localNodeID) +func NewEventBus(localNodeID string, logger hclog.Logger) (*EventBus, error) { broker, err := eventlogger.NewBroker() if err != nil { return nil, err @@ -220,12 +181,25 @@ func NewEventBus(clusterIDFunc func() string, localNodeID string, logger hclog.L logger = hclog.Default().Named("events") } + sourceUrl, err := url.Parse("vault://" + localNodeID) + if err != nil { + return nil, err + } + + cloudEventsFormatterFilter := &cloudevents.FormatterFilter{ + Source: sourceUrl, + Predicate: func(_ context.Context, e interface{}) (bool, error) { + return true, nil + }, + } + return &EventBus{ - logger: logger, - broker: broker, - formatterNodeID: formatterNodeID, - timeout: defaultTimeout, - filters: NewFilters(localNodeID), + logger: logger, + broker: broker, + formatterNodeID: formatterNodeID, + timeout: defaultTimeout, + cloudEventsFormatterFilter: cloudEventsFormatterFilter, + filters: NewFilters(localNodeID), }, nil } @@ -244,7 +218,7 @@ func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespaceP return nil, nil, err } - err = bus.broker.RegisterNode(bus.formatterNodeID, cloudEventsFormatterFilter) + err = bus.broker.RegisterNode(bus.formatterNodeID, bus.cloudEventsFormatterFilter) if err != nil { return nil, nil, err } diff --git a/vault/eventbus/bus_test.go b/vault/eventbus/bus_test.go index 31dc78b7e214..57dc1c93db60 100644 --- a/vault/eventbus/bus_test.go +++ b/vault/eventbus/bus_test.go @@ -23,7 +23,7 @@ import ( // TestBusBasics tests that basic event sending and subscribing function. func TestBusBasics(t *testing.T) { - bus, err := NewEventBus("", "", nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -78,7 +78,7 @@ func TestBusBasics(t *testing.T) { // TestBusIgnoresSendContext tests that the context is ignored when sending to an event, // so that we do not give up too quickly. func TestBusIgnoresSendContext(t *testing.T) { - bus, err := NewEventBus("", "", nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -119,7 +119,7 @@ func TestBusIgnoresSendContext(t *testing.T) { // TestSubscribeNonRootNamespace verifies that events for non-root namespaces // aren't filtered out by the bus. func TestSubscribeNonRootNamespace(t *testing.T) { - bus, err := NewEventBus("", "", nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -162,7 +162,7 @@ func TestSubscribeNonRootNamespace(t *testing.T) { // TestNamespaceFiltering verifies that events for other namespaces are filtered out by the bus. func TestNamespaceFiltering(t *testing.T) { - bus, err := NewEventBus("", "", nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -222,7 +222,7 @@ func TestNamespaceFiltering(t *testing.T) { // TestBus2Subscriptions verifies that events of different types are successfully routed to the correct subscribers. func TestBus2Subscriptions(t *testing.T) { - bus, err := NewEventBus("", "", nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -293,7 +293,7 @@ func TestBusSubscriptionsCancel(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("cancel=%v", tc.cancel), func(t *testing.T) { subscriptions.Store(0) - bus, err := NewEventBus("", "", nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func waitFor(t *testing.T, maxWait time.Duration, f func() bool) { // TestBusWildcardSubscriptions tests that a single subscription can receive // multiple event types using * for glob patterns. func TestBusWildcardSubscriptions(t *testing.T) { - bus, err := NewEventBus("", "", nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -471,7 +471,7 @@ func TestBusWildcardSubscriptions(t *testing.T) { // TestDataPathIsPrependedWithMount tests that "data_path", if present in the // metadata, is prepended with the plugin's mount. func TestDataPathIsPrependedWithMount(t *testing.T) { - bus, err := NewEventBus("", "", nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -591,7 +591,7 @@ func TestDataPathIsPrependedWithMount(t *testing.T) { // TestBexpr tests go-bexpr filters are evaluated on an event. func TestBexpr(t *testing.T) { - bus, err := NewEventBus("", "", nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -671,7 +671,7 @@ func TestBexpr(t *testing.T) { // TestPipelineCleanedUp ensures pipelines are properly cleaned up after // subscriptions are closed. func TestPipelineCleanedUp(t *testing.T) { - bus, err := NewEventBus("", "", nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } From ba0c32e5c3f7fe389c6dee727087b62bda9a968f Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Mon, 27 Nov 2023 14:54:37 -0800 Subject: [PATCH 4/5] Update vault/eventbus/bus.go Co-authored-by: Tom Proctor --- vault/eventbus/bus.go | 1 - 1 file changed, 1 deletion(-) diff --git a/vault/eventbus/bus.go b/vault/eventbus/bus.go index 6dd100ff13e7..fd5102ad8f6d 100644 --- a/vault/eventbus/bus.go +++ b/vault/eventbus/bus.go @@ -43,7 +43,6 @@ var ( "path", logical.EventMetadataDataPath, } - initCloudEventsFormatterFilterOnce sync.Once ) // EventBus contains the main logic of running an event broker for Vault. From d4fa975d5322a7806875b83eb55ed4571b187000 Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Mon, 27 Nov 2023 15:06:27 -0800 Subject: [PATCH 5/5] Address PR feedback --- changelog/24201.txt | 3 ++ vault/eventbus/bus_test.go | 4 +-- vault/eventbus/filter.go | 64 ++++++++++++++++------------------- vault/eventbus/filter_test.go | 35 ++++++++----------- 4 files changed, 50 insertions(+), 56 deletions(-) create mode 100644 changelog/24201.txt diff --git a/changelog/24201.txt b/changelog/24201.txt new file mode 100644 index 000000000000..9253e44ab8c0 --- /dev/null +++ b/changelog/24201.txt @@ -0,0 +1,3 @@ +```release-note:change +events: Source URL is now `vault://{vault node}` +``` diff --git a/vault/eventbus/bus_test.go b/vault/eventbus/bus_test.go index 57dc1c93db60..0255f7fe8678 100644 --- a/vault/eventbus/bus_test.go +++ b/vault/eventbus/bus_test.go @@ -684,7 +684,7 @@ func TestPipelineCleanedUp(t *testing.T) { t.Fatal(err) } // check that the filters are set - if !bus.filters.anyMatch(namespace.RootNamespace.Path, eventType) { + if !bus.filters.anyMatch(namespace.RootNamespace, eventType) { t.Fatal() } if !bus.broker.IsAnyPipelineRegistered(eventTypeAll) { @@ -699,7 +699,7 @@ func TestPipelineCleanedUp(t *testing.T) { } // and that the filters are cleaned up - if bus.filters.anyMatch(namespace.RootNamespace.Path, eventType) { + if bus.filters.anyMatch(namespace.RootNamespace, eventType) { t.Fatal() } } diff --git a/vault/eventbus/filter.go b/vault/eventbus/filter.go index 5bac0b446c65..7d4268aacfe8 100644 --- a/vault/eventbus/filter.go +++ b/vault/eventbus/filter.go @@ -5,19 +5,19 @@ package eventbus import ( "slices" + "sort" "sync" - "sync/atomic" + "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" "github.com/ryanuber/go-glob" ) // Filters keeps track of all the event patterns that each node is interested in. type Filters struct { - lock sync.RWMutex - parallel bool - self nodeID - filters map[nodeID]*NodeFilter + lock sync.RWMutex + self nodeID + filters map[nodeID]*NodeFilter } // nodeID is used to syntactically indicate that the string is a node's name identifier. @@ -34,14 +34,14 @@ type NodeFilter struct { patterns []pattern } -func (nf *NodeFilter) match(nsPath string, eventType logical.EventType) bool { +func (nf *NodeFilter) match(ns *namespace.Namespace, eventType logical.EventType) bool { if nf == nil { return false } for _, p := range nf.patterns { if glob.Glob(p.eventTypePattern, string(eventType)) { for _, nsp := range p.namespacePatterns { - if glob.Glob(nsp, nsPath) { + if glob.Glob(nsp, ns.Path) { return true } } @@ -65,12 +65,20 @@ func (f *Filters) addPattern(node nodeID, namespacePatterns []string, eventTypeP if _, ok := f.filters[node]; !ok { f.filters[node] = &NodeFilter{} } - f.filters[node].patterns = append(f.filters[node].patterns, pattern{eventTypePattern: eventTypePattern, namespacePatterns: namespacePatterns}) + nsPatterns := slices.Clone(namespacePatterns) + sort.Strings(nsPatterns) + f.filters[node].patterns = append(f.filters[node].patterns, pattern{eventTypePattern: eventTypePattern, namespacePatterns: nsPatterns}) +} + +func (f *Filters) addNsPattern(node nodeID, ns *namespace.Namespace, eventTypePattern string) { + f.addPattern(node, []string{ns.Path}, eventTypePattern) } // removePattern removes a pattern from a node's list. func (f *Filters) removePattern(node nodeID, namespacePatterns []string, eventTypePattern string) { - check := pattern{eventTypePattern: eventTypePattern, namespacePatterns: namespacePatterns} + nsPatterns := slices.Clone(namespacePatterns) + sort.Strings(nsPatterns) + check := pattern{eventTypePattern: eventTypePattern, namespacePatterns: nsPatterns} f.lock.Lock() defer f.lock.Unlock() filters, ok := f.filters[node] @@ -83,42 +91,30 @@ func (f *Filters) removePattern(node nodeID, namespacePatterns []string, eventTy }) } +func (f *Filters) removeNsPattern(node nodeID, ns *namespace.Namespace, eventTypePattern string) { + f.removePattern(node, []string{ns.Path}, eventTypePattern) +} + // anyMatch returns true if any node's pattern list matches the arguments. -func (f *Filters) anyMatch(nsPath string, eventType logical.EventType) bool { +func (f *Filters) anyMatch(ns *namespace.Namespace, eventType logical.EventType) bool { f.lock.RLock() defer f.lock.RUnlock() - if f.parallel { - wg := sync.WaitGroup{} - anyMatched := atomic.Bool{} - for _, nf := range f.filters { - wg.Add(1) - go func(nf *NodeFilter) { - if nf.match(nsPath, eventType) { - anyMatched.Store(true) - } - wg.Done() - }(nf) + for _, nf := range f.filters { + if nf.match(ns, eventType) { + return true } - wg.Wait() - return anyMatched.Load() - } else { - for _, nf := range f.filters { - if nf.match(nsPath, eventType) { - return true - } - } - return false } + return false } // nodeMatch returns true if the given node's pattern list matches the arguments. -func (f *Filters) nodeMatch(node nodeID, nsPath string, eventType logical.EventType) bool { +func (f *Filters) nodeMatch(node nodeID, ns *namespace.Namespace, eventType logical.EventType) bool { f.lock.RLock() defer f.lock.RUnlock() - return f.filters[node].match(nsPath, eventType) + return f.filters[node].match(ns, eventType) } // localMatch returns true if the local node's pattern list matches the arguments. -func (f *Filters) localMatch(nsPath string, eventType logical.EventType) bool { - return f.nodeMatch(f.self, nsPath, eventType) +func (f *Filters) localMatch(ns *namespace.Namespace, eventType logical.EventType) bool { + return f.nodeMatch(f.self, ns, eventType) } diff --git a/vault/eventbus/filter_test.go b/vault/eventbus/filter_test.go index b84cc2a41841..034fb74f95a6 100644 --- a/vault/eventbus/filter_test.go +++ b/vault/eventbus/filter_test.go @@ -6,31 +6,26 @@ package eventbus import ( "testing" + "github.com/hashicorp/vault/helper/namespace" "github.com/stretchr/testify/assert" ) // TestFilters_AddRemoveMatchLocal checks that basic matching, adding, and removing of patterns all work. func TestFilters_AddRemoveMatchLocal(t *testing.T) { f := NewFilters("self") + ns := &namespace.Namespace{ + ID: "ns1", + Path: "ns1", + } - assert.False(t, f.localMatch("ns1", "abc")) - assert.False(t, f.anyMatch("ns1", "abc")) - f.addPattern("self", []string{"ns1"}, "abc") - assert.True(t, f.localMatch("ns1", "abc")) - assert.False(t, f.localMatch("ns1", "abcd")) - assert.True(t, f.anyMatch("ns1", "abc")) - assert.False(t, f.anyMatch("ns1", "abcd")) - f.removePattern("self", []string{"ns1"}, "abc") - assert.False(t, f.localMatch("ns1", "abc")) - assert.False(t, f.anyMatch("ns1", "abc")) -} - -// TestFilters_ParallelAnyMatch checks that anyMatch works with parallel set to true. -func TestFilters_ParallelAnyMatch(t *testing.T) { - f := NewFilters("self") - f.parallel = true - - f.addPattern("self", []string{"ns1"}, "abc") - assert.True(t, f.anyMatch("ns1", "abc")) - assert.False(t, f.anyMatch("ns1", "abcd")) + assert.False(t, f.localMatch(ns, "abc")) + assert.False(t, f.anyMatch(ns, "abc")) + f.addNsPattern("self", ns, "abc") + assert.True(t, f.localMatch(ns, "abc")) + assert.False(t, f.localMatch(ns, "abcd")) + assert.True(t, f.anyMatch(ns, "abc")) + assert.False(t, f.anyMatch(ns, "abcd")) + f.removeNsPattern("self", ns, "abc") + assert.False(t, f.localMatch(ns, "abc")) + assert.False(t, f.anyMatch(ns, "abc")) }