diff --git a/changelog/24201.txt b/changelog/24201.txt new file mode 100644 index 000000000000..9253e44ab8c0 --- /dev/null +++ b/changelog/24201.txt @@ -0,0 +1,3 @@ +```release-note:change +events: Source URL is now `vault://{vault node}` +``` diff --git a/vault/core.go b/vault/core.go index 73a29f03ccfe..eff0e9cb1a43 100644 --- a/vault/core.go +++ b/vault/core.go @@ -1261,7 +1261,11 @@ func NewCore(conf *CoreConfig) (*Core, error) { eventsLogger := conf.Logger.Named("events") c.allLoggers = append(c.allLoggers, eventsLogger) // start the event system - events, err := eventbus.NewEventBus(eventsLogger) + nodeID, err := c.LoadNodeID() + if err != nil { + return nil, err + } + events, err := eventbus.NewEventBus(nodeID, eventsLogger) if err != nil { return nil, err } diff --git a/vault/eventbus/bus.go b/vault/eventbus/bus.go index 0185e9fbadf4..fd5102ad8f6d 100644 --- a/vault/eventbus/bus.go +++ b/vault/eventbus/bus.go @@ -35,9 +35,8 @@ const ( ) var ( - ErrNotStarted = errors.New("event broker has not been started") - cloudEventsFormatterFilter *cloudevents.FormatterFilter - subscriptions atomic.Int64 // keeps track of event subscription count in all event buses + ErrNotStarted = errors.New("event broker has not been started") + subscriptions atomic.Int64 // keeps track of event subscription count in all event buses // these metadata fields will have the plugin mount path prepended to them metadataPrependPathFields = []string{ @@ -49,11 +48,13 @@ var ( // EventBus contains the main logic of running an event broker for Vault. // Start() must be called before the EventBus will accept events for sending. type EventBus struct { - logger hclog.Logger - broker *eventlogger.Broker - started atomic.Bool - formatterNodeID eventlogger.NodeID - timeout time.Duration + logger hclog.Logger + broker *eventlogger.Broker + started atomic.Bool + formatterNodeID eventlogger.NodeID + timeout time.Duration + filters *Filters + cloudEventsFormatterFilter *cloudevents.FormatterFilter } type pluginEventBus struct { @@ -72,6 +73,7 @@ type asyncChanNode struct { closeOnce sync.Once cancelFunc context.CancelFunc pipelineID eventlogger.PipelineID + removeFilter func() removePipeline func(ctx context.Context, t eventlogger.EventType, id eventlogger.PipelineID) (bool, error) } @@ -162,21 +164,7 @@ func (bus *pluginEventBus) SendEvent(ctx context.Context, eventType logical.Even return bus.bus.SendEventInternal(ctx, bus.namespace, bus.pluginInfo, eventType, data) } -func init() { - // TODO: maybe this should relate to the Vault core somehow? - sourceUrl, err := url.Parse("https://vaultproject.io/") - if err != nil { - panic(err) - } - cloudEventsFormatterFilter = &cloudevents.FormatterFilter{ - Source: sourceUrl, - Predicate: func(_ context.Context, e interface{}) (bool, error) { - return true, nil - }, - } -} - -func NewEventBus(logger hclog.Logger) (*EventBus, error) { +func NewEventBus(localNodeID string, logger hclog.Logger) (*EventBus, error) { broker, err := eventlogger.NewBroker() if err != nil { return nil, err @@ -192,11 +180,25 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) { logger = hclog.Default().Named("events") } + sourceUrl, err := url.Parse("vault://" + localNodeID) + if err != nil { + return nil, err + } + + cloudEventsFormatterFilter := &cloudevents.FormatterFilter{ + Source: sourceUrl, + Predicate: func(_ context.Context, e interface{}) (bool, error) { + return true, nil + }, + } + return &EventBus{ - logger: logger, - broker: broker, - formatterNodeID: formatterNodeID, - timeout: defaultTimeout, + logger: logger, + broker: broker, + formatterNodeID: formatterNodeID, + timeout: defaultTimeout, + cloudEventsFormatterFilter: cloudEventsFormatterFilter, + filters: NewFilters(localNodeID), }, nil } @@ -215,7 +217,7 @@ func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespaceP return nil, nil, err } - err = bus.broker.RegisterNode(bus.formatterNodeID, cloudEventsFormatterFilter) + err = bus.broker.RegisterNode(bus.formatterNodeID, bus.cloudEventsFormatterFilter) if err != nil { return nil, nil, err } @@ -240,7 +242,12 @@ func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespaceP } ctx, cancel := context.WithCancel(ctx) - asyncNode := newAsyncNode(ctx, bus.logger, bus.broker) + + bus.filters.addPattern(bus.filters.self, namespacePathPatterns, pattern) + + asyncNode := newAsyncNode(ctx, bus.logger, bus.broker, func() { + bus.filters.removePattern(bus.filters.self, namespacePathPatterns, pattern) + }) err = bus.broker.RegisterNode(eventlogger.NodeID(sinkNodeID), asyncNode) if err != nil { defer cancel() @@ -301,7 +308,7 @@ func newFilterNode(namespacePatterns []string, pattern string, bexprFilter strin } } - // Filter for correct event type, including wildcards. + // NodeFilter for correct event type, including wildcards. if !glob.Glob(pattern, eventRecv.EventType) { return false, nil } @@ -315,11 +322,12 @@ func newFilterNode(namespacePatterns []string, pattern string, bexprFilter strin }, nil } -func newAsyncNode(ctx context.Context, logger hclog.Logger, broker *eventlogger.Broker) *asyncChanNode { +func newAsyncNode(ctx context.Context, logger hclog.Logger, broker *eventlogger.Broker, removeFilter func()) *asyncChanNode { return &asyncChanNode{ ctx: ctx, ch: make(chan *eventlogger.Event), logger: logger, + removeFilter: removeFilter, removePipeline: broker.RemovePipelineAndNodes, } } @@ -328,6 +336,7 @@ func newAsyncNode(ctx context.Context, logger hclog.Logger, broker *eventlogger. func (node *asyncChanNode) Close(ctx context.Context) { node.closeOnce.Do(func() { defer node.cancelFunc() + node.removeFilter() removed, err := node.removePipeline(ctx, eventTypeAll, node.pipelineID) switch { diff --git a/vault/eventbus/bus_test.go b/vault/eventbus/bus_test.go index d6e5c9c9cbaf..0255f7fe8678 100644 --- a/vault/eventbus/bus_test.go +++ b/vault/eventbus/bus_test.go @@ -23,7 +23,7 @@ import ( // TestBusBasics tests that basic event sending and subscribing function. func TestBusBasics(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -78,7 +78,7 @@ func TestBusBasics(t *testing.T) { // TestBusIgnoresSendContext tests that the context is ignored when sending to an event, // so that we do not give up too quickly. func TestBusIgnoresSendContext(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -119,7 +119,7 @@ func TestBusIgnoresSendContext(t *testing.T) { // TestSubscribeNonRootNamespace verifies that events for non-root namespaces // aren't filtered out by the bus. func TestSubscribeNonRootNamespace(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -162,7 +162,7 @@ func TestSubscribeNonRootNamespace(t *testing.T) { // TestNamespaceFiltering verifies that events for other namespaces are filtered out by the bus. func TestNamespaceFiltering(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -222,7 +222,7 @@ func TestNamespaceFiltering(t *testing.T) { // TestBus2Subscriptions verifies that events of different types are successfully routed to the correct subscribers. func TestBus2Subscriptions(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -293,7 +293,7 @@ func TestBusSubscriptionsCancel(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("cancel=%v", tc.cancel), func(t *testing.T) { subscriptions.Store(0) - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func waitFor(t *testing.T, maxWait time.Duration, f func() bool) { // TestBusWildcardSubscriptions tests that a single subscription can receive // multiple event types using * for glob patterns. func TestBusWildcardSubscriptions(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -471,7 +471,7 @@ func TestBusWildcardSubscriptions(t *testing.T) { // TestDataPathIsPrependedWithMount tests that "data_path", if present in the // metadata, is prepended with the plugin's mount. func TestDataPathIsPrependedWithMount(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -591,7 +591,7 @@ func TestDataPathIsPrependedWithMount(t *testing.T) { // TestBexpr tests go-bexpr filters are evaluated on an event. func TestBexpr(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -671,7 +671,7 @@ func TestBexpr(t *testing.T) { // TestPipelineCleanedUp ensures pipelines are properly cleaned up after // subscriptions are closed. func TestPipelineCleanedUp(t *testing.T) { - bus, err := NewEventBus(nil) + bus, err := NewEventBus("", nil) if err != nil { t.Fatal(err) } @@ -683,6 +683,10 @@ func TestPipelineCleanedUp(t *testing.T) { if err != nil { t.Fatal(err) } + // check that the filters are set + if !bus.filters.anyMatch(namespace.RootNamespace, eventType) { + t.Fatal() + } if !bus.broker.IsAnyPipelineRegistered(eventTypeAll) { cancel() t.Fatal() @@ -693,4 +697,9 @@ func TestPipelineCleanedUp(t *testing.T) { if bus.broker.IsAnyPipelineRegistered(eventTypeAll) { t.Fatal() } + + // and that the filters are cleaned up + if bus.filters.anyMatch(namespace.RootNamespace, eventType) { + t.Fatal() + } } diff --git a/vault/eventbus/filter.go b/vault/eventbus/filter.go new file mode 100644 index 000000000000..7d4268aacfe8 --- /dev/null +++ b/vault/eventbus/filter.go @@ -0,0 +1,120 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package eventbus + +import ( + "slices" + "sort" + "sync" + + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/sdk/logical" + "github.com/ryanuber/go-glob" +) + +// Filters keeps track of all the event patterns that each node is interested in. +type Filters struct { + lock sync.RWMutex + self nodeID + filters map[nodeID]*NodeFilter +} + +// nodeID is used to syntactically indicate that the string is a node's name identifier. +type nodeID string + +// pattern is used to represent one or more combinations of patterns +type pattern struct { + eventTypePattern string + namespacePatterns []string +} + +// NodeFilter keeps track of all patterns that a particular node is interested in. +type NodeFilter struct { + patterns []pattern +} + +func (nf *NodeFilter) match(ns *namespace.Namespace, eventType logical.EventType) bool { + if nf == nil { + return false + } + for _, p := range nf.patterns { + if glob.Glob(p.eventTypePattern, string(eventType)) { + for _, nsp := range p.namespacePatterns { + if glob.Glob(nsp, ns.Path) { + return true + } + } + } + } + return false +} + +// NewFilters creates an empty set of filters to keep track of each node's pattern interests. +func NewFilters(self string) *Filters { + return &Filters{ + self: nodeID(self), + filters: map[nodeID]*NodeFilter{}, + } +} + +// addPattern adds a pattern to a node's list. +func (f *Filters) addPattern(node nodeID, namespacePatterns []string, eventTypePattern string) { + f.lock.Lock() + defer f.lock.Unlock() + if _, ok := f.filters[node]; !ok { + f.filters[node] = &NodeFilter{} + } + nsPatterns := slices.Clone(namespacePatterns) + sort.Strings(nsPatterns) + f.filters[node].patterns = append(f.filters[node].patterns, pattern{eventTypePattern: eventTypePattern, namespacePatterns: nsPatterns}) +} + +func (f *Filters) addNsPattern(node nodeID, ns *namespace.Namespace, eventTypePattern string) { + f.addPattern(node, []string{ns.Path}, eventTypePattern) +} + +// removePattern removes a pattern from a node's list. +func (f *Filters) removePattern(node nodeID, namespacePatterns []string, eventTypePattern string) { + nsPatterns := slices.Clone(namespacePatterns) + sort.Strings(nsPatterns) + check := pattern{eventTypePattern: eventTypePattern, namespacePatterns: nsPatterns} + f.lock.Lock() + defer f.lock.Unlock() + filters, ok := f.filters[node] + if !ok { + return + } + filters.patterns = slices.DeleteFunc(filters.patterns, func(m pattern) bool { + return m.eventTypePattern == check.eventTypePattern && + slices.Equal(m.namespacePatterns, check.namespacePatterns) + }) +} + +func (f *Filters) removeNsPattern(node nodeID, ns *namespace.Namespace, eventTypePattern string) { + f.removePattern(node, []string{ns.Path}, eventTypePattern) +} + +// anyMatch returns true if any node's pattern list matches the arguments. +func (f *Filters) anyMatch(ns *namespace.Namespace, eventType logical.EventType) bool { + f.lock.RLock() + defer f.lock.RUnlock() + for _, nf := range f.filters { + if nf.match(ns, eventType) { + return true + } + } + return false +} + +// nodeMatch returns true if the given node's pattern list matches the arguments. +func (f *Filters) nodeMatch(node nodeID, ns *namespace.Namespace, eventType logical.EventType) bool { + f.lock.RLock() + defer f.lock.RUnlock() + return f.filters[node].match(ns, eventType) +} + +// localMatch returns true if the local node's pattern list matches the arguments. +func (f *Filters) localMatch(ns *namespace.Namespace, eventType logical.EventType) bool { + return f.nodeMatch(f.self, ns, eventType) +} diff --git a/vault/eventbus/filter_test.go b/vault/eventbus/filter_test.go new file mode 100644 index 000000000000..034fb74f95a6 --- /dev/null +++ b/vault/eventbus/filter_test.go @@ -0,0 +1,31 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package eventbus + +import ( + "testing" + + "github.com/hashicorp/vault/helper/namespace" + "github.com/stretchr/testify/assert" +) + +// TestFilters_AddRemoveMatchLocal checks that basic matching, adding, and removing of patterns all work. +func TestFilters_AddRemoveMatchLocal(t *testing.T) { + f := NewFilters("self") + ns := &namespace.Namespace{ + ID: "ns1", + Path: "ns1", + } + + assert.False(t, f.localMatch(ns, "abc")) + assert.False(t, f.anyMatch(ns, "abc")) + f.addNsPattern("self", ns, "abc") + assert.True(t, f.localMatch(ns, "abc")) + assert.False(t, f.localMatch(ns, "abcd")) + assert.True(t, f.anyMatch(ns, "abc")) + assert.False(t, f.anyMatch(ns, "abcd")) + f.removeNsPattern("self", ns, "abc") + assert.False(t, f.localMatch(ns, "abc")) + assert.False(t, f.anyMatch(ns, "abc")) +}