diff --git a/.gitignore b/.gitignore index 62db80cd9b4d..101ddbbf2631 100644 --- a/.gitignore +++ b/.gitignore @@ -48,7 +48,9 @@ Vagrantfile # Configs *.hcl !command/agent/config/test-fixtures/config.hcl +!command/agent/config/test-fixtures/config-cache.hcl !command/agent/config/test-fixtures/config-embedded-type.hcl +!command/agent/config/test-fixtures/config-cache-embedded-type.hcl .DS_Store .idea diff --git a/api/client.go b/api/client.go index 80ccd7d50290..cfb49b8fda3a 100644 --- a/api/client.go +++ b/api/client.go @@ -25,6 +25,7 @@ import ( "golang.org/x/time/rate" ) +const EnvVaultAgentAddress = "VAULT_AGENT_ADDR" const EnvVaultAddress = "VAULT_ADDR" const EnvVaultCACert = "VAULT_CACERT" const EnvVaultCAPath = "VAULT_CAPATH" @@ -237,6 +238,10 @@ func (c *Config) ReadEnvironment() error { if v := os.Getenv(EnvVaultAddress); v != "" { envAddress = v } + // Agent's address will take precedence over Vault's address + if v := os.Getenv(EnvVaultAgentAddress); v != "" { + envAddress = v + } if v := os.Getenv(EnvVaultMaxRetries); v != "" { maxRetries, err := strconv.ParseUint(v, 10, 32) if err != nil { @@ -366,6 +371,21 @@ func NewClient(c *Config) (*Client, error) { c.modifyLock.Lock() defer c.modifyLock.Unlock() + // If address begins with a `/`, treat it as a socket file path and set + // the HttpClient's transport to the corresponding socket dialer. + if strings.HasPrefix(c.Address, "/") { + socketFilePath := c.Address + c.HttpClient = &http.Client{ + Transport: &http.Transport{ + DialContext: func(context.Context, string, string) (net.Conn, error) { + return net.Dial("unix", socketFilePath) + }, + }, + } + // Set the unix address for URL parsing below + c.Address = "http://unix" + } + u, err := url.Parse(c.Address) if err != nil { return nil, err @@ -707,7 +727,7 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon redirectCount := 0 START: - req, err := r.toRetryableHTTP() + req, err := r.ToRetryableHTTP() if err != nil { return nil, err } diff --git a/api/request.go b/api/request.go index 4efa2aa84177..41d45720fea7 100644 --- a/api/request.go +++ b/api/request.go @@ -62,7 +62,7 @@ func (r *Request) ResetJSONBody() error { // DEPRECATED: ToHTTP turns this request into a valid *http.Request for use // with the net/http package. func (r *Request) ToHTTP() (*http.Request, error) { - req, err := r.toRetryableHTTP() + req, err := r.ToRetryableHTTP() if err != nil { return nil, err } @@ -85,7 +85,7 @@ func (r *Request) ToHTTP() (*http.Request, error) { return req.Request, nil } -func (r *Request) toRetryableHTTP() (*retryablehttp.Request, error) { +func (r *Request) ToRetryableHTTP() (*retryablehttp.Request, error) { // Encode the query parameters r.URL.RawQuery = r.Params.Encode() diff --git a/api/secret.go b/api/secret.go index e25962604b4e..c8a0ba3d9d2c 100644 --- a/api/secret.go +++ b/api/secret.go @@ -292,6 +292,7 @@ type SecretAuth struct { TokenPolicies []string `json:"token_policies"` IdentityPolicies []string `json:"identity_policies"` Metadata map[string]string `json:"metadata"` + Orphan bool `json:"orphan"` LeaseDuration int `json:"lease_duration"` Renewable bool `json:"renewable"` diff --git a/command/agent.go b/command/agent.go index 92c93c70c2e1..0db43be541ee 100644 --- a/command/agent.go +++ b/command/agent.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "io" + "net" + "os" "sort" "strings" @@ -23,6 +25,7 @@ import ( "github.com/hashicorp/vault/command/agent/auth/gcp" "github.com/hashicorp/vault/command/agent/auth/jwt" "github.com/hashicorp/vault/command/agent/auth/kubernetes" + "github.com/hashicorp/vault/command/agent/cache" "github.com/hashicorp/vault/command/agent/config" "github.com/hashicorp/vault/command/agent/sink" "github.com/hashicorp/vault/command/agent/sink/file" @@ -332,10 +335,40 @@ func (c *AgentCommand) Run(args []string) int { EnableReauthOnNewCredentials: config.AutoAuth.EnableReauthOnNewCredentials, }) - // Start things running + // Start auto-auth and sink servers go ah.Run(ctx, method) go ss.Run(ctx, ah.OutputCh, sinks) + // Parse agent listener configurations + var listeners []net.Listener + if len(config.Cache.Listeners) != 0 { + listeners, err = cache.ServerListeners(config.Cache.Listeners, c.logWriter, c.UI) + if err != nil { + c.UI.Error(fmt.Sprintf("Error running listeners: %v", err)) + return 1 + } + } + + // Start listening to requests + err = cache.Run(ctx, &cache.Config{ + Token: c.client.Token(), + UseAutoAuthToken: config.Cache.UseAutoAuthToken, + Listeners: listeners, + Logger: c.logger.Named("cache"), + }) + if err != nil { + c.UI.Error(fmt.Sprintf("Error starting the cache listeners: %v", err)) + return 1 + } + + // Ensure that listeners are closed at all the exits + listenerCloseFunc := func() { + for _, ln := range listeners { + ln.Close() + } + } + defer c.cleanupGuard.Do(listenerCloseFunc) + // Release the log gate. c.logGate.Flush() diff --git a/command/agent/auth/auth.go b/command/agent/auth/auth.go index 73ccf6ea8e57..8dbed70e6490 100644 --- a/command/agent/auth/auth.go +++ b/command/agent/auth/auth.go @@ -7,6 +7,7 @@ import ( hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/contextutil" "github.com/hashicorp/vault/helper/jsonutil" ) @@ -59,13 +60,6 @@ func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler { return ah } -func backoffOrQuit(ctx context.Context, backoff time.Duration) { - select { - case <-time.After(backoff): - case <-ctx.Done(): - } -} - func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { if am == nil { panic("nil auth method") @@ -116,7 +110,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { path, data, err := am.Authenticate(ctx, ah.client) if err != nil { ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) + contextutil.BackoffOrQuit(ctx, backoff) continue } @@ -125,7 +119,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { wrapClient, err := ah.client.Clone() if err != nil { ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) + contextutil.BackoffOrQuit(ctx, backoff) continue } wrapClient.SetWrappingLookupFunc(func(string, string) string { @@ -138,7 +132,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { // Check errors/sanity if err != nil { ah.logger.Error("error authenticating", "error", err, "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) + contextutil.BackoffOrQuit(ctx, backoff) continue } @@ -146,18 +140,18 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { case ah.wrapTTL > 0: if secret.WrapInfo == nil { ah.logger.Error("authentication returned nil wrap info", "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) + contextutil.BackoffOrQuit(ctx, backoff) continue } if secret.WrapInfo.Token == "" { ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) + contextutil.BackoffOrQuit(ctx, backoff) continue } wrappedResp, err := jsonutil.EncodeJSON(secret.WrapInfo) if err != nil { ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) + contextutil.BackoffOrQuit(ctx, backoff) continue } ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing") @@ -178,12 +172,12 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { default: if secret == nil || secret.Auth == nil { ah.logger.Error("authentication returned nil auth info", "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) + contextutil.BackoffOrQuit(ctx, backoff) continue } if secret.Auth.ClientToken == "" { ah.logger.Error("authentication returned empty client token", "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) + contextutil.BackoffOrQuit(ctx, backoff) continue } ah.logger.Info("authentication successful, sending token to sinks") @@ -201,7 +195,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) { }) if err != nil { ah.logger.Error("error creating renewer, backing off and retrying", "error", err, "backoff", backoff.Seconds()) - backoffOrQuit(ctx, backoff) + contextutil.BackoffOrQuit(ctx, backoff) continue } diff --git a/command/agent/cache/api_proxy.go b/command/agent/cache/api_proxy.go new file mode 100644 index 000000000000..044ebf014ff6 --- /dev/null +++ b/command/agent/cache/api_proxy.go @@ -0,0 +1,58 @@ +package cache + +import ( + "bytes" + "context" + "io/ioutil" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" +) + +// APIProxy is an implementation of the proxier interface that is used to +// forward the request to Vault and get the response. +type APIProxy struct { + logger hclog.Logger +} + +type APIProxyConfig struct { + Logger hclog.Logger +} + +func NewAPIProxy(config *APIProxyConfig) Proxier { + return &APIProxy{ + logger: config.Logger, + } +} + +func (ap *APIProxy) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) { + client, err := api.NewClient(api.DefaultConfig()) + if err != nil { + return nil, err + } + client.SetToken(req.Token) + client.SetHeaders(req.Request.Header) + + fwReq := client.NewRequest(req.Request.Method, req.Request.URL.Path) + fwReq.BodyBytes = req.RequestBody + + // Make the request to Vault and get the response + ap.logger.Info("forwarding request", "path", req.Request.URL.Path) + resp, err := client.RawRequestWithContext(ctx, fwReq) + if err != nil { + return nil, err + } + + // Parse and reset response body + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + ap.logger.Error("failed to read request body", "error", err) + return nil, err + } + resp.Body = ioutil.NopCloser(bytes.NewBuffer(respBody)) + + return &SendResponse{ + Response: resp, + ResponseBody: respBody, + }, nil +} diff --git a/command/agent/cache/api_proxy_test.go b/command/agent/cache/api_proxy_test.go new file mode 100644 index 000000000000..94383a723bdc --- /dev/null +++ b/command/agent/cache/api_proxy_test.go @@ -0,0 +1,43 @@ +package cache + +import ( + "testing" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/jsonutil" + "github.com/hashicorp/vault/helper/logging" + "github.com/hashicorp/vault/helper/namespace" +) + +func TestCache_APIProxy(t *testing.T) { + cleanup, client, _ := setupClusterAndAgent(t, nil) + defer cleanup() + + proxier := NewAPIProxy(&APIProxyConfig{ + Logger: logging.NewVaultLogger(hclog.Trace), + }) + + r := client.NewRequest("GET", "/v1/sys/health") + req, err := r.ToRetryableHTTP() + if err != nil { + t.Fatal(err) + } + + resp, err := proxier.Send(namespace.RootContext(nil), &SendRequest{ + Request: req.Request, + }) + if err != nil { + t.Fatal(err) + } + + var result api.HealthResponse + err = jsonutil.DecodeJSONFromReader(resp.Response.Body, &result) + if err != nil { + t.Fatal(err) + } + + if !result.Initialized || result.Sealed || result.Standby { + t.Fatalf("bad sys/health response") + } +} diff --git a/command/agent/cache/cache_namespaces_test.go b/command/agent/cache/cache_namespaces_test.go new file mode 100644 index 000000000000..01f8b8e6513a --- /dev/null +++ b/command/agent/cache/cache_namespaces_test.go @@ -0,0 +1,314 @@ +package cache + +import ( + "context" + "testing" + "time" + + "github.com/go-test/deep" + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" +) + +func TestCache_Namespaces(t *testing.T) { + t.Run("send", testSendNamespaces) + + t.Run("full_path", func(t *testing.T) { + t.Run("handle_cacheclear", func(t *testing.T) { + testHandleCacheClearNamespaces(t, true) + }) + + t.Run("eviction_on_revocation", func(t *testing.T) { + testEvictionOnRevocationNamespaces(t, true) + }) + }) + + t.Run("namespace_header", func(t *testing.T) { + t.Run("handle_cacheclear", func(t *testing.T) { + testHandleCacheClearNamespaces(t, false) + }) + + t.Run("eviction_on_revocation", func(t *testing.T) { + testEvictionOnRevocationNamespaces(t, false) + }) + }) +} + +func testSendNamespaces(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + cleanup, clusterClient, testClient := setupClusterAndAgent(t, coreConfig) + defer cleanup() + + // Create a namespace + _, err := clusterClient.Logical().Write("sys/namespaces/ns1", nil) + if err != nil { + t.Fatal(err) + } + + // Mount the leased KV into ns1 + clusterClient.SetNamespace("ns1/") + err = clusterClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + clusterClient.SetNamespace("") + + // Try request using full path + { + // Write some random value + _, err = clusterClient.Logical().Write("/ns1/kv/foo", map[string]interface{}{ + "value": "test", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + proxiedResp, err := testClient.Logical().Read("/ns1/kv/foo") + if err != nil { + t.Fatal(err) + } + + cachedResp, err := testClient.Logical().Read("/ns1/kv/foo") + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(proxiedResp, cachedResp); diff != nil { + t.Fatal(diff) + } + } + + // Try request using the namespace header + { + // Write some random value + _, err = clusterClient.Logical().Write("/ns1/kv/bar", map[string]interface{}{ + "value": "test", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + testClient.SetNamespace("ns1/") + proxiedResp, err := testClient.Logical().Read("/kv/bar") + if err != nil { + t.Fatal(err) + } + + cachedResp, err := testClient.Logical().Read("/kv/bar") + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(proxiedResp, cachedResp); diff != nil { + t.Fatal(diff) + } + testClient.SetNamespace("") + } + + // Try the same request using different namespace input methods (header vs + // full path), they should not be the same cache entry (i.e. should produce + // different lease ID's). + { + _, err := clusterClient.Logical().Write("/ns1/kv/baz", map[string]interface{}{ + "value": "test", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + proxiedResp, err := testClient.Logical().Read("/ns1/kv/baz") + if err != nil { + t.Fatal(err) + } + + testClient.SetNamespace("ns1/") + cachedResp, err := testClient.Logical().Read("/kv/baz") + if err != nil { + t.Fatal(err) + } + testClient.SetNamespace("") + + if diff := deep.Equal(proxiedResp, cachedResp); diff == nil { + t.Logf("response #1: %#v", proxiedResp) + t.Logf("response #2: %#v", cachedResp) + t.Fatal("expected requests to be not cached") + } + } +} + +func testHandleCacheClearNamespaces(t *testing.T, fullPath bool) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + cleanup, clusterClient, testClient := setupClusterAndAgent(t, coreConfig) + defer cleanup() + + // Create a namespace + _, err := clusterClient.Logical().Write("sys/namespaces/ns1", nil) + if err != nil { + t.Fatal(err) + } + + // Mount the leased KV into ns1 + clusterClient.SetNamespace("ns1/") + err = clusterClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + clusterClient.SetNamespace("") + + // Write some random value + _, err = clusterClient.Logical().Write("/ns1/kv/foo", map[string]interface{}{ + "value": "test", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + requestPath := "/kv/foo" + testClient.SetNamespace("ns1/") + if fullPath { + requestPath = "/ns1" + requestPath + testClient.SetNamespace("") + } + + // Request the secret + firstResp, err := testClient.Logical().Read(requestPath) + if err != nil { + t.Fatal(err) + } + + time.Sleep(200 * time.Millisecond) + + // Clear by request_path and namespace + requestPathValue := "/v1" + requestPath + data := &cacheClearRequest{ + Type: "request_path", + Value: requestPathValue, + } + + r := testClient.NewRequest("PUT", "/v1/agent/cache-clear") + if err := r.SetJSONBody(data); err != nil { + t.Fatal(err) + } + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + _, err = clusterClient.RawRequestWithContext(ctx, r) + if err != nil { + t.Fatal(err) + } + + time.Sleep(200 * time.Millisecond) + + secondResp, err := testClient.Logical().Read(requestPath) + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(firstResp, secondResp); diff == nil { + t.Logf("response #1: %#v", firstResp) + t.Logf("response #2: %#v", secondResp) + t.Fatal("expected requests to be not cached") + } +} + +func testEvictionOnRevocationNamespaces(t *testing.T, fullPath bool) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + cleanup, clusterClient, testClient := setupClusterAndAgent(t, coreConfig) + defer cleanup() + + // Create a namespace + _, err := clusterClient.Logical().Write("sys/namespaces/ns1", nil) + if err != nil { + t.Fatal(err) + } + + // Mount the leased KV into ns1 + clusterClient.SetNamespace("ns1/") + err = clusterClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + clusterClient.SetNamespace("") + + // Write some random value + _, err = clusterClient.Logical().Write("/ns1/kv/foo", map[string]interface{}{ + "value": "test", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + requestPath := "/kv/foo" + testClient.SetNamespace("ns1/") + if fullPath { + requestPath = "/ns1" + requestPath + testClient.SetNamespace("") + } + + // Request the secret + firstResp, err := testClient.Logical().Read(requestPath) + if err != nil { + t.Fatal(err) + } + leaseID := firstResp.LeaseID + + time.Sleep(200 * time.Millisecond) + + revocationPath := "/sys/leases/revoke" + if fullPath { + revocationPath = "/ns1/" + revocationPath + testClient.SetNamespace("") + } + + _, err = testClient.Logical().Write(revocationPath, map[string]interface{}{ + "lease_id": leaseID, + }) + + secondResp, err := testClient.Logical().Read(requestPath) + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(firstResp, secondResp); diff == nil { + t.Logf("response #1: %#v", firstResp) + t.Logf("response #2: %#v", secondResp) + t.Fatal("expected requests to be not cached") + } +} diff --git a/command/agent/cache/cache_test.go b/command/agent/cache/cache_test.go new file mode 100644 index 000000000000..35f099f7bea9 --- /dev/null +++ b/command/agent/cache/cache_test.go @@ -0,0 +1,321 @@ +package cache + +import ( + "context" + "fmt" + "net" + "os" + "testing" + + "github.com/hashicorp/vault/logical" + + "github.com/go-test/deep" + hclog "github.com/hashicorp/go-hclog" + kv "github.com/hashicorp/vault-plugin-secrets-kv" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/credential/userpass" + "github.com/hashicorp/vault/helper/logging" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/vault" +) + +const policyAdmin = ` +path "*" { + capabilities = ["sudo", "create", "read", "update", "delete", "list"] +} +` + +// testSetupClusterAndAgent is a helper func used to set up a test cluster and +// caching agent. It returns a cleanup func that should be deferred immediately +// along with two clients, one for direct cluster communication and another to +// talk to the caching agent. +func setupClusterAndAgent(t *testing.T, coreConfig *vault.CoreConfig) (func(), *api.Client, *api.Client) { + t.Helper() + + // Handle sane defaults + if coreConfig == nil { + coreConfig = &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + }, + } + } + + if coreConfig.CredentialBackends == nil { + coreConfig.CredentialBackends = map[string]logical.Factory{ + "userpass": userpass.Factory, + } + } + + // Init new test cluster + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + + cores := cluster.Cores + vault.TestWaitActive(t, cores[0].Core) + + // clusterClient is the client that is used to talk directly to the cluster. + clusterClient := cores[0].Client + + // Add an admin policy + if err := clusterClient.Sys().PutPolicy("admin", policyAdmin); err != nil { + t.Fatal(err) + } + + // Set up the userpass auth backend and an admin user. Used for getting a token + // for the agent later down in this func. + clusterClient.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + + _, err := clusterClient.Logical().Write("auth/userpass/users/foo", map[string]interface{}{ + "password": "bar", + "policies": []string{"admin"}, + }) + if err != nil { + t.Fatal(err) + } + + // Set up env vars for agent consumption + origEnvVaultAddress := os.Getenv(api.EnvVaultAddress) + os.Setenv(api.EnvVaultAddress, clusterClient.Address()) + + origEnvVaultCACert := os.Getenv(api.EnvVaultCACert) + os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir)) + + cacheLogger := logging.NewVaultLogger(hclog.Trace) + ctx := context.Background() + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + // Start listening to requests + err = Run(ctx, &Config{ + Token: clusterClient.Token(), + UseAutoAuthToken: false, + Listeners: []net.Listener{listener}, + Logger: cacheLogger.Named("cache"), + }) + if err != nil { + t.Fatal(err) + } + + // testClient is the client that is used to talk to the agent for proxying/caching behavior. + testClient, err := clusterClient.Clone() + if err != nil { + t.Fatal(err) + } + + if err := testClient.SetAddress("http://" + listener.Addr().String()); err != nil { + t.Fatal(err) + } + + // Login via userpass method to derive a managed token. Set that token as the + // testClient's token + resp, err := testClient.Logical().Write("auth/userpass/login/foo", map[string]interface{}{ + "password": "bar", + }) + if err != nil { + t.Fatal(err) + } + testClient.SetToken(resp.Auth.ClientToken) + + cleanup := func() { + cluster.Cleanup() + os.Setenv(api.EnvVaultAddress, origEnvVaultAddress) + os.Setenv(api.EnvVaultCACert, origEnvVaultCACert) + listener.Close() + } + + return cleanup, clusterClient, testClient +} + +func TestCache_NonCacheable(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": kv.Factory, + }, + } + + cleanup, _, testClient := setupClusterAndAgent(t, coreConfig) + defer cleanup() + + // Query mounts first + origMounts, err := testClient.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + + // Mount a kv backend + if err := testClient.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + Options: map[string]string{ + "version": "2", + }, + }); err != nil { + t.Fatal(err) + } + + // Query mounts again + newMounts, err := testClient.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(origMounts, newMounts); diff == nil { + t.Logf("response #1: %#v", origMounts) + t.Logf("response #2: %#v", newMounts) + t.Fatal("expected requests to be not cached") + } +} + +func TestCache_AuthResponse(t *testing.T) { + cleanup, _, testClient := setupClusterAndAgent(t, nil) + defer cleanup() + + resp, err := testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + token := resp.Auth.ClientToken + testClient.SetToken(token) + + // Test on auth response by creating a child token + { + proxiedResp, err := testClient.Logical().Write("auth/token/create", map[string]interface{}{ + "policies": "default", + }) + if err != nil { + t.Fatal(err) + } + if proxiedResp.Auth == nil || proxiedResp.Auth.ClientToken == "" { + t.Fatalf("expected a valid client token in the response, got = %#v", proxiedResp) + } + + cachedResp, err := testClient.Logical().Write("auth/token/create", map[string]interface{}{ + "policies": "default", + }) + if err != nil { + t.Fatal(err) + } + if cachedResp.Auth == nil || cachedResp.Auth.ClientToken == "" { + t.Fatalf("expected a valid client token in the response, got = %#v", cachedResp) + } + + if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil { + t.Fatal(diff) + } + } + + // Test on *non-renewable* auth response by creating a child root token + { + proxiedResp, err := testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + if proxiedResp.Auth == nil || proxiedResp.Auth.ClientToken == "" { + t.Fatalf("expected a valid client token in the response, got = %#v", proxiedResp) + } + + cachedResp, err := testClient.Logical().Write("auth/token/create", nil) + if err != nil { + t.Fatal(err) + } + if cachedResp.Auth == nil || cachedResp.Auth.ClientToken == "" { + t.Fatalf("expected a valid client token in the response, got = %#v", cachedResp) + } + + if diff := deep.Equal(proxiedResp.Auth.ClientToken, cachedResp.Auth.ClientToken); diff != nil { + t.Fatal(diff) + } + } +} + +func TestCache_LeaseResponse(t *testing.T) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: hclog.NewNullLogger(), + LogicalBackends: map[string]logical.Factory{ + "kv": vault.LeasedPassthroughBackendFactory, + }, + } + + cleanup, client, testClient := setupClusterAndAgent(t, coreConfig) + defer cleanup() + + err := client.Sys().Mount("kv", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + + // Test proxy by issuing two different requests + { + // Write data to the lease-kv backend + _, err := testClient.Logical().Write("kv/foo", map[string]interface{}{ + "value": "bar", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + _, err = testClient.Logical().Write("kv/foobar", map[string]interface{}{ + "value": "bar", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + firstResp, err := testClient.Logical().Read("kv/foo") + if err != nil { + t.Fatal(err) + } + + secondResp, err := testClient.Logical().Read("kv/foobar") + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(firstResp, secondResp); diff == nil { + t.Logf("response: %#v", firstResp) + t.Fatal("expected proxied responses, got cached response on second request") + } + } + + // Test caching behavior by issue the same request twice + { + _, err := testClient.Logical().Write("kv/baz", map[string]interface{}{ + "value": "foo", + "ttl": "1h", + }) + if err != nil { + t.Fatal(err) + } + + proxiedResp, err := testClient.Logical().Read("kv/baz") + if err != nil { + t.Fatal(err) + } + + cachedResp, err := testClient.Logical().Read("kv/baz") + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(proxiedResp, cachedResp); diff != nil { + t.Fatal(diff) + } + } +} diff --git a/command/agent/cache/cachememdb/cache_memdb.go b/command/agent/cache/cachememdb/cache_memdb.go new file mode 100644 index 000000000000..7fb5e887f1f1 --- /dev/null +++ b/command/agent/cache/cachememdb/cache_memdb.go @@ -0,0 +1,239 @@ +package cachememdb + +import ( + "errors" + "fmt" + + memdb "github.com/hashicorp/go-memdb" +) + +const ( + tableNameIndexer = "indexer" +) + +// CacheMemDB is the underlying cache database for storing indexes. +type CacheMemDB struct { + db *memdb.MemDB +} + +// New creates a new instance of CacheMemDB. +func New() (*CacheMemDB, error) { + db, err := newDB() + if err != nil { + return nil, err + } + + return &CacheMemDB{ + db: db, + }, nil +} + +func newDB() (*memdb.MemDB, error) { + cacheSchema := &memdb.DBSchema{ + Tables: map[string]*memdb.TableSchema{ + tableNameIndexer: &memdb.TableSchema{ + Name: tableNameIndexer, + Indexes: map[string]*memdb.IndexSchema{ + IndexNameID.String(): &memdb.IndexSchema{ + Name: IndexNameID.String(), + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "ID", + }, + }, + IndexNameRequestPath.String(): &memdb.IndexSchema{ + Name: IndexNameRequestPath.String(), + Unique: false, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "Namespace", + }, + &memdb.StringFieldIndex{ + Field: "RequestPath", + }, + }, + }, + }, + IndexNameToken.String(): &memdb.IndexSchema{ + Name: IndexNameToken.String(), + Unique: false, + Indexer: &memdb.StringFieldIndex{ + Field: "Token", + }, + }, + IndexNameTokenParent.String(): &memdb.IndexSchema{ + Name: IndexNameTokenParent.String(), + Unique: false, + AllowMissing: true, + Indexer: &memdb.StringFieldIndex{ + Field: "TokenParent", + }, + }, + IndexNameTokenAccessor.String(): &memdb.IndexSchema{ + Name: IndexNameTokenAccessor.String(), + Unique: false, + AllowMissing: true, + Indexer: &memdb.StringFieldIndex{ + Field: "TokenAccessor", + }, + }, + IndexNameLease.String(): &memdb.IndexSchema{ + Name: IndexNameLease.String(), + Unique: true, + AllowMissing: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Lease", + }, + }, + }, + }, + }, + } + + db, err := memdb.NewMemDB(cacheSchema) + if err != nil { + return nil, err + } + return db, nil +} + +// GetByPrefix returns all the cached indexes based on the index name and the +// value prefix. +func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ([]*Index, error) { + if indexNameFromString(indexName) == IndexNameInvalid { + return nil, fmt.Errorf("invalid index name %q", indexName) + } + + indexName = indexName + "_prefix" + + // Get all the objects + iter, err := c.db.Txn(false).Get(tableNameIndexer, indexName, indexValues...) + if err != nil { + return nil, err + } + + var indexes []*Index + for { + obj := iter.Next() + if obj == nil { + break + } + index, ok := obj.(*Index) + if !ok { + return nil, fmt.Errorf("failed to cast cached index") + } + + indexes = append(indexes, index) + } + + return indexes, nil +} + +// Get returns the index based on the indexer and the index values provided. +func (c *CacheMemDB) Get(indexName string, indexValues ...interface{}) (*Index, error) { + if indexNameFromString(indexName) == IndexNameInvalid { + return nil, fmt.Errorf("invalid index name %q", indexName) + } + + raw, err := c.db.Txn(false).First(tableNameIndexer, indexName, indexValues...) + if err != nil { + return nil, err + } + + if raw == nil { + return nil, nil + } + + index, ok := raw.(*Index) + if !ok { + return nil, errors.New("unable to parse index value from the cache") + } + + return index, nil +} + +// Set stores the index into the cache. +func (c *CacheMemDB) Set(index *Index) error { + if index == nil { + return errors.New("nil index provided") + } + + txn := c.db.Txn(true) + defer txn.Abort() + + if err := txn.Insert(tableNameIndexer, index); err != nil { + return fmt.Errorf("unable to insert index into cache: %v", err) + } + + txn.Commit() + + return nil +} + +// Evict removes an index from the cache based on index name and value. +func (c *CacheMemDB) Evict(indexName string, indexValues ...interface{}) error { + index, err := c.Get(indexName, indexValues...) + if err != nil { + return fmt.Errorf("unable to fetch index on cache deletion: %v", err) + } + + if index == nil { + return nil + } + + txn := c.db.Txn(true) + defer txn.Abort() + + if err := txn.Delete(tableNameIndexer, index); err != nil { + return fmt.Errorf("unable to delete index from cache: %v", err) + } + + txn.Commit() + + return nil +} + +// EvictAll removes all matching indexes from the cache based on index name and value. +func (c *CacheMemDB) EvictAll(indexName, indexValue string) error { + return c.batchEvict(false, indexName, indexValue) +} + +// EvictByPrefix removes all matching prefix indexes from the cache based on index name and prefix. +func (c *CacheMemDB) EvictByPrefix(indexName, indexPrefix string) error { + return c.batchEvict(true, indexName, indexPrefix) +} + +func (c *CacheMemDB) batchEvict(isPrefix bool, indexName string, indexValues ...interface{}) error { + if indexNameFromString(indexName) == IndexNameInvalid { + return fmt.Errorf("invalid index name %q", indexName) + } + + if isPrefix { + indexName = indexName + "_prefix" + } + + txn := c.db.Txn(true) + defer txn.Abort() + + _, err := txn.DeleteAll(tableNameIndexer, indexName, indexValues...) + if err != nil { + return err + } + + txn.Commit() + + return nil +} + +// Flush resets the underlying cache object. +func (c *CacheMemDB) Flush() error { + newDB, err := newDB() + if err != nil { + return err + } + + c.db = newDB + + return nil +} diff --git a/command/agent/cache/cachememdb/cache_memdb_test.go b/command/agent/cache/cachememdb/cache_memdb_test.go new file mode 100644 index 000000000000..b4cc08a25106 --- /dev/null +++ b/command/agent/cache/cachememdb/cache_memdb_test.go @@ -0,0 +1,286 @@ +package cachememdb + +import ( + "context" + "testing" + + "github.com/go-test/deep" +) + +func testContextInfo() *ContextInfo { + ctx, cancelFunc := context.WithCancel(context.Background()) + + return &ContextInfo{ + Ctx: ctx, + CancelFunc: cancelFunc, + } +} + +func TestNew(t *testing.T) { + _, err := New() + if err != nil { + t.Fatal(err) + } +} + +func TestCacheMemDB_Get(t *testing.T) { + cache, err := New() + if err != nil { + t.Fatal(err) + } + + // Test invalid index name + _, err = cache.Get("foo", "bar") + if err == nil { + t.Fatal("expected error") + } + + // Test on empty cache + index, err := cache.Get(IndexNameID.String(), "foo") + if err != nil { + t.Fatal(err) + } + if index != nil { + t.Fatalf("expected nil index, got: %v", index) + } + + // Populate cache + in := &Index{ + ID: "test_id", + Namespace: "test_ns/", + RequestPath: "/v1/request/path", + Token: "test_token", + TokenAccessor: "test_accessor", + Lease: "test_lease", + Response: []byte("hello world"), + } + + if err := cache.Set(in); err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + indexName string + indexValues []interface{} + }{ + { + "by_index_id", + "id", + []interface{}{in.ID}, + }, + { + "by_request_path", + "request_path", + []interface{}{in.Namespace, in.RequestPath}, + }, + { + "by_lease", + "lease", + []interface{}{in.Lease}, + }, + { + "by_token", + "token", + []interface{}{in.Token}, + }, + { + "by_token_accessor", + "token_accessor", + []interface{}{in.TokenAccessor}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + out, err := cache.Get(tc.indexName, tc.indexValues...) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(in, out); diff != nil { + t.Fatal(diff) + } + }) + } +} + +func TestCacheMemDB_Set(t *testing.T) { + cache, err := New() + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + index *Index + wantErr bool + }{ + { + "nil", + nil, + true, + }, + { + "empty_fields", + &Index{}, + true, + }, + { + "missing_required_fields", + &Index{ + Lease: "foo", + }, + true, + }, + { + "all_fields", + &Index{ + ID: "test_id", + Namespace: "test_ns/", + RequestPath: "/v1/request/path", + Token: "test_token", + TokenAccessor: "test_accessor", + Lease: "test_lease", + RenewCtxInfo: testContextInfo(), + }, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := cache.Set(tc.index); (err != nil) != tc.wantErr { + t.Fatalf("CacheMemDB.Set() error = %v, wantErr = %v", err, tc.wantErr) + } + }) + } +} + +func TestCacheMemDB_Evict(t *testing.T) { + cache, err := New() + if err != nil { + t.Fatal(err) + } + + // Test on empty cache + if err := cache.Evict(IndexNameID.String(), "foo"); err != nil { + t.Fatal(err) + } + + testIndex := &Index{ + ID: "test_id", + Namespace: "test_ns/", + RequestPath: "/v1/request/path", + Token: "test_token", + TokenAccessor: "test_token_accessor", + Lease: "test_lease", + RenewCtxInfo: testContextInfo(), + } + + testCases := []struct { + name string + indexName string + indexValues []interface{} + insertIndex *Index + wantErr bool + }{ + { + "empty_params", + "", + []interface{}{""}, + nil, + true, + }, + { + "invalid_params", + "foo", + []interface{}{"bar"}, + nil, + true, + }, + { + "by_id", + "id", + []interface{}{"test_id"}, + testIndex, + false, + }, + { + "by_request_path", + "request_path", + []interface{}{"test_ns/", "/v1/request/path"}, + testIndex, + false, + }, + { + "by_token", + "token", + []interface{}{"test_token"}, + testIndex, + false, + }, + { + "by_token_accessor", + "token_accessor", + []interface{}{"test_accessor"}, + testIndex, + false, + }, + { + "by_lease", + "lease", + []interface{}{"test_lease"}, + testIndex, + false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.insertIndex != nil { + if err := cache.Set(tc.insertIndex); err != nil { + t.Fatal(err) + } + } + + if err := cache.Evict(tc.indexName, tc.indexValues...); (err != nil) != tc.wantErr { + t.Fatal(err) + } + }) + } +} + +func TestCacheMemDB_Flush(t *testing.T) { + cache, err := New() + if err != nil { + t.Fatal(err) + } + + // Populate cache + in := &Index{ + ID: "test_id", + Token: "test_token", + Lease: "test_lease", + Namespace: "test_ns/", + RequestPath: "/v1/request/path", + Response: []byte("hello world"), + } + + if err := cache.Set(in); err != nil { + t.Fatal(err) + } + + // Reset the cache + if err := cache.Flush(); err != nil { + t.Fatal(err) + } + + // Check the cache doesn't contain inserted index + out, err := cache.Get(IndexNameID.String(), "test_id") + if err != nil { + t.Fatal(err) + } + if out != nil { + t.Fatalf("expected cache to be empty, got = %v", out) + } +} diff --git a/command/agent/cache/cachememdb/index.go b/command/agent/cache/cachememdb/index.go new file mode 100644 index 000000000000..e5dd1c7fed9a --- /dev/null +++ b/command/agent/cache/cachememdb/index.go @@ -0,0 +1,101 @@ +package cachememdb + +import "context" + +type ContextInfo struct { + Ctx context.Context + CancelFunc context.CancelFunc + DoneCh chan struct{} +} + +// Index holds the response to be cached along with multiple other values that +// serve as pointers to refer back to this index. +type Index struct { + // ID is a value that uniquely represents the request held by this + // index. This is computed by serializing and hashing the response object. + // Required: true, Unique: true + ID string + + // Token is the token that fetched the response held by this index + // Required: true, Unique: false + Token string + + // TokenParent is the parent token of the token held by this index + // Required: false, Unique: false + TokenParent string + + // TokenAccessor is the accessor of the token being cached in this index + // Required: true, Unique: false + TokenAccessor string + + // Namespace is the namespace that was provided in the request path as the + // Vault namespace to query + Namespace string + + // RequestPath is the path of the request that resulted in the response + // held by this index. + // Required: true, Unique: false + RequestPath string + + // Lease is the identifier of the lease in Vault, that belongs to the + // response held by this index. + // Required: false, Unique: true + Lease string + + // Response is the serialized response object that the agent is caching. + Response []byte + + // RenewCtxInfo holds the context and the corresponding cancel func for the + // goroutine that manages the renewal of the secret belonging to the + // response in this index. + RenewCtxInfo *ContextInfo +} + +type IndexName uint32 + +const ( + IndexNameInvalid IndexName = iota + IndexNameID + IndexNameLease + IndexNameRequestPath + IndexNameToken + IndexNameTokenAccessor + IndexNameTokenParent +) + +func (indexName IndexName) String() string { + switch indexName { + case IndexNameID: + return "id" + case IndexNameLease: + return "lease" + case IndexNameRequestPath: + return "request_path" + case IndexNameToken: + return "token" + case IndexNameTokenAccessor: + return "token_accessor" + case IndexNameTokenParent: + return "token_parent" + } + return "" +} + +func indexNameFromString(indexName string) IndexName { + switch indexName { + case "id": + return IndexNameID + case "lease": + return IndexNameLease + case "request_path": + return IndexNameRequestPath + case "token": + return IndexNameToken + case "token_accessor": + return IndexNameTokenAccessor + case "token_parent": + return IndexNameTokenParent + default: + return IndexNameInvalid + } +} diff --git a/command/agent/cache/handler.go b/command/agent/cache/handler.go new file mode 100644 index 000000000000..e6c43614b779 --- /dev/null +++ b/command/agent/cache/handler.go @@ -0,0 +1,123 @@ +package cache + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "time" + + "github.com/hashicorp/errwrap" + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/consts" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" +) + +type Config struct { + Token string + Proxier Proxier + UseAutoAuthToken bool + Listeners []net.Listener + Logger hclog.Logger +} + +func Run(ctx context.Context, config *Config) error { + // Create the API proxier + apiProxy := NewAPIProxy(&APIProxyConfig{ + Logger: config.Logger.Named("apiproxy"), + }) + + // Create the lease cache proxier and set its underlying proxier to + // the API proxier. + leaseCache, err := NewLeaseCache(&LeaseCacheConfig{ + BaseContext: ctx, + Proxier: apiProxy, + Logger: config.Logger.Named("leasecache"), + }) + if err != nil { + return fmt.Errorf("failed to create lease cache: %v", err) + } + + config.Proxier = leaseCache + + // Create a muxer and add paths relevant for the lease cache layer + mux := http.NewServeMux() + mux.Handle("/v1/agent/cache-clear", leaseCache.HandleCacheClear(ctx)) + + mux.Handle("/", handler(ctx, config)) + for _, ln := range config.Listeners { + server := &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, + ErrorLog: config.Logger.StandardLogger(nil), + } + go server.Serve(ln) + } + + return nil +} + +func handler(ctx context.Context, config *Config) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + config.Logger.Info("received request", "path", r.URL.Path) + + token := r.Header.Get(consts.AuthHeaderName) + if token == "" && config.UseAutoAuthToken { + token = config.Token + } + + // Parse and reset body. + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + config.Logger.Error("failed to read request body") + respondError(w, http.StatusInternalServerError, errors.New("failed to read request body")) + } + r.Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) + + resp, err := config.Proxier.Send(ctx, &SendRequest{ + Token: token, + Request: r, + RequestBody: reqBody, + }) + if err != nil { + respondError(w, http.StatusInternalServerError, errwrap.Wrapf("failed to get the response: {{err}}", err)) + return + } + + copyHeader(w.Header(), resp.Response.Header) + w.WriteHeader(resp.Response.StatusCode) + io.Copy(w, resp.Response.Body) + return + }) +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func respondError(w http.ResponseWriter, status int, err error) { + logical.AdjustErrorStatusCode(&status, err) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + + resp := &vaulthttp.ErrorResponse{Errors: make([]string, 0, 1)} + if err != nil { + resp.Errors = append(resp.Errors, err.Error()) + } + + enc := json.NewEncoder(w) + enc.Encode(resp) +} diff --git a/command/agent/cache/lease_cache.go b/command/agent/cache/lease_cache.go new file mode 100644 index 000000000000..4ca4d57a128d --- /dev/null +++ b/command/agent/cache/lease_cache.go @@ -0,0 +1,743 @@ +package cache + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net/http" + "strings" + "time" + + "github.com/hashicorp/errwrap" + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + cachememdb "github.com/hashicorp/vault/command/agent/cache/cachememdb" + "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/jsonutil" + nshelper "github.com/hashicorp/vault/helper/namespace" +) + +const ( + vaultPathTokenCreate = "/v1/auth/token/create" + vaultPathTokenRevoke = "/v1/auth/token/revoke" + vaultPathTokenRevokeSelf = "/v1/auth/token/revoke-self" + vaultPathTokenRevokeAccessor = "/v1/auth/token/revoke-accessor" + vaultPathTokenRevokeOrphan = "/v1/auth/token/revoke-orphan" + vaultPathLeaseRevoke = "/v1/sys/leases/revoke" + vaultPathLeaseRevokeForce = "/v1/sys/leases/revoke-force" + vaultPathLeaseRevokePrefix = "/v1/sys/leases/revoke-prefix" +) + +var ( + contextIndexID = contextIndex{} + errInvalidType = errors.New("invalid type provided") + revocationPaths = []string{ + strings.TrimPrefix(vaultPathTokenRevoke, "/v1"), + strings.TrimPrefix(vaultPathTokenRevokeSelf, "/v1"), + strings.TrimPrefix(vaultPathTokenRevokeAccessor, "/v1"), + strings.TrimPrefix(vaultPathTokenRevokeOrphan, "/v1"), + strings.TrimPrefix(vaultPathLeaseRevoke, "/v1"), + strings.TrimPrefix(vaultPathLeaseRevokeForce, "/v1"), + strings.TrimPrefix(vaultPathLeaseRevokePrefix, "/v1"), + } +) + +type contextIndex struct{} + +type cacheClearRequest struct { + Type string `json:"type"` + Value string `json:"value"` + Namespace string `json:"namespace"` +} + +// LeaseCache is an implementation of Proxier that handles +// the caching of responses. It passes the incoming request +// to an underlying Proxier implementation. +type LeaseCache struct { + proxier Proxier + logger hclog.Logger + db *cachememdb.CacheMemDB + rand *rand.Rand + tokenContexts map[string]*ContextInfo + baseCtxInfo *ContextInfo +} + +// LeaseCacheConfig is the configuration for initializing a new +// Lease. +type LeaseCacheConfig struct { + BaseContext context.Context + Proxier Proxier + Logger hclog.Logger +} + +// ContextInfo holds a derived context and cancelFunc pair. +type ContextInfo struct { + Ctx context.Context + CancelFunc context.CancelFunc + DoneCh chan struct{} +} + +// NewLeaseCache creates a new instance of a LeaseCache. +func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) { + if conf == nil { + return nil, errors.New("nil configuration provided") + } + + if conf.Proxier == nil || conf.Logger == nil { + return nil, fmt.Errorf("missing configuration required params: %v", conf) + } + + db, err := cachememdb.New() + if err != nil { + return nil, err + } + + // Create a base context for the lease cache layer + baseCtx, baseCancelFunc := context.WithCancel(conf.BaseContext) + baseCtxInfo := &ContextInfo{ + Ctx: baseCtx, + CancelFunc: baseCancelFunc, + } + + return &LeaseCache{ + proxier: conf.Proxier, + logger: conf.Logger, + db: db, + rand: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), + tokenContexts: make(map[string]*ContextInfo), + baseCtxInfo: baseCtxInfo, + }, nil +} + +// Send performs a cache lookup on the incoming request. If it's a cache hit, +// it will return the cached response, otherwise it will delegate to the +// underlying Proxier and cache the received response. +func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) { + // Compute the index ID + id, err := computeIndexID(req) + if err != nil { + c.logger.Error("failed to compute cache key", "error", err) + return nil, err + } + + // Check if the response for this request is already in the cache + index, err := c.db.Get(cachememdb.IndexNameID.String(), id) + if err != nil { + return nil, err + } + + // Cached request is found, deserialize the response and return early + if index != nil { + c.logger.Debug("returning cached response", "path", req.Request.URL.Path) + + reader := bufio.NewReader(bytes.NewReader(index.Response)) + resp, err := http.ReadResponse(reader, nil) + if err != nil { + c.logger.Error("failed to deserialize response", "error", err) + return nil, err + } + + return &SendResponse{ + Response: &api.Response{ + Response: resp, + }, + }, nil + } + + c.logger.Debug("forwarding the request and caching the response", "path", req.Request.URL.Path) + + // Pass the request down and get a response + resp, err := c.proxier.Send(ctx, req) + if err != nil { + return nil, err + } + + // Get the namespace from the request header + namespace := req.Request.Header.Get(consts.NamespaceHeaderName) + // We need to populate an empty value since go-memdb will skip over indexes + // that contain empty values. + if namespace == "" { + namespace = "root/" + } + + // Build the index to cache based on the response received + index = &cachememdb.Index{ + ID: id, + Namespace: namespace, + RequestPath: req.Request.URL.Path, + } + + secret, err := api.ParseSecret(bytes.NewBuffer(resp.ResponseBody)) + if err != nil { + c.logger.Error("failed to parse response as secret", "error", err) + return nil, err + } + + isRevocation, err := c.handleRevocationRequest(ctx, req, resp) + if err != nil { + c.logger.Error("failed to process the response", "error", err) + return nil, err + } + + // If this is a revocation request, do not go through cache logic. + if isRevocation { + return resp, nil + } + + var renewCtxInfo *ContextInfo + switch { + case secret == nil: + // Fast path for non-cacheable responses + return resp, nil + case secret.LeaseID != "": + renewCtxInfo = c.tokenContexts[req.Token] + // If the lease belongs to a token that is not managed by the agent, + // return the response without caching it. + if renewCtxInfo == nil { + return resp, nil + } + + // Derive a context for renewal using the token's context + newCtxInfo := new(ContextInfo) + newCtxInfo.Ctx, newCtxInfo.CancelFunc = context.WithCancel(renewCtxInfo.Ctx) + newCtxInfo.DoneCh = make(chan struct{}) + renewCtxInfo = newCtxInfo + + index.Lease = secret.LeaseID + index.Token = req.Token + + case secret.Auth != nil: + isNonOrphanNewToken := strings.HasPrefix(req.Request.URL.Path, vaultPathTokenCreate) && resp.Response.StatusCode == http.StatusOK && !secret.Auth.Orphan + + // If the new token is a result of token creation endpoints (not from + // login endpoints), and if its a non-orphan, then the new token's + // context should be derived from the context of the parent token. + var parentCtx context.Context + if isNonOrphanNewToken { + parentCtxInfo := c.tokenContexts[req.Token] + // If parent token is not managed by the agent, child shouldn't be + // either. + if parentCtxInfo == nil { + return resp, nil + } + parentCtx = parentCtxInfo.Ctx + index.TokenParent = req.Token + } + + renewCtxInfo = c.createCtxInfo(parentCtx, secret.Auth.ClientToken) + index.Token = secret.Auth.ClientToken + index.TokenAccessor = secret.Auth.Accessor + + default: + // We shouldn't be hitting this, but will err on the side of caution and + // simply proxy. + return resp, nil + } + + // Serialize the response to store it in the cached index + var respBytes bytes.Buffer + err = resp.Response.Write(&respBytes) + if err != nil { + c.logger.Error("failed to serialize response", "error", err) + return nil, err + } + + // Reset the response body for upper layers to read + resp.Response.Body = ioutil.NopCloser(bytes.NewBuffer(resp.ResponseBody)) + + // Set the index's Response + index.Response = respBytes.Bytes() + + // Store the index ID in the renewer context + renewCtx := context.WithValue(renewCtxInfo.Ctx, contextIndexID, index.ID) + + // Store the renewer context in the index + index.RenewCtxInfo = &cachememdb.ContextInfo{ + Ctx: renewCtx, + CancelFunc: renewCtxInfo.CancelFunc, + DoneCh: renewCtxInfo.DoneCh, + } + + // Short-circuit if the secret is not renewable + tokenRenewable, err := secret.TokenIsRenewable() + if err != nil { + c.logger.Error("failed to parse renewable param", "error", err) + return nil, err + } + if !secret.Renewable && !tokenRenewable { + c.logger.Debug("secret not renewable, skipping addtion to the renewer") + return resp, nil + } + + c.logger.Debug("storing response into the cache and starting the secret renewal") + + // Store the index in the cache + err = c.db.Set(index) + if err != nil { + c.logger.Error("failed to cache the proxied response", "error", err) + return nil, err + } + + // Start renewing the secret in the response + go c.startRenewing(renewCtx, index, req, secret) + + return resp, nil +} + +func (c *LeaseCache) createCtxInfo(ctx context.Context, token string) *ContextInfo { + if ctx == nil { + ctx = c.baseCtxInfo.Ctx + } + ctxInfo := new(ContextInfo) + ctxInfo.Ctx, ctxInfo.CancelFunc = context.WithCancel(ctx) + ctxInfo.DoneCh = make(chan struct{}) + c.tokenContexts[token] = ctxInfo + return ctxInfo +} + +func (c *LeaseCache) startRenewing(ctx context.Context, index *cachememdb.Index, req *SendRequest, secret *api.Secret) { + defer func() { + id := ctx.Value(contextIndexID).(string) + c.logger.Debug("evicting index from cache", "id", id) + err := c.db.Evict(cachememdb.IndexNameID.String(), id) + if err != nil { + c.logger.Error("failed to evict index", "id", id, "error", err) + return + } + }() + + client, err := api.NewClient(api.DefaultConfig()) + if err != nil { + c.logger.Error("failed to create API client in the renewer", "error", err) + return + } + client.SetToken(req.Token) + client.SetHeaders(req.Request.Header) + + renewer, err := client.NewRenewer(&api.RenewerInput{ + Secret: secret, + }) + if err != nil { + c.logger.Error("failed to create secret renewer", "error", err) + return + } + + c.logger.Debug("initiating renewal", "path", req.Request.URL.Path) + go renewer.Renew() + defer renewer.Stop() + + for { + select { + case <-ctx.Done(): + c.logger.Debug("shutdown triggered, stopping renewer", "path", req.Request.URL.Path) + return + case err := <-renewer.DoneCh(): + if err != nil { + c.logger.Error("failed to renew secret", "error", err) + return + } + c.logger.Debug("renewal halted; evicting from cache", "path", req.Request.URL.Path) + return + case renewal := <-renewer.RenewCh(): + c.logger.Debug("renewal received; updating cache", "path", req.Request.URL.Path) + err = c.updateResponse(ctx, renewal) + if err != nil { + c.logger.Error("failed to handle renewal", "error", err) + return + } + case <-index.RenewCtxInfo.DoneCh: + c.logger.Debug("done channel closed") + return + } + } +} + +func (c *LeaseCache) updateResponse(ctx context.Context, renewal *api.RenewOutput) error { + id := ctx.Value(contextIndexID).(string) + + // Get the cached index using the id in the context + index, err := c.db.Get(cachememdb.IndexNameID.String(), id) + if err != nil { + return err + } + if index == nil { + return fmt.Errorf("missing cache entry for id: %q", id) + } + + // Read the response from the index + resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(index.Response)), nil) + if err != nil { + c.logger.Error("failed to deserialize response", "error", err) + return err + } + + // Update the body in the reponse by the renewed secret + bodyBytes, err := json.Marshal(renewal.Secret) + if err != nil { + return err + } + resp.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) + resp.ContentLength = int64(len(bodyBytes)) + + // Serialize the response + var respBytes bytes.Buffer + err = resp.Write(&respBytes) + if err != nil { + c.logger.Error("failed to serialize updated response", "error", err) + return err + } + + // Update the response in the index and set it in the cache + index.Response = respBytes.Bytes() + err = c.db.Set(index) + if err != nil { + c.logger.Error("failed to cache the proxied response", "error", err) + return err + } + + return nil +} + +// computeIndexID results in a value that uniquely identifies a request +// received by the agent. It does so by SHA256 hashing the serialized request +// object containing the request path, query parameters and body parameters. +func computeIndexID(req *SendRequest) (string, error) { + var b bytes.Buffer + + // Serialze the request + if err := req.Request.Write(&b); err != nil { + return "", fmt.Errorf("failed to serialize request: %v", err) + } + + // Reset the request body after it has been closed by Write + req.Request.Body = ioutil.NopCloser(bytes.NewBuffer(req.RequestBody)) + + // Append req.Token into the byte slice. This is needed since auto-auth'ed + // requests sets the token directly into SendRequest.Token + b.Write([]byte(req.Token)) + + sum := sha256.Sum256(b.Bytes()) + return hex.EncodeToString(sum[:]), nil +} + +// HandleCacheClear returns a handlerFunc that can perform cache clearing operations. +func (c *LeaseCache) HandleCacheClear(ctx context.Context) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req := new(cacheClearRequest) + if err := jsonutil.DecodeJSONFromReader(r.Body, req); err != nil { + if err == io.EOF { + err = errors.New("empty JSON provided") + } + respondError(w, http.StatusBadRequest, errwrap.Wrapf("failed to parse JSON input: {{err}}", err)) + return + } + + c.logger.Debug("received cache-clear request", "type", req.Type, "namespace", req.Namespace, "value", req.Value) + + if err := c.handleCacheClear(ctx, req.Type, req.Namespace, req.Value); err != nil { + // Default to 500 on error, unless the user provided an invalid type, + // which would then be a 400. + httpStatus := http.StatusInternalServerError + if err == errInvalidType { + httpStatus = http.StatusBadRequest + } + respondError(w, httpStatus, errwrap.Wrapf("failed to clear cache: {{err}}", err)) + return + } + + return + }) +} + +func (c *LeaseCache) handleCacheClear(ctx context.Context, clearType string, clearValues ...interface{}) error { + if len(clearValues) == 0 { + return errors.New("no value(s) provided to clear corresponding cache entries") + } + + // The value that we want to clear, for most cases, is the last one provided. + clearValue, ok := clearValues[len(clearValues)-1].(string) + if !ok { + return fmt.Errorf("unable to convert %v to type string", clearValue) + } + + switch clearType { + case "request_path": + // For this particular case, we need to ensure that there are 2 provided + // indexers for the proper lookup. + if len(clearValues) != 2 { + return fmt.Errorf("clearing cache by request path requires 2 indexers, got %d", len(clearValues)) + } + + // The first value provided for this case will be the namespace, but if it's + // an empty value we need to overwrite it with "root/" to ensure proper + // cache lookup. + if clearValues[0].(string) == "" { + clearValues[0] = "root/" + } + + // Find all the cached entries which has the given request path and + // cancel the contexts of all the respective renewers + indexes, err := c.db.GetByPrefix(clearType, clearValues...) + if err != nil { + return err + } + for _, index := range indexes { + index.RenewCtxInfo.CancelFunc() + } + + case "token": + if clearValue == "" { + return nil + } + // Get the context for the given token and cancel its context + tokenCtxInfo := c.tokenContexts[clearValue] + if tokenCtxInfo == nil { + return nil + } + + tokenCtxInfo.CancelFunc() + + // Remove the cancelled context from the map + delete(c.tokenContexts, clearValue) + + case "token_accessor", "lease": + // Get the cached index and cancel the corresponding renewer context + index, err := c.db.Get(clearType, clearValue) + if err != nil { + return err + } + if index == nil { + return nil + } + index.RenewCtxInfo.CancelFunc() + + case "all": + // Cancel the base context which triggers all the goroutines to + // stop and evict entries from cache. + c.baseCtxInfo.CancelFunc() + + // Reset the base context + baseCtx, baseCancel := context.WithCancel(ctx) + c.baseCtxInfo = &ContextInfo{ + Ctx: baseCtx, + CancelFunc: baseCancel, + } + + // Reset the memdb instance + if err := c.db.Flush(); err != nil { + return err + } + + default: + return errInvalidType + } + + c.logger.Debug("successfully cleared matching cache entries") + + return nil +} + +// handleRevocationRequest checks whether the originating request is a +// revocation request, and if so perform applicable cache cleanups. +// Returns true is this is a revocation request. +func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendRequest, resp *SendResponse) (bool, error) { + // Lease and token revocations return 204's on success. Fast-path if that's + // not the case. + if resp.Response.StatusCode != http.StatusNoContent { + return false, nil + } + + namespace, path := deriveNamespaceAndRevocationPath(req) + + switch { + case path == vaultPathTokenRevoke: + // Get the token from the request body + jsonBody := map[string]interface{}{} + if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil { + return false, err + } + token, ok := jsonBody["token"] + if !ok { + return false, fmt.Errorf("failed to get token from request body") + } + + // Clear the cache entry associated with the token and all the other + // entries belonging to the leases derived from this token. + if err := c.handleCacheClear(ctx, "token", token.(string)); err != nil { + return false, err + } + + case path == vaultPathTokenRevokeSelf: + // Clear the cache entry associated with the token and all the other + // entries belonging to the leases derived from this token. + if err := c.handleCacheClear(ctx, "token", req.Token); err != nil { + return false, err + } + + case path == vaultPathTokenRevokeAccessor: + jsonBody := map[string]interface{}{} + if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil { + return false, err + } + accessor, ok := jsonBody["accessor"] + if !ok { + return false, fmt.Errorf("failed to get accessor from request body") + } + + if err := c.handleCacheClear(ctx, "token_accessor", accessor.(string)); err != nil { + return false, err + } + + case path == vaultPathTokenRevokeOrphan: + jsonBody := map[string]interface{}{} + if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil { + return false, err + } + token, ok := jsonBody["token"] + if !ok { + return false, fmt.Errorf("failed to get token from request body") + } + + // Find out all the indexes that are directly tied to the revoked token + indexes, err := c.db.GetByPrefix(cachememdb.IndexNameToken.String(), token.(string)) + if err != nil { + return false, err + } + + // Out of these indexes, one will be for the token itself and the rest + // will be for leases of this token. Cancel the contexts of all the + // leases and return from renewer goroutine for the token's index + // without cancelling the context. Cancelling the context of the + // token's renewer will evict all the child tokens which is not + // desired. + for _, index := range indexes { + if index.Lease != "" { + index.RenewCtxInfo.CancelFunc() + } else { + close(index.RenewCtxInfo.DoneCh) + } + } + + // Clear the parent references of the revoked token + indexes, err = c.db.GetByPrefix(cachememdb.IndexNameTokenParent.String(), token.(string)) + if err != nil { + return false, err + } + for _, index := range indexes { + index.TokenParent = "" + err = c.db.Set(index) + if err != nil { + c.logger.Error("failed to persist index", "error", err) + return false, err + } + } + + case path == vaultPathLeaseRevoke: + // TODO: Should lease present in the URL itself be considered here? + // Get the lease from the request body + jsonBody := map[string]interface{}{} + if err := json.Unmarshal(req.RequestBody, &jsonBody); err != nil { + return false, err + } + leaseID, ok := jsonBody["lease_id"] + if !ok { + return false, fmt.Errorf("failed to get lease_id from request body") + } + if err := c.handleCacheClear(ctx, "lease", leaseID.(string)); err != nil { + return false, err + } + + case strings.HasPrefix(path, vaultPathLeaseRevokeForce): + // Trim the URL path to get the request path prefix + prefix := strings.TrimPrefix(path, vaultPathLeaseRevokeForce) + // Get all the cache indexes that use the request path containing the + // prefix and cancel the renewer context of each. + indexes, err := c.db.GetByPrefix("request_path", namespace, "/v1"+prefix) + if err != nil { + return false, err + } + for _, index := range indexes { + index.RenewCtxInfo.CancelFunc() + } + + case strings.HasPrefix(path, vaultPathLeaseRevokePrefix): + // Trim the URL path to get the request path prefix + prefix := strings.TrimPrefix(path, vaultPathLeaseRevokePrefix) + // Get all the cache indexes that use the request path containing the + // prefix and cancel the renewer context of each. + indexes, err := c.db.GetByPrefix("request_path", namespace, "/v1"+prefix) + if err != nil { + return false, err + } + for _, index := range indexes { + index.RenewCtxInfo.CancelFunc() + } + + default: + return false, nil + } + + c.logger.Debug("triggered caching eviction from revocation request") + + return true, nil +} + +// deriveNamespaceAndRevocationPath returns the namespace and relative path for +// revocation paths. +// +// If the path contains a namespace, but it's not a revocation path, it will be +// returned as-is, since there's no way to tell where the namespace ends and +// where the request path begins purely based off a string. +// +// Case 1: /v1/ns1/leases/revoke -> ns1/, /v1/leases/revoke +// Case 2: ns1/ /v1/leases/revoke -> ns1/, /v1/leases/revoke +// Case 3: /v1/ns1/foo/bar -> root/, /v1/ns1/foo/bar +// Case 4: ns1/ /v1/foo/bar -> ns1/, /v1/foo/bar +func deriveNamespaceAndRevocationPath(req *SendRequest) (string, string) { + namespace := "root/" + nsHeader := req.Request.Header.Get(consts.NamespaceHeaderName) + if nsHeader != "" { + namespace = nsHeader + } + + fullPath := req.Request.URL.Path + nonVersionedPath := strings.TrimPrefix(fullPath, "/v1") + + for _, pathToCheck := range revocationPaths { + // We use strings.Contains here for paths that can contain + // vars in the path, e.g. /v1/lease/revoke-prefix/:prefix + i := strings.Index(nonVersionedPath, pathToCheck) + // If there's no match, move on to the next check + if i == -1 { + continue + } + + // If the index is 0, this is a relative path with no namespace preppended, + // so we can break early + if i == 0 { + break + } + + // We need to turn /ns1 into ns1/, this makes it easy + namespaceInPath := nshelper.Canonicalize(nonVersionedPath[:i]) + + // If it's root, we replace, otherwise we join + if namespace == "root/" { + namespace = namespaceInPath + } else { + namespace = namespace + namespaceInPath + } + + return namespace, fmt.Sprintf("/v1%s", nonVersionedPath[i:]) + } + + return namespace, fmt.Sprintf("/v1%s", nonVersionedPath) +} diff --git a/command/agent/cache/lease_cache_test.go b/command/agent/cache/lease_cache_test.go new file mode 100644 index 000000000000..b8d08feab539 --- /dev/null +++ b/command/agent/cache/lease_cache_test.go @@ -0,0 +1,507 @@ +package cache + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "strings" + "testing" + + "github.com/go-test/deep" + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/logging" +) + +func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache { + t.Helper() + + lc, err := NewLeaseCache(&LeaseCacheConfig{ + BaseContext: context.Background(), + Proxier: newMockProxier(responses), + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), + }) + + if err != nil { + t.Fatal(err) + } + return lc +} + +func TestCache_ComputeIndexID(t *testing.T) { + type args struct { + req *http.Request + } + tests := []struct { + name string + req *SendRequest + want string + wantErr bool + }{ + { + "basic", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "test", + }, + }, + }, + "2edc7e965c3e1bdce3b1d5f79a52927842569c0734a86544d222753f11ae4847", + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := computeIndexID(tt.req) + if (err != nil) != tt.wantErr { + t.Errorf("actual_error: %v, expected_error: %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, string(tt.want)) { + t.Errorf("bad: index id; actual: %q, expected: %q", got, string(tt.want)) + } + }) + } +} + +func TestCache_LeaseCache_EmptyToken(t *testing.T) { + responses := []*SendResponse{ + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusCreated, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "test"}}`)), + }, + }, + ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "test"}}`), + }, + } + lc := testNewLeaseCache(t, responses) + + // Even if the send request doesn't have a token on it, a successful + // cacheable response should result in the index properly getting populated + // with a token and memdb shouldn't complain while inserting the index. + urlPath := "http://example.com/v1/sample/api" + sendReq := &SendRequest{ + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatalf("expected a non empty response") + } +} + +func TestCache_LeaseCache_SendCacheable(t *testing.T) { + // Emulate 2 responses from the api proxy. One returns a new token and the + // other returns a lease. + responses := []*SendResponse{ + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusCreated, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "test", "renewable": true}}`)), + }, + }, + ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "test", "renewable": true}}`), + }, + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo", "renewable": true}`)), + }, + }, + ResponseBody: []byte(`{"value": "output", "lease_id": "foo", "renewable": true}`), + }, + } + lc := testNewLeaseCache(t, responses) + + // Make a request. A response with a new token is returned to the lease + // cache and that will be cached. + urlPath := "http://example.com/v1/sample/api" + sendReq := &SendRequest{ + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Send the same request again to get the cached response + sendReq = &SendRequest{ + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Modify the request a little bit to ensure the second response is + // returned to the lease cache. But make sure that the token in the request + // is valid. + sendReq = &SendRequest{ + Token: "test", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Make the same request again and ensure that the same reponse is returned + // again. + sendReq = &SendRequest{ + Token: "test", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } +} + +func TestCache_LeaseCache_SendNonCacheable(t *testing.T) { + responses := []*SendResponse{ + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "output"}`)), + }, + }, + }, + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusNotFound, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid"}`)), + }, + }, + }, + } + lc := testNewLeaseCache(t, responses) + + // Send a request through the lease cache which is not cacheable (there is + // no lease information or auth information in the response) + sendReq := &SendRequest{ + Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Since the response is non-cacheable, the second response will be + // returned. + sendReq = &SendRequest{ + Token: "foo", + Request: httptest.NewRequest("GET", "http://example.com", strings.NewReader(`{"value": "input"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[1].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } +} + +func TestCache_LeaseCache_SendNonCacheableNonTokenLease(t *testing.T) { + // Create the cache + responses := []*SendResponse{ + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "output", "lease_id": "foo"}`)), + }, + }, + ResponseBody: []byte(`{"value": "output", "lease_id": "foo"}`), + }, + &SendResponse{ + Response: &api.Response{ + Response: &http.Response{ + StatusCode: http.StatusCreated, + Body: ioutil.NopCloser(strings.NewReader(`{"value": "invalid", "auth": {"client_token": "test"}}`)), + }, + }, + ResponseBody: []byte(`{"value": "invalid", "auth": {"client_token": "test"}}`), + }, + } + lc := testNewLeaseCache(t, responses) + + // Send a request through lease cache which returns a response containing + // lease_id. Response will not be cached because it doesn't belong to a + // token that is managed by the lease cache. + urlPath := "http://example.com/v1/sample/api" + sendReq := &SendRequest{ + Token: "foo", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Verify that the response is not cached by sending the same request and + // by expecting a different response. + sendReq = &SendRequest{ + Token: "foo", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff == nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } +} + +func TestCache_LeaseCache_HandleCacheClear(t *testing.T) { + lc := testNewLeaseCache(t, nil) + + handler := lc.HandleCacheClear(context.Background()) + ts := httptest.NewServer(handler) + defer ts.Close() + + // Test missing body, should return 400 + resp, err := http.Post(ts.URL, "application/json", nil) + if err != nil { + t.Fatal() + } + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status code mismatch: expected = %v, got = %v", http.StatusBadRequest, resp.StatusCode) + } + + testCases := []struct { + name string + reqType string + reqValue string + expectedStatusCode int + }{ + { + "invalid_type", + "foo", + "", + http.StatusBadRequest, + }, + { + "invalid_value", + "", + "bar", + http.StatusBadRequest, + }, + { + "all", + "all", + "", + http.StatusOK, + }, + { + "by_request_path", + "request_path", + "foo", + http.StatusOK, + }, + { + "by_token", + "token", + "foo", + http.StatusOK, + }, + { + "by_lease", + "lease", + "foo", + http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + reqBody := fmt.Sprintf("{\"type\": \"%s\", \"value\": \"%s\"}", tc.reqType, tc.reqValue) + resp, err := http.Post(ts.URL, "application/json", strings.NewReader(reqBody)) + if err != nil { + t.Fatal(err) + } + if tc.expectedStatusCode != resp.StatusCode { + t.Fatalf("status code mismatch: expected = %v, got = %v", tc.expectedStatusCode, resp.StatusCode) + } + }) + } +} + +func TestCache_DeriveNamespaceAndRevocationPath(t *testing.T) { + tests := []struct { + name string + req *SendRequest + wantNamespace string + wantRelativePath string + }{ + { + "non_revocation_full_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns1/sys/mounts", + }, + }, + }, + "root/", + "/v1/ns1/sys/mounts", + }, + { + "non_revocation_relative_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/sys/mounts", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/", + "/v1/sys/mounts", + }, + { + "non_revocation_relative_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns2/sys/mounts", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/", + "/v1/ns2/sys/mounts", + }, + { + "revocation_full_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns1/sys/leases/revoke", + }, + }, + }, + "ns1/", + "/v1/sys/leases/revoke", + }, + { + "revocation_relative_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/sys/leases/revoke", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/", + "/v1/sys/leases/revoke", + }, + { + "revocation_relative_partial_ns", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns2/sys/leases/revoke", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/ns2/", + "/v1/sys/leases/revoke", + }, + { + "revocation_prefix_full_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns1/sys/leases/revoke-prefix/foo", + }, + }, + }, + "ns1/", + "/v1/sys/leases/revoke-prefix/foo", + }, + { + "revocation_prefix_relative_path", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/sys/leases/revoke-prefix/foo", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/", + "/v1/sys/leases/revoke-prefix/foo", + }, + { + "revocation_prefix_partial_ns", + &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns2/sys/leases/revoke-prefix/foo", + }, + Header: http.Header{ + consts.NamespaceHeaderName: []string{"ns1/"}, + }, + }, + }, + "ns1/ns2/", + "/v1/sys/leases/revoke-prefix/foo", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotNamespace, gotRelativePath := deriveNamespaceAndRevocationPath(tt.req) + if gotNamespace != tt.wantNamespace { + t.Errorf("deriveNamespaceAndRevocationPath() gotNamespace = %v, want %v", gotNamespace, tt.wantNamespace) + } + if gotRelativePath != tt.wantRelativePath { + t.Errorf("deriveNamespaceAndRevocationPath() gotRelativePath = %v, want %v", gotRelativePath, tt.wantRelativePath) + } + }) + } +} diff --git a/command/agent/cache/listener.go b/command/agent/cache/listener.go new file mode 100644 index 000000000000..2515138bcb0f --- /dev/null +++ b/command/agent/cache/listener.go @@ -0,0 +1,120 @@ +package cache + +import ( + "fmt" + "io" + "net" + "os" + "strings" + + "github.com/hashicorp/vault/command/agent/config" + "github.com/hashicorp/vault/command/server" + "github.com/hashicorp/vault/helper/reload" + "github.com/mitchellh/cli" +) + +func ServerListeners(lnConfigs []*config.Listener, logger io.Writer, ui cli.Ui) ([]net.Listener, error) { + var listeners []net.Listener + var listener net.Listener + var err error + for _, lnConfig := range lnConfigs { + switch lnConfig.Type { + case "unix": + listener, _, _, err = unixSocketListener(lnConfig.Config, logger, ui) + if err != nil { + return nil, err + } + listeners = append(listeners, listener) + case "tcp": + listener, _, _, err := tcpListener(lnConfig.Config, logger, ui) + if err != nil { + return nil, err + } + listeners = append(listeners, listener) + default: + return nil, fmt.Errorf("unsupported listener type: %q", lnConfig.Type) + } + } + + return listeners, nil +} + +func unixSocketListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) { + addr, ok := config["address"].(string) + if !ok { + return nil, nil, nil, fmt.Errorf("invalid address: %v", config["address"]) + } + + if addr == "" { + return nil, nil, nil, fmt.Errorf("address field should point to socket file path") + } + + // Remove the socket file as it shouldn't exist for the domain socket to + // work + err := os.Remove(addr) + if err != nil && !os.IsNotExist(err) { + return nil, nil, nil, fmt.Errorf("failed to remove the socket file: %v", err) + } + + listener, err := net.Listen("unix", addr) + if err != nil { + return nil, nil, nil, err + } + + // Wrap the listener in rmListener so that the Unix domain socket file is + // removed on close. + listener = &rmListener{ + Listener: listener, + Path: addr, + } + + props := map[string]string{"addr": addr} + + return server.ListenerWrapTLS(listener, props, config, ui) +} + +func tcpListener(config map[string]interface{}, _ io.Writer, ui cli.Ui) (net.Listener, map[string]string, reload.ReloadFunc, error) { + bindProto := "tcp" + var addr string + addrRaw, ok := config["address"] + if !ok { + addr = "127.0.0.1:8300" + } else { + addr = addrRaw.(string) + } + + // If they've passed 0.0.0.0, we only want to bind on IPv4 + // rather than golang's dual stack default + if strings.HasPrefix(addr, "0.0.0.0:") { + bindProto = "tcp4" + } + + ln, err := net.Listen(bindProto, addr) + if err != nil { + return nil, nil, nil, err + } + + ln = server.TCPKeepAliveListener{ln.(*net.TCPListener)} + + props := map[string]string{"addr": addr} + + return server.ListenerWrapTLS(ln, props, config, ui) +} + +// rmListener is an implementation of net.Listener that forwards most +// calls to the listener but also removes a file as part of the close. We +// use this to cleanup the unix domain socket on close. +type rmListener struct { + net.Listener + Path string +} + +func (l *rmListener) Close() error { + // Close the listener itself + if err := l.Listener.Close(); err != nil { + return err + } + + // Remove the file + return os.Remove(l.Path) +} diff --git a/command/agent/cache/proxy.go b/command/agent/cache/proxy.go new file mode 100644 index 000000000000..4637590917e9 --- /dev/null +++ b/command/agent/cache/proxy.go @@ -0,0 +1,28 @@ +package cache + +import ( + "context" + "net/http" + + "github.com/hashicorp/vault/api" +) + +// SendRequest is the input for Proxier.Send. +type SendRequest struct { + Token string + Request *http.Request + RequestBody []byte +} + +// SendResponse is the output from Proxier.Send. +type SendResponse struct { + Response *api.Response + ResponseBody []byte +} + +// Proxier is the interface implemented by different components that are +// responsible for performing specific tasks, such as caching and proxying. All +// these tasks combined together would serve the request received by the agent. +type Proxier interface { + Send(ctx context.Context, req *SendRequest) (*SendResponse, error) +} diff --git a/command/agent/cache/testing.go b/command/agent/cache/testing.go new file mode 100644 index 000000000000..d9de1caadc7d --- /dev/null +++ b/command/agent/cache/testing.go @@ -0,0 +1,36 @@ +package cache + +import ( + "context" + "fmt" +) + +// mockProxier is a mock implementation of the Proxier interface, used for testing purposes. +// The mock will return the provided responses every time it reaches its Send method, up to +// the last provided response. This lets tests control what the next/underlying Proxier layer +// might expect to return. +type mockProxier struct { + proxiedResponses []*SendResponse + responseIndex int +} + +func newMockProxier(responses []*SendResponse) *mockProxier { + return &mockProxier{ + proxiedResponses: responses, + } +} + +func (p *mockProxier) Send(ctx context.Context, req *SendRequest) (*SendResponse, error) { + if p.responseIndex >= len(p.proxiedResponses) { + return nil, fmt.Errorf("index out of bounds: responseIndex = %d, responses = %d", p.responseIndex, len(p.proxiedResponses)) + } + resp := p.proxiedResponses[p.responseIndex] + + p.responseIndex++ + + return resp, nil +} + +func (p *mockProxier) ResponseIndex() int { + return p.responseIndex +} diff --git a/command/agent/config/config.go b/command/agent/config/config.go index 3a18b946efac..1ee5a9992be0 100644 --- a/command/agent/config/config.go +++ b/command/agent/config/config.go @@ -22,6 +22,17 @@ type Config struct { AutoAuth *AutoAuth `hcl:"auto_auth"` ExitAfterAuth bool `hcl:"exit_after_auth"` PidFile string `hcl:"pid_file"` + Cache *Cache `hcl:"cache"` +} + +type Cache struct { + UseAutoAuthToken bool `hcl:"use_auto_auth_token"` + Listeners []*Listener `hcl:"listeners"` +} + +type Listener struct { + Type string + Config map[string]interface{} } type AutoAuth struct { @@ -91,9 +102,90 @@ func LoadConfig(path string, logger log.Logger) (*Config, error) { return nil, errwrap.Wrapf("error parsing 'auto_auth': {{err}}", err) } + err = parseCache(&result, list) + if err != nil { + return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err) + } + return &result, nil } +func parseCache(result *Config, list *ast.ObjectList) error { + name := "cache" + + cacheList := list.Filter(name) + if len(cacheList.Items) != 1 { + return fmt.Errorf("one and only one %q block is required", name) + } + + item := cacheList.Items[0] + + var c Cache + err := hcl.DecodeObject(&c, item.Val) + if err != nil { + return err + } + + result.Cache = &c + + subs, ok := item.Val.(*ast.ObjectType) + if !ok { + return fmt.Errorf("could not parse %q as an object", name) + } + subList := subs.List + + err = parseListeners(result, subList) + if err != nil { + return errwrap.Wrapf("error parsing 'listener' stanzas: {{err}}", err) + } + + return nil +} + +func parseListeners(result *Config, list *ast.ObjectList) error { + name := "listener" + + listenerList := list.Filter(name) + if len(listenerList.Items) < 1 { + return fmt.Errorf("at least one %q block is required", name) + } + + var listeners []*Listener + for _, item := range listenerList.Items { + var lnConfig map[string]interface{} + err := hcl.DecodeObject(&lnConfig, item.Val) + if err != nil { + return err + } + + var lnType string + switch { + case lnConfig["type"] != nil: + lnType = lnConfig["type"].(string) + delete(lnConfig, "type") + case len(item.Keys) == 1: + lnType = strings.ToLower(item.Keys[0].Token.Value().(string)) + default: + return errors.New("listener type must be specified") + } + + switch lnType { + case "unix", "tcp": + default: + return fmt.Errorf("invalid listener type %q", lnType) + } + + listeners = append(listeners, &Listener{ + Type: lnType, + Config: lnConfig, + }) + } + + result.Cache.Listeners = listeners + + return nil +} + func parseAutoAuth(result *Config, list *ast.ObjectList) error { name := "auto_auth" diff --git a/command/agent/config/config_test.go b/command/agent/config/config_test.go index 2f78b4fb04fa..465df76f8009 100644 --- a/command/agent/config/config_test.go +++ b/command/agent/config/config_test.go @@ -10,6 +10,83 @@ import ( "github.com/hashicorp/vault/helper/logging" ) +func TestLoadConfigFile_AgentCache(t *testing.T) { + logger := logging.NewVaultLogger(log.Debug) + + os.Setenv("TEST_AAD_ENV", "aad") + defer os.Unsetenv("TEST_AAD_ENV") + + config, err := LoadConfig("./test-fixtures/config-cache.hcl", logger) + if err != nil { + t.Fatal(err) + } + + expected := &Config{ + AutoAuth: &AutoAuth{ + Method: &Method{ + Type: "aws", + WrapTTL: 300 * time.Second, + MountPath: "auth/aws", + Config: map[string]interface{}{ + "role": "foobar", + }, + }, + Sinks: []*Sink{ + &Sink{ + Type: "file", + DHType: "curve25519", + DHPath: "/tmp/file-foo-dhpath", + AAD: "foobar", + Config: map[string]interface{}{ + "path": "/tmp/file-foo", + }, + }, + }, + }, + Cache: &Cache{ + UseAutoAuthToken: true, + Listeners: []*Listener{ + &Listener{ + Type: "unix", + Config: map[string]interface{}{ + "address": "/Users/vishal/go/src/github.com/hashicorp/vault/socket", + "tls_disable": true, + }, + }, + &Listener{ + Type: "tcp", + Config: map[string]interface{}{ + "address": "127.0.0.1:8300", + "tls_disable": true, + }, + }, + &Listener{ + Type: "tcp", + Config: map[string]interface{}{ + "address": "127.0.0.1:8400", + "tls_key_file": "/Users/vishal/go/src/github.com/hashicorp/vault/cakey.pem", + "tls_cert_file": "/Users/vishal/go/src/github.com/hashicorp/vault/cacert.pem", + }, + }, + }, + }, + PidFile: "./pidfile", + } + + if diff := deep.Equal(config, expected); diff != nil { + t.Fatal(diff) + } + + config, err = LoadConfig("./test-fixtures/config-cache-embedded-type.hcl", logger) + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(config, expected); diff != nil { + t.Fatal(diff) + } +} + func TestLoadConfigFile(t *testing.T) { logger := logging.NewVaultLogger(log.Debug) diff --git a/command/agent/config/test-fixtures/config-cache-embedded-type.hcl b/command/agent/config/test-fixtures/config-cache-embedded-type.hcl new file mode 100644 index 000000000000..01728a6597a0 --- /dev/null +++ b/command/agent/config/test-fixtures/config-cache-embedded-type.hcl @@ -0,0 +1,44 @@ +pid_file = "./pidfile" + +auto_auth { + method { + type = "aws" + wrap_ttl = 300 + config = { + role = "foobar" + } + } + + sink { + type = "file" + config = { + path = "/tmp/file-foo" + } + aad = "foobar" + dh_type = "curve25519" + dh_path = "/tmp/file-foo-dhpath" + } +} + +cache { + use_auto_auth_token = true + + listener { + type = "unix" + address = "/Users/vishal/go/src/github.com/hashicorp/vault/socket" + tls_disable = true + } + + listener { + type = "tcp" + address = "127.0.0.1:8300" + tls_disable = true + } + + listener { + type = "tcp" + address = "127.0.0.1:8400" + tls_key_file = "/Users/vishal/go/src/github.com/hashicorp/vault/cakey.pem" + tls_cert_file = "/Users/vishal/go/src/github.com/hashicorp/vault/cacert.pem" + } +} diff --git a/command/agent/config/test-fixtures/config-cache.hcl b/command/agent/config/test-fixtures/config-cache.hcl new file mode 100644 index 000000000000..c9ac5a3694e1 --- /dev/null +++ b/command/agent/config/test-fixtures/config-cache.hcl @@ -0,0 +1,41 @@ +pid_file = "./pidfile" + +auto_auth { + method { + type = "aws" + wrap_ttl = 300 + config = { + role = "foobar" + } + } + + sink { + type = "file" + config = { + path = "/tmp/file-foo" + } + aad = "foobar" + dh_type = "curve25519" + dh_path = "/tmp/file-foo-dhpath" + } +} + +cache { + use_auto_auth_token = true + + listener "unix" { + address = "/Users/vishal/go/src/github.com/hashicorp/vault/socket" + tls_disable = true + } + + listener "tcp" { + address = "127.0.0.1:8300" + tls_disable = true + } + + listener "tcp" { + address = "127.0.0.1:8400" + tls_key_file = "/Users/vishal/go/src/github.com/hashicorp/vault/cakey.pem" + tls_cert_file = "/Users/vishal/go/src/github.com/hashicorp/vault/cacert.pem" + } +} diff --git a/command/base.go b/command/base.go index db37fd37c380..144e16435a80 100644 --- a/command/base.go +++ b/command/base.go @@ -39,6 +39,7 @@ type BaseCommand struct { flagsOnce sync.Once flagAddress string + flagAgentAddress string flagCACert string flagCAPath string flagClientCert string @@ -78,6 +79,9 @@ func (c *BaseCommand) Client() (*api.Client, error) { if c.flagAddress != "" { config.Address = c.flagAddress } + if c.flagAgentAddress != "" { + config.Address = c.flagAgentAddress + } if c.flagOutputCurlString { config.OutputCurlString = c.flagOutputCurlString @@ -220,6 +224,15 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { } f.StringVar(addrStringVar) + agentAddrStringVar := &StringVar{ + Name: "agent-address", + Target: &c.flagAgentAddress, + EnvVar: "VAULT_AGENT_ADDR", + Completion: complete.PredictAnything, + Usage: "Address of the Agent.", + } + f.StringVar(agentAddrStringVar) + f.StringVar(&StringVar{ Name: "ca-cert", Target: &c.flagCACert, diff --git a/command/server/listener.go b/command/server/listener.go index a1f2f392684c..6546972260f2 100644 --- a/command/server/listener.go +++ b/command/server/listener.go @@ -72,7 +72,7 @@ func listenerWrapProxy(ln net.Listener, config map[string]interface{}) (net.List return newLn, nil } -func listenerWrapTLS( +func ListenerWrapTLS( ln net.Listener, props map[string]string, config map[string]interface{}, diff --git a/command/server/listener_tcp.go b/command/server/listener_tcp.go index 201e124f3aae..02b7b309fa83 100644 --- a/command/server/listener_tcp.go +++ b/command/server/listener_tcp.go @@ -35,7 +35,7 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) ( return nil, nil, nil, err } - ln = tcpKeepAliveListener{ln.(*net.TCPListener)} + ln = TCPKeepAliveListener{ln.(*net.TCPListener)} ln, err = listenerWrapProxy(ln, config) if err != nil { @@ -94,20 +94,20 @@ func tcpListenerFactory(config map[string]interface{}, _ io.Writer, ui cli.Ui) ( config["x_forwarded_for_reject_not_authorized"] = true } - return listenerWrapTLS(ln, props, config, ui) + return ListenerWrapTLS(ln, props, config, ui) } -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted // connections. It's used by ListenAndServe and ListenAndServeTLS so // dead TCP connections (e.g. closing laptop mid-download) eventually // go away. // // This is copied directly from the Go source code. -type tcpKeepAliveListener struct { +type TCPKeepAliveListener struct { *net.TCPListener } -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { +func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) { tc, err := ln.AcceptTCP() if err != nil { return diff --git a/helper/contextutil/context.go b/helper/contextutil/context.go new file mode 100644 index 000000000000..10b2470cf7db --- /dev/null +++ b/helper/contextutil/context.go @@ -0,0 +1,13 @@ +package contextutil + +import ( + "context" + "time" +) + +func BackoffOrQuit(ctx context.Context, backoff time.Duration) { + select { + case <-time.After(backoff): + case <-ctx.Done(): + } +}