From 6a87419ab0ee1bb2346de70f90c93d955edbbe37 Mon Sep 17 00:00:00 2001
From: Theron Voran <tvoran@users.noreply.github.com>
Date: Fri, 31 Jan 2025 11:11:44 -0800
Subject: [PATCH] CE changes for VAULT-33018 (#29470)

---
 changelog/29470.txt        |  3 ++
 http/handler.go            |  1 -
 vault/eventbus/bus.go      | 23 ++++++++---
 vault/eventbus/bus_test.go | 80 +++++++++++++++++++++++++++-----------
 4 files changed, 77 insertions(+), 30 deletions(-)
 create mode 100644 changelog/29470.txt

diff --git a/changelog/29470.txt b/changelog/29470.txt
new file mode 100644
index 000000000000..36f28af7ed94
--- /dev/null
+++ b/changelog/29470.txt
@@ -0,0 +1,3 @@
+```release-note:improvement
+events (enterprise): Send events downstream to performance standby nodes in a cluster, removing the need to redirect client event subscriptions to the active node.
+```
diff --git a/http/handler.go b/http/handler.go
index d8f040251ff5..4be68dbd0875 100644
--- a/http/handler.go
+++ b/http/handler.go
@@ -127,7 +127,6 @@ func init() {
 		"!sys/storage/raft/snapshot-auto/config",
 	})
 	websocketPaths.AddPaths(websocketRawPaths)
-	alwaysRedirectPaths.AddPaths(websocketRawPaths)
 }
 
 type HandlerAnchor struct{}
diff --git a/vault/eventbus/bus.go b/vault/eventbus/bus.go
index b9a96a41ef16..a55f3d4aedbc 100644
--- a/vault/eventbus/bus.go
+++ b/vault/eventbus/bus.go
@@ -118,7 +118,7 @@ func patchMountPath(data *logical.EventData, pluginInfo *logical.EventPluginInfo
 // the namespace and plugin info automatically.
 // The context passed in is currently ignored to ensure that the event is sent if the context is short-lived,
 // such as with an HTTP request context.
-func (bus *EventBus) SendEventInternal(_ context.Context, ns *namespace.Namespace, pluginInfo *logical.EventPluginInfo, eventType logical.EventType, data *logical.EventData) error {
+func (bus *EventBus) SendEventInternal(_ context.Context, ns *namespace.Namespace, pluginInfo *logical.EventPluginInfo, eventType logical.EventType, forwarded bool, data *logical.EventData) error {
 	if ns == nil {
 		return namespace.ErrNoNamespace
 	}
@@ -126,11 +126,17 @@ func (bus *EventBus) SendEventInternal(_ context.Context, ns *namespace.Namespac
 		return ErrNotStarted
 	}
 	eventReceived := &logical.EventReceived{
-		Event:      patchMountPath(data, pluginInfo),
 		Namespace:  ns.Path,
 		EventType:  string(eventType),
 		PluginInfo: pluginInfo,
 	}
+	// If the event has been forwarded downstream, no need to patch the mount
+	// path again
+	if forwarded {
+		eventReceived.Event = data
+	} else {
+		eventReceived.Event = patchMountPath(data, pluginInfo)
+	}
 
 	// We can't easily know when the SendEvent is complete, so we can't call the cancel function.
 	// But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long.
@@ -161,10 +167,10 @@ func (bus *EventBus) WithPlugin(ns *namespace.Namespace, eventPluginInfo *logica
 // This function does *not* wait for all subscribers to acknowledge before returning.
 // The context passed in is currently ignored.
 func (bus *pluginEventBus) SendEvent(ctx context.Context, eventType logical.EventType, data *logical.EventData) error {
-	return bus.bus.SendEventInternal(ctx, bus.namespace, bus.pluginInfo, eventType, data)
+	return bus.bus.SendEventInternal(ctx, bus.namespace, bus.pluginInfo, eventType, false, data)
 }
 
-func NewEventBus(localClusterID string, logger hclog.Logger) (*EventBus, error) {
+func NewEventBus(localNodeID string, logger hclog.Logger) (*EventBus, error) {
 	broker, err := eventlogger.NewBroker()
 	if err != nil {
 		return nil, err
@@ -180,7 +186,7 @@ func NewEventBus(localClusterID string, logger hclog.Logger) (*EventBus, error)
 		logger = hclog.Default().Named("events")
 	}
 
-	sourceUrl, err := url.Parse("vault://" + localClusterID)
+	sourceUrl, err := url.Parse("vault://" + localNodeID)
 	if err != nil {
 		return nil, err
 	}
@@ -198,7 +204,7 @@ func NewEventBus(localClusterID string, logger hclog.Logger) (*EventBus, error)
 		formatterNodeID:            formatterNodeID,
 		timeout:                    defaultTimeout,
 		cloudEventsFormatterFilter: cloudEventsFormatterFilter,
-		filters:                    NewFilters(localClusterID),
+		filters:                    NewFilters(localNodeID),
 	}, nil
 }
 
@@ -336,6 +342,11 @@ func (bus *EventBus) NotifyOnClusterFilterChanges(ctx context.Context, cluster s
 	return bus.filters.watch(ctx, clusterID(cluster))
 }
 
+// NewAllEventsSubscription creates a new subscription to all events.
+func (bus *EventBus) NewAllEventsSubscription(ctx context.Context) (<-chan *eventlogger.Event, context.CancelFunc, error) {
+	return bus.subscribeInternal(ctx, nil, "*", "", nil)
+}
+
 // NewGlobalSubscription creates a new subscription to all events that match the global filter.
 func (bus *EventBus) NewGlobalSubscription(ctx context.Context) (<-chan *eventlogger.Event, context.CancelFunc, error) {
 	g := globalCluster
diff --git a/vault/eventbus/bus_test.go b/vault/eventbus/bus_test.go
index cb4c12e950fc..7072cb09a1d7 100644
--- a/vault/eventbus/bus_test.go
+++ b/vault/eventbus/bus_test.go
@@ -36,14 +36,14 @@ func TestBusBasics(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, event)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, false, event)
 	if !errors.Is(err, ErrNotStarted) {
 		t.Errorf("Expected not started error but got: %v", err)
 	}
 
 	bus.Start()
 
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, event)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, false, event)
 	if err != nil {
 		t.Errorf("Expected no error sending: %v", err)
 	}
@@ -59,7 +59,7 @@ func TestBusBasics(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, event)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, false, event)
 	if err != nil {
 		t.Error(err)
 	}
@@ -100,7 +100,7 @@ func TestBusIgnoresSendContext(t *testing.T) {
 	ctx, cancel := context.WithCancel(context.Background())
 	cancel() // cancel immediately
 
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, event)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, false, event)
 	if err != nil {
 		t.Errorf("Expected no error sending: %v", err)
 	}
@@ -144,7 +144,7 @@ func TestSubscribeNonRootNamespace(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	err = bus.SendEventInternal(ctx, ns, nil, eventType, event)
+	err = bus.SendEventInternal(ctx, ns, nil, eventType, false, event)
 	if err != nil {
 		t.Error(err)
 	}
@@ -190,7 +190,7 @@ func TestNamespaceFiltering(t *testing.T) {
 	err = bus.SendEventInternal(ctx, &namespace.Namespace{
 		ID:   "abc",
 		Path: "/abc",
-	}, nil, eventType, event)
+	}, nil, eventType, false, event)
 	if err != nil {
 		t.Error(err)
 	}
@@ -203,7 +203,7 @@ func TestNamespaceFiltering(t *testing.T) {
 		// okay
 	}
 
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, event)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, false, event)
 	if err != nil {
 		t.Error(err)
 	}
@@ -253,11 +253,11 @@ func TestBus2Subscriptions(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType2, event2)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType2, false, event2)
 	if err != nil {
 		t.Error(err)
 	}
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType1, event1)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType1, false, event1)
 	if err != nil {
 		t.Error(err)
 	}
@@ -345,7 +345,7 @@ func TestBusSubscriptionsCancel(t *testing.T) {
 			if err != nil {
 				t.Fatal(err)
 			}
-			err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, event)
+			err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, false, event)
 			if err != nil {
 				t.Error(err)
 			}
@@ -357,7 +357,7 @@ func TestBusSubscriptionsCancel(t *testing.T) {
 			if err != nil {
 				t.Fatal(err)
 			}
-			err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, event)
+			err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, eventType, false, event)
 			if err != nil {
 				t.Error(err)
 			}
@@ -427,11 +427,11 @@ func TestBusWildcardSubscriptions(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, barEventType, event2)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, barEventType, false, event2)
 	if err != nil {
 		t.Error(err)
 	}
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, fooEventType, event1)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, fooEventType, false, event1)
 	if err != nil {
 		t.Error(err)
 	}
@@ -504,7 +504,7 @@ func TestDataPathIsPrependedWithMount(t *testing.T) {
 	}
 
 	// no plugin info means nothing should change
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, fooEventType, event)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, nil, fooEventType, false, event)
 	if err != nil {
 		t.Error(err)
 	}
@@ -530,7 +530,7 @@ func TestDataPathIsPrependedWithMount(t *testing.T) {
 		PluginVersion: "v1.13.1+builtin",
 		Version:       "2",
 	}
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, &pluginInfo, fooEventType, event)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, &pluginInfo, fooEventType, false, event)
 	if err != nil {
 		t.Error(err)
 	}
@@ -571,7 +571,7 @@ func TestDataPathIsPrependedWithMount(t *testing.T) {
 	if err := event.Metadata.UnmarshalJSON(metadataBytes); err != nil {
 		t.Fatal(err)
 	}
-	err = bus.SendEventInternal(ctx, namespace.RootNamespace, &pluginInfo, fooEventType, event)
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, &pluginInfo, fooEventType, false, event)
 	if err != nil {
 		t.Error(err)
 	}
@@ -587,6 +587,40 @@ func TestDataPathIsPrependedWithMount(t *testing.T) {
 	case <-timeout:
 		t.Error("Timeout waiting for event")
 	}
+
+	// Test that a forwarded event does not have anything prepended
+	event, err = logical.NewEvent()
+	if err != nil {
+		t.Fatal(err)
+	}
+	metadata = map[string]string{
+		logical.EventMetadataDataPath: "your/secret/path",
+		"not_touched":                 "xyz",
+	}
+	metadataBytes, err = json.Marshal(metadata)
+	if err != nil {
+		t.Fatal(err)
+	}
+	event.Metadata = &structpb.Struct{}
+	if err := event.Metadata.UnmarshalJSON(metadataBytes); err != nil {
+		t.Fatal(err)
+	}
+	err = bus.SendEventInternal(ctx, namespace.RootNamespace, &pluginInfo, fooEventType, true, event)
+	if err != nil {
+		t.Error(err)
+	}
+
+	timeout = time.After(1 * time.Second)
+	select {
+	case message := <-ch:
+		metadata := message.Payload.(*logical.EventReceived).Event.Metadata.AsMap()
+		assert.Contains(t, metadata, "not_touched")
+		assert.Equal(t, "xyz", metadata["not_touched"])
+		assert.Contains(t, metadata, "data_path")
+		assert.Equal(t, "your/secret/path", metadata["data_path"])
+	case <-timeout:
+		t.Error("Timeout waiting for event")
+	}
 }
 
 // TestBexpr tests go-bexpr filters are evaluated on an event.
@@ -625,7 +659,7 @@ func TestBexpr(t *testing.T) {
 			PluginVersion: "v1.13.1+builtin",
 			Version:       "2",
 		}
-		return bus.SendEventInternal(ctx, namespace.RootNamespace, &pluginInfo, logical.EventType(eventType), event)
+		return bus.SendEventInternal(ctx, namespace.RootNamespace, &pluginInfo, logical.EventType(eventType), false, event)
 	}
 
 	testCases := []struct {
@@ -725,7 +759,7 @@ func TestSubscribeGlobal(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev)
+	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", false, ev)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -769,7 +803,7 @@ func TestSubscribeGlobal_WithApply(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev)
+	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", false, ev)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -805,7 +839,7 @@ func TestSubscribeCluster(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev)
+	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", false, ev)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -847,7 +881,7 @@ func TestSubscribeCluster_WithApply(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev)
+	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", false, ev)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -890,7 +924,7 @@ func TestClearGlobalFilter(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev)
+	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", false, ev)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -931,7 +965,7 @@ func TestClearClusterFilter(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", ev)
+	err = bus.SendEventInternal(nil, namespace.RootNamespace, nil, "abcd", false, ev)
 	if err != nil {
 		t.Fatal(err)
 	}