From 77ceb7dde00e328f11d69d90b62004f0e116b451 Mon Sep 17 00:00:00 2001 From: Mike Jarmy Date: Fri, 11 Oct 2019 18:56:07 -0400 Subject: [PATCH] Vault Agent Cache Auto-Auth SSRF Protection (#7627) * implement SSRF protection header * add test for SSRF protection header * cleanup * refactor * implement SSRF header on a per-listener basis * cleanup * cleanup * creat unit test for agent SSRF * improve unit test for agent SSRF * add VaultRequest SSRF header to CLI * fix unit test * cleanup * improve test suite * simplify check for Vault-Request header * add constant for Vault-Request header * improve test suite * change 'config' to 'agentConfig' * Revert "change 'config' to 'agentConfig'" This reverts commit 14ee72d21fff8027966ee3c89dd3ac41d849206f. * do not remove header from request * change header name to X-Vault-Request * simplify http.Handler logic * cleanup * simplify http.Handler logic * use stdlib errors package --- api/client.go | 8 +- command/agent.go | 48 ++++++- command/agent/config/config.go | 3 + command/agent_test.go | 250 +++++++++++++++++++++++++++++++++ sdk/helper/consts/consts.go | 4 + 5 files changed, 304 insertions(+), 9 deletions(-) diff --git a/api/client.go b/api/client.go index 7abc7e006bbb..216e8d22d938 100644 --- a/api/client.go +++ b/api/client.go @@ -427,10 +427,14 @@ func NewClient(c *Config) (*Client, error) { } client := &Client{ - addr: u, - config: c, + addr: u, + config: c, + headers: make(http.Header), } + // Add the VaultRequest SSRF protection header + client.headers[consts.RequestHeaderName] = []string{"true"} + if token := os.Getenv(EnvVaultToken); token != "" { client.token = token } diff --git a/command/agent.go b/command/agent.go index e338ad86ed1a..abb5e8cebd9f 100644 --- a/command/agent.go +++ b/command/agent.go @@ -2,6 +2,7 @@ package command import ( "context" + "errors" "flag" "fmt" "io" @@ -28,13 +29,14 @@ import ( "github.com/hashicorp/vault/command/agent/auth/jwt" "github.com/hashicorp/vault/command/agent/auth/kubernetes" "github.com/hashicorp/vault/command/agent/cache" - "github.com/hashicorp/vault/command/agent/config" + agentConfig "github.com/hashicorp/vault/command/agent/config" "github.com/hashicorp/vault/command/agent/sink" "github.com/hashicorp/vault/command/agent/sink/file" "github.com/hashicorp/vault/command/agent/sink/inmem" gatedwriter "github.com/hashicorp/vault/helper/gated-writer" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/version" "github.com/kr/pretty" "github.com/mitchellh/cli" @@ -192,7 +194,7 @@ func (c *AgentCommand) Run(args []string) int { } // Load the configuration - config, err := config.LoadConfig(c.flagConfigs[0]) + config, err := agentConfig.LoadConfig(c.flagConfigs[0]) if err != nil { c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err)) return 1 @@ -418,11 +420,8 @@ func (c *AgentCommand) Run(args []string) int { }) } - // Create a muxer and add paths relevant for the lease cache layer - mux := http.NewServeMux() - mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx)) - - mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink)) + // Create the request handler + cacheHandler := cache.Handler(ctx, cacheLogger, leaseCache, inmemSink) var listeners []net.Listener for i, lnConfig := range config.Listeners { @@ -434,6 +433,25 @@ func (c *AgentCommand) Run(args []string) int { listeners = append(listeners, ln) + // Parse 'require_request_header' listener config option, and wrap + // the request handler if necessary + muxHandler := cacheHandler + if v, ok := lnConfig.Config[agentConfig.RequireRequestHeader]; ok { + switch v { + case true: + muxHandler = verifyRequestHeader(muxHandler) + case false /* noop */ : + default: + c.UI.Error(fmt.Sprintf("Invalid value for 'require_request_header': %v", v)) + return 1 + } + } + + // Create a muxer and add paths relevant for the lease cache layer + mux := http.NewServeMux() + mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx)) + mux.Handle("/", muxHandler) + scheme := "https://" if tlsConf == nil { scheme = "http://" @@ -536,6 +554,22 @@ func (c *AgentCommand) Run(args []string) int { return 0 } +// verifyRequestHeader wraps an http.Handler inside a Handler that checks for +// the request header that is used for SSRF protection. +func verifyRequestHeader(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if val, ok := r.Header[consts.RequestHeaderName]; !ok || len(val) != 1 || val[0] != "true" { + logical.RespondError(w, + http.StatusPreconditionFailed, + errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName))) + return + } + + handler.ServeHTTP(w, r) + }) +} + func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) { var isFlagSet bool f.Visit(func(f *flag.Flag) { diff --git a/command/agent/config/config.go b/command/agent/config/config.go index 9f9fafa1a83a..3d77640c12e5 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -45,6 +45,9 @@ type Listener struct { Config map[string]interface{} } +// RequireRequestHeader is a listener configuration option +const RequireRequestHeader = "require_request_header" + type AutoAuth struct { Method *Method `hcl:"-"` Sinks []*Sink `hcl:"sinks"` diff --git a/command/agent_test.go b/command/agent_test.go index f1fbf67fdcaa..54f7773e677f 100644 --- a/command/agent_test.go +++ b/command/agent_test.go @@ -1,16 +1,23 @@ package command import ( + "encoding/json" "fmt" "io/ioutil" + "net/http" "os" + "reflect" + "sync" "testing" + "time" hclog "github.com/hashicorp/go-hclog" vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt" "github.com/hashicorp/vault/api" + credAppRole "github.com/hashicorp/vault/builtin/credential/approle" "github.com/hashicorp/vault/command/agent" vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" @@ -370,3 +377,246 @@ auto_auth { t.Fatal("sink 1/2 values don't match") } } + +func TestAgent_RequireRequestHeader(t *testing.T) { + + // request issues HTTP requests. + request := func(client *api.Client, req *api.Request, expectedStatusCode int) map[string]interface{} { + resp, err := client.RawRequest(req) + if err != nil { + t.Fatalf("err: %s", err) + } + if resp.StatusCode != expectedStatusCode { + t.Fatalf("expected status code %d, not %d", expectedStatusCode, resp.StatusCode) + } + + bytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("err: %s", err) + } + if len(bytes) == 0 { + return nil + } + + var body map[string]interface{} + err = json.Unmarshal(bytes, &body) + if err != nil { + t.Fatalf("err: %s", err) + } + return body + } + + // makeTempFile creates a temp file and populates it. + makeTempFile := func(name, contents string) string { + f, err := ioutil.TempFile("", name) + if err != nil { + t.Fatal(err) + } + path := f.Name() + f.WriteString(contents) + f.Close() + return path + } + + // newApiClient creates an *api.Client. + newApiClient := func(addr string, includeVaultRequestHeader bool) *api.Client { + conf := api.DefaultConfig() + conf.Address = addr + cli, err := api.NewClient(conf) + if err != nil { + t.Fatalf("err: %s", err) + } + + h := cli.Headers() + val, ok := h[consts.RequestHeaderName] + if !ok || !reflect.DeepEqual(val, []string{"true"}) { + t.Fatalf("invalid %s header", consts.RequestHeaderName) + } + if !includeVaultRequestHeader { + delete(h, consts.RequestHeaderName) + cli.SetHeaders(h) + } + + return cli + } + + //---------------------------------------------------- + // Start the server and agent + //---------------------------------------------------- + + // Start a vault server + logger := logging.NewVaultLogger(hclog.Trace) + cluster := vault.NewTestCluster(t, + &vault.CoreConfig{ + Logger: logger, + CredentialBackends: map[string]logical.Factory{ + "approle": credAppRole.Factory, + }, + }, + &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + vault.TestWaitActive(t, cluster.Cores[0].Core) + serverClient := cluster.Cores[0].Client + + // Enable the approle auth method + req := serverClient.NewRequest("POST", "/v1/sys/auth/approle") + req.BodyBytes = []byte(`{ + "type": "approle" + }`) + request(serverClient, req, 204) + + // Create a named role + req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role") + req.BodyBytes = []byte(`{ + "secret_id_num_uses": "10", + "secret_id_ttl": "1m", + "token_max_ttl": "1m", + "token_num_uses": "10", + "token_ttl": "1m" + }`) + request(serverClient, req, 204) + + // Fetch the RoleID of the named role + req = serverClient.NewRequest("GET", "/v1/auth/approle/role/test-role/role-id") + body := request(serverClient, req, 200) + data := body["data"].(map[string]interface{}) + roleID := data["role_id"].(string) + + // Get a SecretID issued against the named role + req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role/secret-id") + body = request(serverClient, req, 200) + data = body["data"].(map[string]interface{}) + secretID := data["secret_id"].(string) + + // Write the RoleID and SecretID to temp files + roleIDPath := makeTempFile("role_id.txt", roleID+"\n") + secretIDPath := makeTempFile("secret_id.txt", secretID+"\n") + defer os.Remove(roleIDPath) + defer os.Remove(secretIDPath) + + // Get a temp file path we can use for the sink + sinkPath := makeTempFile("sink.txt", "") + defer os.Remove(sinkPath) + + // Create a config file + config := ` +auto_auth { + method "approle" { + mount_path = "auth/approle" + config = { + role_id_file_path = "%s" + secret_id_file_path = "%s" + } + } + + sink "file" { + config = { + path = "%s" + } + } +} + +cache { + use_auto_auth_token = true +} + +listener "tcp" { + address = "127.0.0.1:8101" + tls_disable = true +} +listener "tcp" { + address = "127.0.0.1:8102" + tls_disable = true + require_request_header = false +} +listener "tcp" { + address = "127.0.0.1:8103" + tls_disable = true + require_request_header = true +} +` + config = fmt.Sprintf(config, roleIDPath, secretIDPath, sinkPath) + configPath := makeTempFile("config.hcl", config) + defer os.Remove(configPath) + + // Start the agent + ui, cmd := testAgentCommand(t, logger) + cmd.client = serverClient + cmd.startedCh = make(chan struct{}) + + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + code := cmd.Run([]string{"-config", configPath}) + if code != 0 { + t.Errorf("non-zero return code when running agent: %d", code) + t.Logf("STDOUT from agent:\n%s", ui.OutputWriter.String()) + t.Logf("STDERR from agent:\n%s", ui.ErrorWriter.String()) + } + wg.Done() + }() + + select { + case <-cmd.startedCh: + case <-time.After(5 * time.Second): + t.Errorf("timeout") + } + + // defer agent shutdown + defer func() { + cmd.ShutdownCh <- struct{}{} + wg.Wait() + }() + + //---------------------------------------------------- + // Perform the tests + //---------------------------------------------------- + + // Test against a listener configuration that omits + // 'require_request_header', with the header missing from the request. + agentClient := newApiClient("http://127.0.0.1:8101", false) + req = agentClient.NewRequest("GET", "/v1/sys/health") + request(agentClient, req, 200) + + // Test against a listener configuration that sets 'require_request_header' + // to 'false', with the header missing from the request. + agentClient = newApiClient("http://127.0.0.1:8102", false) + req = agentClient.NewRequest("GET", "/v1/sys/health") + request(agentClient, req, 200) + + // Test against a listener configuration that sets 'require_request_header' + // to 'true', with the header missing from the request. + agentClient = newApiClient("http://127.0.0.1:8103", false) + req = agentClient.NewRequest("GET", "/v1/sys/health") + resp, err := agentClient.RawRequest(req) + if err == nil { + t.Fatalf("expected error") + } + if resp.StatusCode != http.StatusPreconditionFailed { + t.Fatalf("expected status code %d, not %d", http.StatusPreconditionFailed, resp.StatusCode) + } + + // Test against a listener configuration that sets 'require_request_header' + // to 'true', with an invalid header present in the request. + agentClient = newApiClient("http://127.0.0.1:8103", false) + h := agentClient.Headers() + h[consts.RequestHeaderName] = []string{"bogus"} + agentClient.SetHeaders(h) + req = agentClient.NewRequest("GET", "/v1/sys/health") + resp, err = agentClient.RawRequest(req) + if err == nil { + t.Fatalf("expected error") + } + if resp.StatusCode != http.StatusPreconditionFailed { + t.Fatalf("expected status code %d, not %d", http.StatusPreconditionFailed, resp.StatusCode) + } + + // Test against a listener configuration that sets 'require_request_header' + // to 'true', with the proper header present in the request. + agentClient = newApiClient("http://127.0.0.1:8103", true) + req = agentClient.NewRequest("GET", "/v1/sys/health") + request(agentClient, req, 200) +} diff --git a/sdk/helper/consts/consts.go b/sdk/helper/consts/consts.go index 769a7858369d..92b570cad053 100644 --- a/sdk/helper/consts/consts.go +++ b/sdk/helper/consts/consts.go @@ -12,6 +12,10 @@ const ( // AuthHeaderName is the name of the header containing the token. AuthHeaderName = "X-Vault-Token" + // RequestHeaderName is the name of the header used by the Agent for + // SSRF protection. + RequestHeaderName = "X-Vault-Request" + // PerformanceReplicationALPN is the negotiated protocol used for // performance replication. PerformanceReplicationALPN = "replication_v1"