Skip to content

Commit

Permalink
backport of commit 4c11d09 (#19262)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom Proctor <[email protected]>
  • Loading branch information
hc-github-team-secure-vault-core and tomhjp authored Feb 20, 2023
1 parent 6ae50fe commit 6e323b6
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 14 deletions.
2 changes: 1 addition & 1 deletion http/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
_, _, err := core.CheckToken(ctx, req, false)
if err != nil {
if errors.Is(err, logical.ErrPermissionDenied) {
respondError(w, http.StatusUnauthorized, logical.ErrPermissionDenied)
respondError(w, http.StatusForbidden, logical.ErrPermissionDenied)
return
}
logger.Debug("Error validating token", "error", err)
Expand Down
82 changes: 69 additions & 13 deletions http/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
Expand All @@ -34,7 +35,7 @@ func TestEventsSubscribe(t *testing.T) {

stop := atomic.Bool{}

eventType := "abc"
const eventType = "abc"

// send some events
go func() {
Expand All @@ -43,7 +44,10 @@ func TestEventsSubscribe(t *testing.T) {
if err != nil {
core.Logger().Info("Error generating UUID, exiting sender", "error", err)
}
err = core.Events().SendInternal(namespace.RootContext(context.Background()), namespace.RootNamespace, nil, logical.EventType(eventType), &logical.EventData{
pluginInfo := &logical.EventPluginInfo{
MountPath: "secret",
}
err = core.Events().SendInternal(namespace.RootContext(context.Background()), namespace.RootNamespace, pluginInfo, logical.EventType(eventType), &logical.EventData{
Id: id,
Metadata: nil,
EntityIds: nil,
Expand All @@ -60,9 +64,7 @@ func TestEventsSubscribe(t *testing.T) {
stop.Store(true)
})

ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancelFunc)

ctx := context.Background()
wsAddr := strings.Replace(addr, "http", "ws", 1)

testCases := []struct {
Expand All @@ -71,14 +73,6 @@ func TestEventsSubscribe(t *testing.T) {

for _, testCase := range testCases {
url := fmt.Sprintf("%s/v1/sys/events/subscribe/%s?json=%v", wsAddr, eventType, testCase.json)
// check that the connection fails if we don't have a token
_, _, err := websocket.Dial(ctx, url, nil)
if err == nil {
t.Error("Expected websocket error but got none")
} else if !strings.HasSuffix(err.Error(), "401") {
t.Errorf("Expected 401 websocket but got %v", err)
}

conn, _, err := websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPHeader: http.Header{"x-vault-token": []string{token}},
})
Expand All @@ -101,3 +95,65 @@ func TestEventsSubscribe(t *testing.T) {
}
}
}

// TestEventsSubscribeAuth tests that unauthenticated and unauthorized subscriptions
// fail correctly.
func TestEventsSubscribeAuth(t *testing.T) {
core := vault.TestCore(t)
ln, addr := TestServer(t, core)
defer ln.Close()

// unseal the core
keys, root := vault.TestCoreInit(t, core)
for _, key := range keys {
_, err := core.Unseal(key)
if err != nil {
t.Fatal(err)
}
}

var nonPrivilegedToken string
// Fetch a valid non privileged token.
{
config := api.DefaultConfig()
config.Address = addr

client, err := api.NewClient(config)
if err != nil {
t.Fatal(err)
}
client.SetToken(root)

secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{Policies: []string{"default"}})
if err != nil {
t.Fatal(err)
}
if secret.Auth.ClientToken == "" {
t.Fatal("Failed to fetch a non privileged token")
}
nonPrivilegedToken = secret.Auth.ClientToken
}

ctx := context.Background()
wsAddr := strings.Replace(addr, "http", "ws", 1)

// Get a 403 with no token.
_, resp, err := websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", nil)
if err == nil {
t.Error("Expected websocket error but got none")
}
if resp == nil || resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected 403 but got %+v", resp)
}

// Get a 403 with a non privileged token.
_, resp, err = websocket.Dial(ctx, wsAddr+"/v1/sys/events/subscribe/abc", &websocket.DialOptions{
HTTPHeader: http.Header{"x-vault-token": []string{nonPrivilegedToken}},
})
if err == nil {
t.Error("Expected websocket error but got none")
}
if resp == nil || resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected 403 but got %+v", resp)
}
}

0 comments on commit 6e323b6

Please sign in to comment.