From 4c11d090cdfdd2aa6498c099ca0cd88f5e6af488 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Mon, 20 Feb 2023 16:24:27 +0000 Subject: [PATCH] Events API uses consistent error codes (#19246) --- http/events.go | 2 +- http/events_test.go | 82 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 70 insertions(+), 14 deletions(-) diff --git a/http/events.go b/http/events.go index f516b3ccbc33..e5eaa78be972 100644 --- a/http/events.go +++ b/http/events.go @@ -84,7 +84,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler _, _, err := core.CheckToken(ctx, req, false) if err != nil { if errors.Is(err, logical.ErrPermissionDenied) { - respondError(w, http.StatusUnauthorized, logical.ErrPermissionDenied) + respondError(w, http.StatusForbidden, logical.ErrPermissionDenied) return } logger.Debug("Error validating token", "error", err) diff --git a/http/events_test.go b/http/events_test.go index 3fe5d68e96c1..4cba7d1bd1c3 100644 --- a/http/events_test.go +++ b/http/events_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" @@ -34,7 +35,7 @@ func TestEventsSubscribe(t *testing.T) { stop := atomic.Bool{} - eventType := "abc" + const eventType = "abc" // send some events go func() { @@ -43,7 +44,10 @@ func TestEventsSubscribe(t *testing.T) { if err != nil { core.Logger().Info("Error generating UUID, exiting sender", "error", err) } - err = core.Events().SendInternal(namespace.RootContext(context.Background()), namespace.RootNamespace, nil, logical.EventType(eventType), &logical.EventData{ + pluginInfo := &logical.EventPluginInfo{ + MountPath: "secret", + } + err = core.Events().SendInternal(namespace.RootContext(context.Background()), namespace.RootNamespace, pluginInfo, logical.EventType(eventType), &logical.EventData{ Id: id, Metadata: nil, EntityIds: nil, @@ -60,9 +64,7 @@ func TestEventsSubscribe(t *testing.T) { stop.Store(true) }) - ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second) - t.Cleanup(cancelFunc) - + ctx := context.Background() wsAddr := strings.Replace(addr, "http", "ws", 1) testCases := []struct { @@ -71,14 +73,6 @@ func TestEventsSubscribe(t *testing.T) { for _, testCase := range testCases { url := fmt.Sprintf("%s/v1/sys/events/subscribe/%s?json=%v", wsAddr, eventType, testCase.json) - // check that the connection fails if we don't have a token - _, _, err := websocket.Dial(ctx, url, nil) - if err == nil { - t.Error("Expected websocket error but got none") - } else if !strings.HasSuffix(err.Error(), "401") { - t.Errorf("Expected 401 websocket but got %v", err) - } - conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{ HTTPHeader: http.Header{"x-vault-token": []string{token}}, }) @@ -101,3 +95,65 @@ func TestEventsSubscribe(t *testing.T) { } } } + +// TestEventsSubscribeAuth tests that unauthenticated and unauthorized subscriptions +// fail correctly. +func TestEventsSubscribeAuth(t *testing.T) { + core := vault.TestCore(t) + ln, addr := TestServer(t, core) + defer ln.Close() + + // unseal the core + keys, root := vault.TestCoreInit(t, core) + for _, key := range keys { + _, err := core.Unseal(key) + if err != nil { + t.Fatal(err) + } + } + + var nonPrivilegedToken string + // Fetch a valid non privileged token. + { + config := api.DefaultConfig() + config.Address = addr + + client, err := api.NewClient(config) + if err != nil { + t.Fatal(err) + } + client.SetToken(root) + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{Policies: []string{"default"}}) + if err != nil { + t.Fatal(err) + } + if secret.Auth.ClientToken == "" { + t.Fatal("Failed to fetch a non privileged token") + } + nonPrivilegedToken = secret.Auth.ClientToken + } + + ctx := context.Background() + wsAddr := strings.Replace(addr, "http", "ws", 1) + + // Get a 403 with no token. + _, resp, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", nil) + if err == nil { + t.Error("Expected websocket error but got none") + } + if resp == nil || resp.StatusCode != http.StatusForbidden { + t.Errorf("Expected 403 but got %+v", resp) + } + + // Get a 403 with a non privileged token. + _, resp, err = websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", &websocket.DialOptions{ + HTTPHeader: http.Header{"x-vault-token": []string{nonPrivilegedToken}}, + }) + if err == nil { + t.Error("Expected websocket error but got none") + } + if resp == nil || resp.StatusCode != http.StatusForbidden { + t.Errorf("Expected 403 but got %+v", resp) + } +}