Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a locking issue in the Rollback manager #6426

Merged
merged 10 commits into from
Mar 18, 2019
125 changes: 77 additions & 48 deletions vault/rollback.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ type RollbackManager struct {
type rollbackState struct {
lastError error
sync.WaitGroup
once sync.Once
rollbackFunc func(context.Context) error
}

// Run the rollback once, return true if we were the one that ran it. Caller
// should hold the statelock.
func (rs *rollbackState) run(ctx context.Context) (ran bool, err error) {
rs.once.Do(func() {
ran = true
err = rs.rollbackFunc(ctx)
})
return
}

// NewRollbackManager is used to create a new rollback manager
Expand Down Expand Up @@ -132,24 +144,62 @@ func (m *RollbackManager) triggerRollbacks() {
}
fullPath := e.namespace.Path + path

m.inflightLock.RLock()
_, ok := m.inflight[fullPath]
m.inflightLock.RUnlock()
if !ok {
m.startRollback(ctx, fullPath, true)
}
// Start a rollback if necessary
m.startOrLookupRollback(ctx, fullPath, true)
}
}

// startRollback is used to start an async rollback attempt.
briankassouf marked this conversation as resolved.
Show resolved Hide resolved
// This must be called with the inflightLock held.
func (m *RollbackManager) startRollback(ctx context.Context, fullPath string, grabStatelock bool) *rollbackState {
rs := &rollbackState{}
rs.Add(1)
m.inflightAll.Add(1)
func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath string, grabStatelock bool) *rollbackState {
rs := &rollbackState{
briankassouf marked this conversation as resolved.
Show resolved Hide resolved
rollbackFunc: func(ctx context.Context) error {
ns, err := namespace.FromContext(ctx)
if err != nil {
return err
}
if ns == nil {
return namespace.ErrNoNamespace
}

// Invoke a RollbackOperation
req := &logical.Request{
Operation: logical.RollbackOperation,
Path: ns.TrimmedPath(fullPath),
}

var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithTimeout(ctx, DefaultMaxRequestDuration)
_, err = m.router.Route(ctx, req)
cancelFunc()

// If the error is an unsupported operation, then it doesn't
// matter, the backend doesn't support it.
if err == logical.ErrUnsupportedOperation {
err = nil
}
// If we failed due to read-only storage, we can't do anything; ignore
if err != nil && strings.Contains(err.Error(), logical.ErrReadOnly.Error()) {
err = nil
}
if err != nil {
m.logger.Error("error rolling back", "path", fullPath, "error", err)
}
return nil
},
}

m.inflightLock.Lock()
defer m.inflightLock.Unlock()
rsInflight, ok := m.inflight[fullPath]
if ok {
return rsInflight
}

briankassouf marked this conversation as resolved.
Show resolved Hide resolved
// If no inflight rollback is already running, kick one off
m.inflight[fullPath] = rs
m.inflightLock.Unlock()
rs.Add(1)
m.inflightAll.Add(1)
go m.attemptRollback(ctx, fullPath, rs, grabStatelock)
return rs
}
Expand All @@ -170,47 +220,20 @@ func (m *RollbackManager) attemptRollback(ctx context.Context, fullPath string,
m.inflightLock.Unlock()
}()

ns, err := namespace.FromContext(ctx)
if err != nil {
return err
}
if ns == nil {
return namespace.ErrNoNamespace
}

// Invoke a RollbackOperation
req := &logical.Request{
Operation: logical.RollbackOperation,
Path: ns.TrimmedPath(fullPath),
}

if grabStatelock {
// Grab the statelock or stop
if stopped := grabLockOrStop(m.core.stateLock.RLock, m.core.stateLock.RUnlock, m.shutdownCh); stopped {
return errors.New("rollback shutting down")
}
}

var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithTimeout(ctx, DefaultMaxRequestDuration)
_, err = m.router.Route(ctx, req)
// Run the rollback
_, err = rs.run(ctx)

if grabStatelock {
m.core.stateLock.RUnlock()
}
cancelFunc()

// If the error is an unsupported operation, then it doesn't
// matter, the backend doesn't support it.
if err == logical.ErrUnsupportedOperation {
err = nil
}
// If we failed due to read-only storage, we can't do anything; ignore
if err != nil && strings.Contains(err.Error(), logical.ErrReadOnly.Error()) {
err = nil
}
if err != nil {
m.logger.Error("error rolling back", "path", fullPath, "error", err)
}
return
}

Expand All @@ -224,15 +247,21 @@ func (m *RollbackManager) Rollback(ctx context.Context, path string) error {
}
fullPath := ns.Path + path

// Check for an existing attempt and start one if none
m.inflightLock.RLock()
rs, ok := m.inflight[fullPath]
m.inflightLock.RUnlock()
if !ok {
rs = m.startRollback(ctx, fullPath, false)
// Check for an existing attempt or start one if none
rs := m.startOrLookupRollback(ctx, fullPath, false)

// Do a run here in the event an already inflight rollback is blocked on
// grabbing the statelock. This prevents a deadlock in some cases where the
// caller of this function holds the write statelock.
ran, err := rs.run(ctx)
// If we were the runner, return the error
if ran {
return err
}

// Wait for the attempt to finish
// If we weren't the runner, wait for the inflight attempt to finish. It's
// safe to do this, since if the other thread starts the run they are
// already in possession of the statelock and we are not deadlocked.
rs.Wait()

// Return the last error
Expand Down