From 14f83956a90c9964a429ccab254bf62078584648 Mon Sep 17 00:00:00 2001 From: John Eikenberry Date: Tue, 27 Aug 2019 14:57:55 -0700 Subject: [PATCH] fix vault retry logic on failed calls The original problem was that for non-renewable vault secrets that it was having trouble fetching, it would wait the standard exponential backoff time plus the configured sleep time (like it does between successful fetches). When what it should do is use the sleep time between successful fetches and exponential backoff on failures. While fixing this I cleaned up the code to make the logic clear. The issue existed in both vault_read and vault_write, and they shared a common chunk of renew logic between them and with vault_token. So I refactored that out into a common function. Fixes #1224 --- dependency/vault_common.go | 42 ++++++++++++++- dependency/vault_common_test.go | 4 +- dependency/vault_read.go | 93 +++++++++++++++------------------ dependency/vault_read_test.go | 11 +++- dependency/vault_token.go | 45 ++++------------ dependency/vault_write.go | 88 +++++++++++++------------------ dependency/vault_write_test.go | 6 ++- 7 files changed, 145 insertions(+), 144 deletions(-) diff --git a/dependency/vault_common.go b/dependency/vault_common.go index 213c63293..81fe3e086 100644 --- a/dependency/vault_common.go +++ b/dependency/vault_common.go @@ -64,9 +64,47 @@ type SecretWrapInfo struct { WrappedAccessor string } -// vaultRenewDuration accepts a secret and returns the recommended amount of +// +type renewer interface { + Dependency + stopChan() chan struct{} + secrets() (*Secret, *api.Secret) +} + +func renewSecret(clients *ClientSet, d renewer) error { + log.Printf("[TRACE] %s: starting renewer", d) + + secret, vaultSecret := d.secrets() + renewer, err := clients.Vault().NewRenewer(&api.RenewerInput{ + Secret: vaultSecret, + }) + if err != nil { + return err + } + go renewer.Renew() + defer renewer.Stop() + + for { + select { + case err := <-renewer.DoneCh(): + if err != nil { + log.Printf("[WARN] %s: failed to renew: %s", d, err) + } + log.Printf("[WARN] %s: renewer done (maybe the lease expired)", d) + return nil + case renewal := <-renewer.RenewCh(): + log.Printf("[TRACE] %s: successfully renewed", d) + printVaultWarnings(d, renewal.Secret.Warnings) + updateSecret(secret, renewal.Secret) + case <-d.stopChan(): + return ErrStopped + } + } +} + +// leaseCheckWait accepts a secret and returns the recommended amount of // time to sleep. -func vaultRenewDuration(s *Secret) time.Duration { +func leaseCheckWait(s *Secret) time.Duration { // Handle whether this is an auth or a regular secret. base := s.LeaseDuration if s.Auth != nil && s.Auth.LeaseDuration > 0 { diff --git a/dependency/vault_common_test.go b/dependency/vault_common_test.go index b20f280be..620bd4398 100644 --- a/dependency/vault_common_test.go +++ b/dependency/vault_common_test.go @@ -8,13 +8,13 @@ func init() { func TestVaultRenewDuration(t *testing.T) { renewable := Secret{LeaseDuration: 100, Renewable: true} - renewableDur := vaultRenewDuration(&renewable).Seconds() + renewableDur := leaseCheckWait(&renewable).Seconds() if renewableDur < 16 || renewableDur >= 34 { t.Fatalf("renewable duration is not within 1/6 to 1/3 of lease duration: %f", renewableDur) } nonRenewable := Secret{LeaseDuration: 100} - nonRenewableDur := vaultRenewDuration(&nonRenewable).Seconds() + nonRenewableDur := leaseCheckWait(&nonRenewable).Seconds() if nonRenewableDur < 85 || nonRenewableDur > 95 { t.Fatalf("renewable duration is not within 85%% to 95%% of lease duration: %f", nonRenewableDur) } diff --git a/dependency/vault_read.go b/dependency/vault_read.go index 5479f90af..729f747d8 100644 --- a/dependency/vault_read.go +++ b/dependency/vault_read.go @@ -18,7 +18,8 @@ var ( // VaultReadQuery is the dependency to Vault for a secret type VaultReadQuery struct { - stopCh chan struct{} + stopCh chan struct{} + sleepCh <-chan time.Time rawPath string queryValues url.Values @@ -45,81 +46,69 @@ func NewVaultReadQuery(s string) (*VaultReadQuery, error) { return &VaultReadQuery{ stopCh: make(chan struct{}, 1), + sleepCh: make(chan time.Time, 1), rawPath: secretURL.Path, queryValues: secretURL.Query(), }, nil } // Fetch queries the Vault API -func (d *VaultReadQuery) Fetch(clients *ClientSet, opts *QueryOptions) (interface{}, *ResponseMetadata, error) { +func (d *VaultReadQuery) Fetch(clients *ClientSet, opts *QueryOptions, +) (interface{}, *ResponseMetadata, error) { select { case <-d.stopCh: return nil, nil, ErrStopped default: } + select { + case <-d.sleepCh: + default: + } - opts = opts.Merge(&QueryOptions{}) + firstRun := d.secret == nil - if d.secret != nil { - if vaultSecretRenewable(d.secret) { - log.Printf("[TRACE] %s: starting renewer", d) - - renewer, err := clients.Vault().NewRenewer(&api.RenewerInput{ - Grace: opts.VaultGrace, - Secret: d.vaultSecret, - }) - if err != nil { - return nil, nil, errors.Wrap(err, d.String()) - } - go renewer.Renew() - defer renewer.Stop() - - RENEW: - for { - select { - case err := <-renewer.DoneCh(): - if err != nil { - log.Printf("[WARN] %s: failed to renew: %s", d, err) - } - log.Printf("[WARN] %s: renewer returned (maybe the lease expired)", d) - break RENEW - case renewal := <-renewer.RenewCh(): - log.Printf("[TRACE] %s: successfully renewed", d) - printVaultWarnings(d, renewal.Secret.Warnings) - updateSecret(d.secret, renewal.Secret) - case <-d.stopCh: - return nil, nil, ErrStopped - } - } - } else { - // The secret isn't renewable, probably the generic secret backend. - dur := vaultRenewDuration(d.secret) - log.Printf("[TRACE] %s: secret is not renewable, sleeping for %s", d, dur) - select { - case <-time.After(dur): - // The lease is almost expired, it's time to request a new one. - case <-d.stopCh: - return nil, nil, ErrStopped - } + if !firstRun && vaultSecretRenewable(d.secret) { + err := renewSecret(clients, d) + if err != nil { + return nil, nil, errors.Wrap(err, d.String()) } } - // We don't have a secret, or the prior renewal failed - vaultSecret, err := d.readSecret(clients, opts) + err := d.fetchSecret(clients, opts) if err != nil { return nil, nil, errors.Wrap(err, d.String()) } - // Print any warnings - printVaultWarnings(d, vaultSecret.Warnings) - - // Create the cloned secret which will be exposed to the template. - d.vaultSecret = vaultSecret - d.secret = transformSecret(vaultSecret) + if !vaultSecretRenewable(d.secret) { + dur := leaseCheckWait(d.secret) + log.Printf("[TRACE] %s: non-renewable secret, set sleep for %s", d, dur) + d.sleepCh = time.After(dur) + } return respWithMetadata(d.secret) } +func (d *VaultReadQuery) fetchSecret(clients *ClientSet, opts *QueryOptions, +) error { + opts = opts.Merge(&QueryOptions{}) + vaultSecret, err := d.readSecret(clients, opts) + if err == nil { + printVaultWarnings(d, vaultSecret.Warnings) + d.vaultSecret = vaultSecret + // the cloned secret which will be exposed to the template + d.secret = transformSecret(vaultSecret) + } + return err +} + +func (d *VaultReadQuery) stopChan() chan struct{} { + return d.stopCh +} + +func (d *VaultReadQuery) secrets() (*Secret, *api.Secret) { + return d.secret, d.vaultSecret +} + // CanShare returns if this dependency is shareable. func (d *VaultReadQuery) CanShare() bool { return false diff --git a/dependency/vault_read_test.go b/dependency/vault_read_test.go index 597ea6db4..8741a2a65 100644 --- a/dependency/vault_read_test.go +++ b/dependency/vault_read_test.go @@ -77,6 +77,7 @@ func TestNewVaultReadQuery(t *testing.T) { if act != nil { act.stopCh = nil + act.sleepCh = nil } assert.Equal(t, tc.exp, act) @@ -170,7 +171,10 @@ func TestVaultReadQuery_Fetch_KVv1(t *testing.T) { errCh <- err return } - dataCh <- data + select { + case dataCh <- data: + case <-d.stopCh: + } } }() @@ -372,7 +376,10 @@ func TestVaultReadQuery_Fetch_KVv2(t *testing.T) { errCh <- err return } - dataCh <- data + select { + case dataCh <- data: + case <-d.stopCh: + } } }() diff --git a/dependency/vault_token.go b/dependency/vault_token.go index a3c7f4c87..61fa29cfa 100644 --- a/dependency/vault_token.go +++ b/dependency/vault_token.go @@ -5,7 +5,6 @@ import ( "time" "github.com/hashicorp/vault/api" - "github.com/pkg/errors" ) var ( @@ -44,44 +43,15 @@ func (d *VaultTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (interfa default: } - opts = opts.Merge(&QueryOptions{}) - if vaultSecretRenewable(d.secret) { - log.Printf("[TRACE] %s: starting renewer", d) - - renewer, err := clients.Vault().NewRenewer(&api.RenewerInput{ - Grace: opts.VaultGrace, - Secret: d.vaultSecret, - }) - if err != nil { - return nil, nil, errors.Wrap(err, d.String()) - } - go renewer.Renew() - defer renewer.Stop() - - RENEW: - for { - select { - case err := <-renewer.DoneCh(): - if err != nil { - log.Printf("[WARN] %s: failed to renew: %s", d, err) - } - log.Printf("[WARN] %s: renewer returned (maybe the lease expired)", d) - break RENEW - case renewal := <-renewer.RenewCh(): - log.Printf("[TRACE] %s: successfully renewed", d) - printVaultWarnings(d, renewal.Secret.Warnings) - updateSecret(d.secret, renewal.Secret) - case <-d.stopCh: - return nil, nil, ErrStopped - } - } + renewSecret(clients, d) } // The secret isn't renewable, probably the generic secret backend. // TODO This is incorrect when given a non-renewable template. We should // instead to a lookup self to determine the lease duration. - dur := vaultRenewDuration(d.secret) + opts = opts.Merge(&QueryOptions{}) + dur := leaseCheckWait(d.secret) if dur < opts.VaultGrace { dur = opts.VaultGrace } @@ -89,7 +59,6 @@ func (d *VaultTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (interfa log.Printf("[TRACE] %s: token is not renewable, sleeping for %s", d, dur) select { case <-time.After(dur): - // The lease is almost expired, it's time to request a new one. case <-d.stopCh: return nil, nil, ErrStopped } @@ -97,6 +66,14 @@ func (d *VaultTokenQuery) Fetch(clients *ClientSet, opts *QueryOptions) (interfa return nil, nil, ErrLeaseExpired } +func (d *VaultTokenQuery) stopChan() chan struct{} { + return d.stopCh +} + +func (d *VaultTokenQuery) secrets() (*Secret, *api.Secret) { + return d.secret, d.vaultSecret +} + // CanShare returns if this dependency is shareable. func (d *VaultTokenQuery) CanShare() bool { return false diff --git a/dependency/vault_write.go b/dependency/vault_write.go index 3e0ecd776..44bacd438 100644 --- a/dependency/vault_write.go +++ b/dependency/vault_write.go @@ -21,7 +21,8 @@ var ( // VaultWriteQuery is the dependency to Vault for a secret type VaultWriteQuery struct { - stopCh chan struct{} + stopCh chan struct{} + sleepCh <-chan time.Time path string data map[string]interface{} @@ -42,6 +43,7 @@ func NewVaultWriteQuery(s string, d map[string]interface{}) (*VaultWriteQuery, e return &VaultWriteQuery{ stopCh: make(chan struct{}, 1), + sleepCh: make(chan time.Time, 1), path: s, data: d, dataHash: sha1Map(d), @@ -49,77 +51,61 @@ func NewVaultWriteQuery(s string, d map[string]interface{}) (*VaultWriteQuery, e } // Fetch queries the Vault API -func (d *VaultWriteQuery) Fetch(clients *ClientSet, opts *QueryOptions) (interface{}, *ResponseMetadata, error) { +func (d *VaultWriteQuery) Fetch(clients *ClientSet, opts *QueryOptions, +) (interface{}, *ResponseMetadata, error) { select { case <-d.stopCh: return nil, nil, ErrStopped default: } + select { + case <-d.sleepCh: + default: + } - opts = opts.Merge(&QueryOptions{}) + firstRun := d.secret == nil - if d.secret != nil { - if vaultSecretRenewable(d.secret) { - log.Printf("[TRACE] %s: starting renewer", d) - - renewer, err := clients.Vault().NewRenewer(&api.RenewerInput{ - Grace: opts.VaultGrace, - Secret: d.vaultSecret, - }) - if err != nil { - return nil, nil, errors.Wrap(err, d.String()) - } - go renewer.Renew() - defer renewer.Stop() - - RENEW: - for { - select { - case err := <-renewer.DoneCh(): - if err != nil { - log.Printf("[WARN] %s: failed to renew: %s", d, err) - } - log.Printf("[WARN] %s: renewer returned (maybe the lease expired)", d) - break RENEW - case renewal := <-renewer.RenewCh(): - log.Printf("[TRACE] %s: successfully renewed", d) - printVaultWarnings(d, renewal.Secret.Warnings) - updateSecret(d.secret, renewal.Secret) - case <-d.stopCh: - return nil, nil, ErrStopped - } - } - } else { - // The secret isn't renewable, probably the generic secret backend. - dur := vaultRenewDuration(d.secret) - log.Printf("[TRACE] %s: secret is not renewable, sleeping for %s", d, dur) - select { - case <-time.After(dur): - // The lease is almost expired, it's time to request a new one. - case <-d.stopCh: - return nil, nil, ErrStopped - } + if !firstRun && vaultSecretRenewable(d.secret) { + err := renewSecret(clients, d) + if err != nil { + return nil, nil, errors.Wrap(err, d.String()) } } - // We don't have a secret, or the prior renewal failed + opts = opts.Merge(&QueryOptions{}) vaultSecret, err := d.writeSecret(clients, opts) if err != nil { return nil, nil, errors.Wrap(err, d.String()) } // vaultSecret == nil when writing to KVv1 engines - if vaultSecret != nil { - // Print any warnings - printVaultWarnings(d, vaultSecret.Warnings) - // Create the cloned secret which will be exposed to the template. - d.vaultSecret = vaultSecret - d.secret = transformSecret(vaultSecret) + if vaultSecret == nil { + return respWithMetadata(d.secret) + } + + printVaultWarnings(d, vaultSecret.Warnings) + d.vaultSecret = vaultSecret + // cloned secret which will be exposed to the template + d.secret = transformSecret(vaultSecret) + + if !vaultSecretRenewable(d.secret) { + dur := leaseCheckWait(d.secret) + log.Printf("[TRACE] %s: non-renewable secret, set sleep for %s", d, dur) + d.sleepCh = time.After(dur) } return respWithMetadata(d.secret) } +// meet renewer interface +func (d *VaultWriteQuery) stopChan() chan struct{} { + return d.stopCh +} + +func (d *VaultWriteQuery) secrets() (*Secret, *api.Secret) { + return d.secret, d.vaultSecret +} + // CanShare returns if this dependency is shareable. func (d *VaultWriteQuery) CanShare() bool { return false diff --git a/dependency/vault_write_test.go b/dependency/vault_write_test.go index 6dbe0c403..aa8f27b32 100644 --- a/dependency/vault_write_test.go +++ b/dependency/vault_write_test.go @@ -64,6 +64,7 @@ func TestNewVaultWriteQuery(t *testing.T) { if act != nil { act.stopCh = nil + act.sleepCh = nil } assert.Equal(t, tc.exp, act) @@ -233,7 +234,10 @@ func TestVaultWriteQuery_Fetch(t *testing.T) { errCh <- err return } - dataCh <- data + select { + case dataCh <- data: + case <-d.stopCh: + } } }()