Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vault: revert #18998 to fix potential deadlock #19963

Merged
merged 2 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 69 additions & 93 deletions client/vaultclient/vaultclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ type vaultClient struct {
// vaultClientRenewalRequest is a request object for renewal of both tokens and
// secret's leases.
type vaultClientRenewalRequest struct {
// renewalLoopCh is used to notify listeners every time the token goes
// through the renewal loop. It does not guarantee the renewal was
// successful, so listeners should also read from errCh for renewal errors.
renewalLoopCh chan struct{}

// errCh is the channel into which any renewal error will be sent to
errCh chan error

Expand Down Expand Up @@ -358,15 +353,13 @@ func (c *vaultClient) GetConsulACL(token, path string) (*vaultapi.Secret, error)
return c.client.Logical().Read(path)
}

// RenewToken pushes the supplied token to the min-heap for an immediate
// renewal for a given duration (in seconds) and blocks until the renewal loop
// has processed it. The token is then renewed periodically until Stop() or
// StopRenewToken() is called.
//
// Any error returned during the periodical renewal will be written to a
// buffered channel and the channel is returned instead of an actual error.
// This helps the caller be notified of a renewal failure asynchronously for
// appropriate actions to be taken.
// RenewToken renews the supplied token for a given duration (in seconds) and
// adds it to the min-heap so that it is renewed periodically by the renewal
// loop. Any error returned during renewal will be written to a buffered
// channel and the channel is returned instead of an actual error. This helps
// the caller be notified of a renewal failure asynchronously for appropriate
// actions to be taken. The caller of this function need not have to close the
// error channel.
func (c *vaultClient) RenewToken(token string, increment int) (<-chan error, error) {
if token == "" {
err := fmt.Errorf("missing token")
Expand All @@ -377,84 +370,38 @@ func (c *vaultClient) RenewToken(token string, increment int) (<-chan error, err
return nil, err
}

// Create a buffered error channel
errCh := make(chan error, 1)

// Create a renewal request and indicate that the identifier in the
// request is a token and not a lease
req := &vaultClientRenewalRequest{
renewalLoopCh: make(chan struct{}),
errCh: make(chan error, 1),
id: token,
isToken: true,
increment: increment,
}

// Push an immediate renewal request to the heap and block until a result
// is received.
err := c.pushRenewalRequest(req, time.Now())
if err != nil {
return nil, err
renewalReq := &vaultClientRenewalRequest{
errCh: errCh,
id: token,
isToken: true,
increment: increment,
}

select {
case err := <-req.errCh:
// Perform the renewal of the token and send any error to the dedicated
// error channel.
if err := c.renew(renewalReq); err != nil {
c.logger.Error("error during renewal of token", "error", err)
metrics.IncrCounter([]string{"client", "vault", "renew_token_failure"}, 1)
return nil, err
case <-req.renewalLoopCh:
return req.errCh, nil
}
}

// pushRenewalRequest pushes a renewal request to the heap and triggers the
// renewal loop to re-fetch a new request.
func (c *vaultClient) pushRenewalRequest(req *vaultClientRenewalRequest, next time.Time) error {
c.lock.Lock()
defer c.lock.Unlock()

if !c.running {
return errors.New("token renewal loop is not running")
}

if !c.isTracked(req.id) {
err := c.heap.Push(req, next)
if err != nil {
return fmt.Errorf("failed to push renewal request to heap: %v", err)
}
} else {
err := c.heap.Update(req, next)
if err != nil {
return fmt.Errorf("failed to update renewal request: %v", err)
}
}

// Signal an update for the renewal loop to trigger a fresh computation for
// the next best candidate for renewal.
select {
case c.updateCh <- struct{}{}:
default:
}

return nil
return errCh, nil
}

// renew is a common method to handle renewal of both tokens and secret leases.
// It invokes a token renewal or a secret's lease renewal. If renewal is
// successful, min-heap is updated based on the duration after which it needs
// renewal again. The next renewal time is randomly selected to avoid spikes in
// the number of APIs periodically.
// Only tokens that are present in the heap are renewed.
func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
c.lock.Lock()
defer c.lock.Unlock()

// Always notify listeners that the request has been processed before
// exiting.
defer func() {
select {
case req.renewalLoopCh <- struct{}{}:
default:
}
}()

if req == nil {
return fmt.Errorf("nil renewal request")
}
Expand All @@ -479,12 +426,6 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
return fmt.Errorf("increment cannot be less than 1")
}

// Verify token is still in the heap before proceeding as it may have been
// removed while waiting for the renewal timer to tick.
if !c.isTracked(req.id) {
return nil
}

var renewalErr error
leaseDuration := req.increment
if req.isToken {
Expand Down Expand Up @@ -535,25 +476,60 @@ func (c *vaultClient) renew(req *vaultClientRenewalRequest) error {
"error", renewalErr, "period", next)
}

if fatal {
// If encountered with an error where in a lease or a
// token is not valid at all with vault, and if that
// item is tracked by the renewal loop, stop renewing
// it by removing the corresponding heap entry.
if err := c.heap.Remove(req.id); err != nil {
return fmt.Errorf("failed to remove heap entry: %v", err)
if c.isTracked(req.id) {
if fatal {
// If encountered with an error where in a lease or a
// token is not valid at all with vault, and if that
// item is tracked by the renewal loop, stop renewing
// it by removing the corresponding heap entry.
if err := c.heap.Remove(req.id); err != nil {
return fmt.Errorf("failed to remove heap entry: %v", err)
}

// Report the fatal error to the client
req.errCh <- renewalErr
close(req.errCh)

return renewalErr
}

// Report the fatal error to the client
req.errCh <- renewalErr
close(req.errCh)
// If the identifier is already tracked, this indicates a
// subsequest renewal. In this case, update the existing
// element in the heap with the new renewal time.
if err := c.heap.Update(req, next); err != nil {
return fmt.Errorf("failed to update heap entry. err: %v", err)
}

return renewalErr
}
// There is no need to signal an update to the renewal loop
// here because this case is hit from the renewal loop itself.
} else {
if fatal {
// If encountered with an error where in a lease or a
// token is not valid at all with vault, and if that
// item is not tracked by renewal loop, don't add it.

// Report the fatal error to the client
req.errCh <- renewalErr
close(req.errCh)

// Update the element in the heap with the new renewal time.
if err := c.heap.Update(req, next); err != nil {
return fmt.Errorf("failed to update heap entry. err: %v", err)
return renewalErr
}

// If the identifier is not already tracked, this is a first
// renewal request. In this case, add an entry into the heap
// with the next renewal time.
if err := c.heap.Push(req, next); err != nil {
return fmt.Errorf("failed to push an entry to heap. err: %v", err)
}

// Signal an update for the renewal loop to trigger a fresh
// computation for the next best candidate for renewal.
if c.running {
select {
case c.updateCh <- struct{}{}:
default:
}
}
}

return nil
Expand Down
61 changes: 61 additions & 0 deletions client/vaultclient/vaultclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ import (
josejwt "github.com/go-jose/go-jose/v3/jwt"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/widmgr"
"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/helper/useragent"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
structsc "github.com/hashicorp/nomad/nomad/structs/config"
Expand Down Expand Up @@ -616,3 +618,62 @@ func TestVaultClient_SetUserAgent(t *testing.T) {
ua := c.client.Headers().Get("User-Agent")
must.Eq(t, useragent.String(), ua)
}

func TestVaultClient_RenewalConcurrent(t *testing.T) {
ci.Parallel(t)

// Create test server to mock the Vault API.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := vaultapi.Secret{
RequestID: uuid.Generate(),
LeaseID: uuid.Generate(),
Renewable: true,
Data: map[string]any{},
Auth: &vaultapi.SecretAuth{
ClientToken: uuid.Generate(),
Accessor: uuid.Generate(),
LeaseDuration: 300,
},
}

out, err := json.Marshal(resp)
if err != nil {
t.Errorf("failed to generate JWKS json response: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
fmt.Fprintln(w, string(out))
}))
defer ts.Close()

// Start Vault client.
conf := structsc.DefaultVaultConfig()
conf.Addr = ts.URL
conf.Enabled = pointer.Of(true)

vc, err := NewVaultClient(conf, testlog.HCLogger(t), nil)
must.NoError(t, err)
vc.Start()

// Renew token multiple times in parallel.
requests := 100
resultCh := make(chan any)
for i := 0; i < requests; i++ {
go func() {
_, err := vc.RenewToken("token", 30)
resultCh <- err
}()
}

// Collect results with timeout.
timer, stop := helper.NewSafeTimer(3 * time.Second)
defer stop()
for i := 0; i < requests; i++ {
select {
case got := <-resultCh:
must.Nil(t, got, must.Sprintf("token renewal error: %v", got))
case <-timer.C:
t.Fatal("timeout waiting for token renewal")
}
}
}
Loading