From 538859988f2fe3516ccde93efa8923c24402b997 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Sat, 31 Aug 2013 12:01:12 -0700 Subject: [PATCH] Fix data races, simplify cancellation signalling /cc @titanous - Re-integrated your pull and fixed a couple races that were even present in your pull request (specifically: reading the struct channel structs outside of a lock and reading the state outside of a lock). I read state outside of a lock using the atomic package now. --- basic_runner.go | 113 ++++++++++++++++++++++-------------------------- 1 file changed, 51 insertions(+), 62 deletions(-) diff --git a/basic_runner.go b/basic_runner.go index 64a2f44..f50b2d2 100644 --- a/basic_runner.go +++ b/basic_runner.go @@ -1,6 +1,17 @@ package multistep -import "sync" +import ( + "sync" + "sync/atomic" +) + +type runState int32 + +const ( + stateIdle runState = iota + stateRunning + stateCancelling +) // BasicRunner is a Runner that just runs the given slice of steps. type BasicRunner struct { @@ -8,71 +19,50 @@ type BasicRunner struct { // modified. Steps []Step - cancelCond *sync.Cond - cancelChs []chan<- bool - running bool + cancelCh chan struct{} + doneCh chan struct{} + state runState l sync.Mutex } func (b *BasicRunner) Run(state StateBag) { - // Make sure we only run one at a time b.l.Lock() - if b.running { + if b.state != stateIdle { panic("already running") } - b.cancelChs = nil - b.cancelCond = sync.NewCond(&sync.Mutex{}) - b.running = true - b.l.Unlock() - - // cancelReady is used to signal that the cancellation goroutine - // started and is waiting. The cancelEnded channel is used to - // signal the goroutine actually ended. - cancelReady := make(chan bool, 1) - cancelEnded := make(chan bool) - go func() { - b.cancelCond.L.Lock() - cancelReady <- true - b.cancelCond.Wait() - b.cancelCond.L.Unlock() - - if b.cancelChs != nil { - state.Put(StateCancelled, true) - } - cancelEnded <- true - }() + cancelCh := make(chan struct{}) + doneCh := make(chan struct{}) + b.cancelCh = cancelCh + b.doneCh = doneCh + b.state = stateRunning + b.l.Unlock() - // Create the channel that we'll say we're done on in the case of - // interrupts here. We do this here so that this deferred statement - // runs last, so all the Cleanup methods are able to run. defer func() { b.l.Lock() - defer b.l.Unlock() - - // Make sure the cancellation goroutine cleans up properly. This - // is a bit complicated. Basically, we first wait until the goroutine - // waiting for cancellation is actually waiting. Then we broadcast - // to it so it can unlock. Then we wait for it to tell us it finished. - <-cancelReady - b.cancelCond.L.Lock() - b.cancelCond.Broadcast() - b.cancelCond.L.Unlock() - <-cancelEnded + b.cancelCh = nil + b.doneCh = nil + b.state = stateIdle + close(doneCh) + b.l.Unlock() + }() - if b.cancelChs != nil { - for _, doneCh := range b.cancelChs { - doneCh <- true - } + // This goroutine listens for cancels and puts the StateCancelled key + // as quickly as possible into the state bag to mark it. + go func() { + select { + case <-cancelCh: + // Flag cancel and wait for finish + state.Put(StateCancelled, true) + <-doneCh + case <-doneCh: } - - b.running = false }() for _, step := range b.Steps { // We also check for cancellation here since we can't be sure // the goroutine that is running to set it actually ran. - if b.cancelChs != nil { + if runState(atomic.LoadInt32((*int32)(&b.state))) == stateCancelling { state.Put(StateCancelled, true) break } @@ -93,20 +83,19 @@ func (b *BasicRunner) Run(state StateBag) { func (b *BasicRunner) Cancel() { b.l.Lock() - - if !b.running { - b.l.Unlock() + switch b.state { + case stateIdle: + // Not running, so Cancel is... done. return + case stateRunning: + // Running, so mark that we cancelled and set the state + close(b.cancelCh) + b.state = stateCancelling + fallthrough + case stateCancelling: + // Already cancelling, so just wait until we're done + ch := b.doneCh + b.l.Unlock() + <-ch } - - if b.cancelChs == nil { - b.cancelChs = make([]chan<- bool, 0, 5) - } - - done := make(chan bool) - b.cancelChs = append(b.cancelChs, done) - b.cancelCond.Broadcast() - b.l.Unlock() - - <-done }