diff --git a/api/api_integration_test.go b/api/api_integration_test.go new file mode 100644 index 000000000000..90a90f68c178 --- /dev/null +++ b/api/api_integration_test.go @@ -0,0 +1,96 @@ +package api_test + +import ( + "database/sql" + "fmt" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/logical/pki" + "github.com/hashicorp/vault/builtin/logical/transit" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" + + vaulthttp "github.com/hashicorp/vault/http" + logxi "github.com/mgutz/logxi/v1" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var testVaultServerDefaultBackends = map[string]logical.Factory{ + "transit": transit.Factory, + "pki": pki.Factory, +} + +func testVaultServer(t testing.TB) (*api.Client, func()) { + return testVaultServerBackends(t, testVaultServerDefaultBackends) +} + +func testVaultServerBackends(t testing.TB, backends map[string]logical.Factory) (*api.Client, func()) { + coreConfig := &vault.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Logger: logxi.NullLog, + LogicalBackends: backends, + } + + cluster := vault.NewTestCluster(t, coreConfig, true) + cluster.StartListeners() + for _, core := range cluster.Cores { + core.Handler.Handle("/", vaulthttp.Handler(core.Core)) + } + + // make it easy to get access to the active + core := cluster.Cores[0].Core + vault.TestWaitActive(t, core) + + // Grab the root token + rootToken := cluster.Cores[0].Root + + client := cluster.Cores[0].Client + client.SetToken(rootToken) + + // Sanity check + secret, err := client.Auth().Token().LookupSelf() + if err != nil { + t.Fatal(err) + } + if secret == nil || secret.Data["id"].(string) != rootToken { + t.Fatalf("token mismatch: %#v vs %q", secret, rootToken) + } + return client, func() { defer cluster.CloseListeners() } +} + +// testPostgresDB creates a testing postgres database in a Docker container, +// returning the connection URL and the associated closer function. +func testPostgresDB(t testing.TB) (string, func()) { + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("postgresdb: failed to connect to docker: %s", err) + } + + resource, err := pool.Run("postgres", "latest", []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_DB=database", + }) + if err != nil { + t.Fatalf("postgresdb: could not start container: %s", err) + } + + addr := fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp")) + + if err := pool.Retry(func() error { + db, err := sql.Open("postgres", addr) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("postgresdb: could not connect: %s", err) + } + + return addr, func() { + if err := pool.Purge(resource); err != nil { + t.Fatalf("postgresdb: failed to cleanup container: %s", err) + } + } +} diff --git a/api/auth_token.go b/api/auth_token.go index aff10f4109cb..4f74f61fe5f2 100644 --- a/api/auth_token.go +++ b/api/auth_token.go @@ -135,6 +135,26 @@ func (c *TokenAuth) RenewSelf(increment int) (*Secret, error) { return ParseSecret(resp.Body) } +// RenewTokenAsSelf behaves like renew-self, but authenticates using a provided +// token instead of the token attached to the client. +func (c *TokenAuth) RenewTokenAsSelf(token string, increment int) (*Secret, error) { + r := c.c.NewRequest("PUT", "/v1/auth/token/renew-self") + r.ClientToken = token + + body := map[string]interface{}{"increment": increment} + if err := r.SetJSONBody(body); err != nil { + return nil, err + } + + resp, err := c.c.RawRequest(r) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return ParseSecret(resp.Body) +} + // RevokeAccessor revokes a token associated with the given accessor // along with all the child tokens. func (c *TokenAuth) RevokeAccessor(accessor string) error { diff --git a/api/renewer.go b/api/renewer.go new file mode 100644 index 000000000000..3f9f23bc3e61 --- /dev/null +++ b/api/renewer.go @@ -0,0 +1,302 @@ +package api + +import ( + "errors" + "math/rand" + "sync" + "time" +) + +var ( + ErrRenewerMissingInput = errors.New("missing input to renewer") + ErrRenewerMissingSecret = errors.New("missing secret to renew") + ErrRenewerNotRenewable = errors.New("secret is not renewable") + ErrRenewerNoSecretData = errors.New("returned empty secret data") + + // DefaultRenewerGrace is the default grace period + DefaultRenewerGrace = 15 * time.Second + + // DefaultRenewerRenewBuffer is the default size of the buffer for renew + // messages on the channel. + DefaultRenewerRenewBuffer = 5 +) + +// Renewer is a process for renewing a secret. +// +// renewer, err := client.NewRenewer(&RenewerInput{ +// Secret: mySecret, +// }) +// go renewer.Renew() +// defer renewer.Stop() +// +// for { +// select { +// case err := <-renewer.DoneCh(): +// if err != nil { +// log.Fatal(err) +// } +// +// // Renewal is now over +// case renewal := <-renewer.RenewCh(): +// log.Printf("Successfully renewed: %#v", renewal) +// } +// } +// +// +// The `DoneCh` will return if renewal fails or if the remaining lease duration +// after a renewal is less than or equal to the grace (in number of seconds). In +// both cases, the caller should attempt a re-read of the secret. Clients should +// check the return value of the channel to see if renewal was successful. +type Renewer struct { + l sync.Mutex + + client *Client + secret *Secret + grace time.Duration + random *rand.Rand + doneCh chan error + renewCh chan *RenewOutput + + stopped bool + stopCh chan struct{} +} + +// RenewerInput is used as input to the renew function. +type RenewerInput struct { + // Secret is the secret to renew + Secret *Secret + + // Grace is a minimum renewal before returning so the upstream client + // can do a re-read. This can be used to prevent clients from waiting + // too long to read a new credential and incur downtime. + Grace time.Duration + + // Rand is the randomizer to use for underlying randomization. If not + // provided, one will be generated and seeded automatically. If provided, it + // is assumed to have already been seeded. + Rand *rand.Rand + + // RenewBuffer is the size of the buffered channel where renew messages are + // dispatched. + RenewBuffer int +} + +// RenewOutput is the metadata returned to the client (if it's listening) to +// renew messages. +type RenewOutput struct { + // RenewedAt is the timestamp when the renewal took place (UTC). + RenewedAt time.Time + + // Secret is the underlying renewal data. It's the same struct as all data + // that is returned from Vault, but since this is renewal data, it will not + // usually include the secret itself. + Secret *Secret +} + +// NewRenewer creates a new renewer from the given input. +func (c *Client) NewRenewer(i *RenewerInput) (*Renewer, error) { + if i == nil { + return nil, ErrRenewerMissingInput + } + + secret := i.Secret + if secret == nil { + return nil, ErrRenewerMissingSecret + } + + grace := i.Grace + if grace == 0 { + grace = DefaultRenewerGrace + } + + random := i.Rand + if random == nil { + random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + } + + renewBuffer := i.RenewBuffer + if renewBuffer == 0 { + renewBuffer = DefaultRenewerRenewBuffer + } + + return &Renewer{ + client: c, + secret: secret, + grace: grace, + random: random, + doneCh: make(chan error, 1), + renewCh: make(chan *RenewOutput, renewBuffer), + + stopped: false, + stopCh: make(chan struct{}), + }, nil +} + +// DoneCh returns the channel where the renewer will publish when renewal stops. +// If there is an error, this will be an error. +func (r *Renewer) DoneCh() <-chan error { + return r.doneCh +} + +// RenewCh is a channel that receives a message when a successful renewal takes +// place and includes metadata about the renewal. +func (r *Renewer) RenewCh() <-chan *RenewOutput { + return r.renewCh +} + +// Stop stops the renewer. +func (r *Renewer) Stop() { + r.l.Lock() + if !r.stopped { + close(r.stopCh) + r.stopped = true + } + r.l.Unlock() +} + +// Renew starts a background process for renewing this secret. When the secret +// is has auth data, this attempts to renew the auth (token). When the secret +// has a lease, this attempts to renew the lease. +func (r *Renewer) Renew() { + var result error + if r.secret.Auth != nil { + result = r.renewAuth() + } else { + result = r.renewLease() + } + + select { + case r.doneCh <- result: + case <-r.stopCh: + } +} + +// renewAuth is a helper for renewing authentication. +func (r *Renewer) renewAuth() error { + if !r.secret.Auth.Renewable || r.secret.Auth.ClientToken == "" { + return ErrRenewerNotRenewable + } + + client, token := r.client, r.secret.Auth.ClientToken + + for { + // Check if we are stopped. + select { + case <-r.stopCh: + return nil + default: + } + + // Renew the auth. + renewal, err := client.Auth().Token().RenewTokenAsSelf(token, 0) + if err != nil { + return err + } + + // Push a message that a renewal took place. + select { + case r.renewCh <- &RenewOutput{time.Now().UTC(), renewal}: + default: + } + + // Somehow, sometimes, this happens. + if renewal == nil || renewal.Auth == nil { + return ErrRenewerNoSecretData + } + + // Do nothing if we are not renewable + if !renewal.Auth.Renewable { + return ErrRenewerNotRenewable + } + + // Grab the lease duration and sleep duration - note that we grab the auth + // lease duration, not the secret lease duration. + leaseDuration := time.Duration(renewal.Auth.LeaseDuration) * time.Second + sleepDuration := r.sleepDuration(leaseDuration) + + // If we are within grace, return now. + if leaseDuration <= r.grace || sleepDuration <= r.grace { + return nil + } + + select { + case <-r.stopCh: + return nil + case <-time.After(sleepDuration): + continue + } + } +} + +// renewLease is a helper for renewing a lease. +func (r *Renewer) renewLease() error { + if !r.secret.Renewable || r.secret.LeaseID == "" { + return ErrRenewerNotRenewable + } + + client, leaseID := r.client, r.secret.LeaseID + + for { + // Check if we are stopped. + select { + case <-r.stopCh: + return nil + default: + } + + // Renew the lease. + renewal, err := client.Sys().Renew(leaseID, 0) + if err != nil { + return err + } + + // Push a message that a renewal took place. + select { + case r.renewCh <- &RenewOutput{time.Now().UTC(), renewal}: + default: + } + + // Somehow, sometimes, this happens. + if renewal == nil { + return ErrRenewerNoSecretData + } + + // Do nothing if we are not renewable + if !renewal.Renewable { + return ErrRenewerNotRenewable + } + + // Grab the lease duration and sleep duration + leaseDuration := time.Duration(renewal.LeaseDuration) * time.Second + sleepDuration := r.sleepDuration(leaseDuration) + + // If we are within grace, return now. + if leaseDuration <= r.grace || sleepDuration <= r.grace { + return nil + } + + select { + case <-r.stopCh: + return nil + case <-time.After(sleepDuration): + continue + } + } +} + +// sleepDuration calculates the time to sleep given the base lease duration. The +// base is the resulting lease duration. It will be reduced to 1/3 and +// multiplied by a random float between 0.0 and 1.0. This extra randomness +// prevents multiple clients from all trying to renew simultaneously. +func (r *Renewer) sleepDuration(base time.Duration) time.Duration { + sleep := float64(base) + + // Renew at 1/3 the remaining lease. This will give us an opportunity to retry + // at least one more time should the first renewal fail. + sleep = sleep / 3.0 + + // Use a randomness so many clients do not hit Vault simultaneously. + sleep = sleep * r.random.Float64() + + return time.Duration(sleep) * time.Second +} diff --git a/api/renewer_integration_test.go b/api/renewer_integration_test.go new file mode 100644 index 000000000000..82ffe508d176 --- /dev/null +++ b/api/renewer_integration_test.go @@ -0,0 +1,228 @@ +package api_test + +import ( + "testing" + "time" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/logical/database" + "github.com/hashicorp/vault/builtin/logical/pki" + "github.com/hashicorp/vault/builtin/logical/transit" + "github.com/hashicorp/vault/logical" +) + +func TestRenewer_Renew(t *testing.T) { + t.Parallel() + + client, vaultDone := testVaultServerBackends(t, map[string]logical.Factory{ + "database": database.Factory, + "pki": pki.Factory, + "transit": transit.Factory, + }) + defer vaultDone() + + pgURL, pgDone := testPostgresDB(t) + defer pgDone() + + t.Run("group", func(t *testing.T) { + t.Run("generic", func(t *testing.T) { + t.Parallel() + + if _, err := client.Logical().Write("secret/value", map[string]interface{}{ + "foo": "bar", + }); err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Read("secret/value") + if err != nil { + t.Fatal(err) + } + + v, err := client.NewRenewer(&api.RenewerInput{ + Secret: secret, + }) + if err != nil { + t.Fatal(err) + } + go v.Renew() + defer v.Stop() + + select { + case err := <-v.DoneCh(): + if err != api.ErrRenewerNotRenewable { + t.Fatal(err) + } + case renew := <-v.RenewCh(): + t.Errorf("received renew, but should have been nil: %#v", renew) + case <-time.After(500 * time.Millisecond): + t.Error("should have been non-renewable") + } + }) + + t.Run("transit", func(t *testing.T) { + t.Parallel() + + if err := client.Sys().Mount("transit", &api.MountInput{ + Type: "transit", + }); err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Write("transit/encrypt/my-app", map[string]interface{}{ + "plaintext": "Zm9vCg==", + }) + if err != nil { + t.Fatal(err) + } + + v, err := client.NewRenewer(&api.RenewerInput{ + Secret: secret, + }) + if err != nil { + t.Fatal(err) + } + go v.Renew() + defer v.Stop() + + select { + case err := <-v.DoneCh(): + if err != api.ErrRenewerNotRenewable { + t.Fatal(err) + } + case renew := <-v.RenewCh(): + t.Errorf("received renew, but should have been nil: %#v", renew) + case <-time.After(500 * time.Millisecond): + t.Error("should have been non-renewable") + } + }) + + t.Run("database", func(t *testing.T) { + t.Parallel() + + if err := client.Sys().Mount("database", &api.MountInput{ + Type: "database", + }); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("database/config/postgresql", map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_url": pgURL, + "allowed_roles": "readonly", + }); err != nil { + t.Fatal(err) + } + if _, err := client.Logical().Write("database/roles/readonly", map[string]interface{}{ + "db_name": "postgresql", + "creation_statements": `` + + `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';` + + `GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";`, + "default_ttl": "1s", + "max_ttl": "3s", + }); err != nil { + t.Fatal(err) + } + + secret, err := client.Logical().Read("database/creds/readonly") + if err != nil { + t.Fatal(err) + } + + v, err := client.NewRenewer(&api.RenewerInput{ + Secret: secret, + }) + if err != nil { + t.Fatal(err) + } + go v.Renew() + defer v.Stop() + + select { + case err := <-v.DoneCh(): + t.Errorf("should have renewed once before returning: %s", err) + case renew := <-v.RenewCh(): + if renew == nil { + t.Fatal("renew is nil") + } + if !renew.Secret.Renewable { + t.Errorf("expected lease to be renewable: %#v", renew) + } + if renew.Secret.LeaseDuration > 2 { + t.Errorf("expected lease to < 2s: %#v", renew) + } + case <-time.After(3 * time.Second): + t.Errorf("no renewal") + } + + select { + case err := <-v.DoneCh(): + if err != nil { + t.Fatal(err) + } + case renew := <-v.RenewCh(): + t.Fatalf("should not have renewed (lease should be up): %#v", renew) + case <-time.After(3 * time.Second): + t.Errorf("no data") + } + }) + + t.Run("auth", func(t *testing.T) { + t.Parallel() + + secret, err := client.Auth().Token().Create(&api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: "1s", + ExplicitMaxTTL: "3s", + }) + if err != nil { + t.Fatal(err) + } + + v, err := client.NewRenewer(&api.RenewerInput{ + Secret: secret, + }) + if err != nil { + t.Fatal(err) + } + go v.Renew() + defer v.Stop() + + select { + case err := <-v.DoneCh(): + t.Errorf("should have renewed once before returning: %s", err) + case renew := <-v.RenewCh(): + if renew == nil { + t.Fatal("renew is nil") + } + if renew.Secret.Auth == nil { + t.Fatal("renew auth is nil") + } + if !renew.Secret.Auth.Renewable { + t.Errorf("expected lease to be renewable: %#v", renew) + } + if renew.Secret.Auth.LeaseDuration > 2 { + t.Errorf("expected lease to < 2s: %#v", renew) + } + if renew.Secret.Auth.ClientToken == "" { + t.Error("expected a client token") + } + if renew.Secret.Auth.Accessor == "" { + t.Error("expected an accessor") + } + case <-time.After(3 * time.Second): + t.Errorf("no renewal") + } + + select { + case err := <-v.DoneCh(): + if err != nil { + t.Fatal(err) + } + case renew := <-v.RenewCh(): + t.Fatalf("should not have renewed (lease should be up): %#v", renew) + case <-time.After(3 * time.Second): + t.Errorf("no data") + } + }) + }) +} diff --git a/api/renewer_test.go b/api/renewer_test.go new file mode 100644 index 000000000000..262484e0fa01 --- /dev/null +++ b/api/renewer_test.go @@ -0,0 +1,85 @@ +package api + +import ( + "reflect" + "testing" + "time" +) + +func TestRenewer_NewRenewer(t *testing.T) { + t.Parallel() + + client, err := NewClient(DefaultConfig()) + if err != nil { + t.Fatal(err) + } + + cases := []struct { + name string + i *RenewerInput + e *Renewer + err bool + }{ + { + "nil", + nil, + nil, + true, + }, + { + "missing_secret", + &RenewerInput{ + Secret: nil, + }, + nil, + true, + }, + { + "default_grace", + &RenewerInput{ + Secret: &Secret{}, + }, + &Renewer{ + secret: &Secret{}, + grace: DefaultRenewerGrace, + }, + false, + }, + { + "custom_grace", + &RenewerInput{ + Secret: &Secret{}, + Grace: 30 * time.Second, + }, + &Renewer{ + secret: &Secret{}, + grace: 30 * time.Second, + }, + false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + v, err := client.NewRenewer(tc.i) + if (err != nil) != tc.err { + t.Fatal(err) + } + + if v == nil { + return + } + + // Zero-out channels because reflect + v.client = nil + v.random = nil + v.doneCh = nil + v.renewCh = nil + v.stopCh = nil + + if !reflect.DeepEqual(tc.e, v) { + t.Errorf("not equal\nexp: %#v\nact: %#v", tc.e, v) + } + }) + } +}