From b029f764fca43ba3af07fd0d9a08b6d87681315e Mon Sep 17 00:00:00 2001 From: Stepan Pesternikov Date: Mon, 23 Sep 2019 17:48:44 +0300 Subject: [PATCH] feat(ctx) change close policy from chan to context --- pool.go | 21 +++++++++++---------- pool_test.go | 46 +++++++++++++++++++++++----------------------- task_test.go | 4 ++-- worker.go | 10 +++++----- 4 files changed, 41 insertions(+), 40 deletions(-) diff --git a/pool.go b/pool.go index 8fcf5f2..5c45ba3 100644 --- a/pool.go +++ b/pool.go @@ -57,7 +57,8 @@ type Config struct { // as well as processing task. type Pool struct { config Config - close chan struct{} + cancel context.CancelFunc + ctx context.Context syncClose sync.Once taskDeque *customTaskDeque arbiterWg sync.WaitGroup @@ -95,14 +96,14 @@ func NewPool(c *Config) *Pool { p.config = c.withDefaults() - p.close = make(chan struct{}) + p.ctx, p.cancel = context.WithCancel(context.Background()) p.taskDeque = newCustomTaskDeque() p.workers = make([]*customWorker, 0, p.config.UnstoppableWorkers) p.waitChan = make(chan struct{}, 1) p.workerWg.Add(p.config.UnstoppableWorkers) for i := 0; i < p.config.UnstoppableWorkers; i++ { - worker := newCustomWorker(p.close) + worker := newCustomWorker(p.ctx) p.workers = append(p.workers, worker) go worker.Run(func() { p.waitTask() @@ -139,8 +140,8 @@ func (p *Pool) arbiter() { select { case <-p.waitChan: break - case <-p.close: - p.taskDeque.put(t, true) + case <-p.ctx.Done(): + _ = p.taskDeque.put(t, true) p.arbiterWg.Done() return } @@ -174,7 +175,7 @@ func (p *Pool) setUnstoppableWorkers(count int) { for i := 0; i < p.config.UnstoppableWorkers-count; i++ { p.workerWg.Add(p.config.UnstoppableWorkers - count) for i := 0; i < p.config.UnstoppableWorkers; i++ { - worker := newCustomWorker(p.close) + worker := newCustomWorker(p.ctx) p.workers = append(p.workers, worker) go worker.Run(func() { p.waitTask() @@ -199,7 +200,7 @@ func (p *Pool) spawn(t *wrappedTask) bool { } atomic.AddInt64(&p.spawnCount, 1) - worker := newCustomWorker(p.close) + worker := newCustomWorker(p.ctx) go worker.Spawn(t, func() { atomic.AddInt64(&p.spawnCount, -1) p.waitTask() @@ -284,7 +285,7 @@ func (p *Pool) UnstoppableWorkers() int { func (p *Pool) Close() { p.syncClose.Do(func() { p.guard.Lock() - close(p.close) + p.cancel() p.guard.Unlock() p.taskDeque.close() @@ -296,7 +297,7 @@ func (p *Pool) Close() { for _, worker := range workers { worker.Release() for task := range worker.taskChan { - p.taskDeque.put(task, true) + _ = p.taskDeque.put(task, true) } } @@ -313,7 +314,7 @@ func (p *Pool) Close() { // IsClosed returns true if this pool has been closed. func (p *Pool) IsClosed() bool { select { - case <-p.close: + case <-p.ctx.Done(): return true default: return false diff --git a/pool_test.go b/pool_test.go index 14534d3..1b9bb23 100644 --- a/pool_test.go +++ b/pool_test.go @@ -41,7 +41,7 @@ func TestSubmit(t *testing.T) { return nil, nil }) - p.Submit(task) + _ = p.Submit(task) } activeCount := p.ActiveCount() @@ -78,14 +78,14 @@ func TestSubmitClose1(t *testing.T) { select { case completion <- i: break - case <-p.close: + case <-p.ctx.Done(): return nil, nil } atomic.AddUint32(&count, 1) return nil, nil }, i) - p.Submit(task) + _ = p.Submit(task) } resultCount := 0 @@ -127,7 +127,7 @@ func TestSubmitClose2(t *testing.T) { select { case completion <- i: break - case <-p.close: + case <-p.ctx.Done(): return nil, nil } atomic.AddUint32(&count, 1) @@ -169,13 +169,13 @@ func TestSubmitWithCompletionClose1(t *testing.T) { select { case completion1 <- i: break - case <-p.close: + case <-p.ctx.Done(): break } return 1, nil }, i) - p.SubmitWithCompletion(completion2, task) + _ = p.SubmitWithCompletion(completion2, task) } resultCount := 0 @@ -205,13 +205,13 @@ func TestSubmitWithCompletionClose2(t *testing.T) { select { case completion1 <- i: break - case <-p.close: + case <-p.ctx.Done(): break } return 1, nil }, i) - p.SubmitWithCompletion(completion2, task) + _ = p.SubmitWithCompletion(completion2, task) } resultCount := 0 @@ -239,7 +239,7 @@ func TestSubmitWithCompletion(t *testing.T) { return 1, nil }) - p.SubmitWithCompletion(completion, task) + _ = p.SubmitWithCompletion(completion, task) } count := 0 @@ -285,7 +285,7 @@ func TestSubmitWithCompletionPanic(t *testing.T) { return 1, nil }, i) - p.SubmitWithCompletion(completion, task) + _ = p.SubmitWithCompletion(completion, task) } countPanic := 0 @@ -348,7 +348,7 @@ func TestSubmitWithCancel(t *testing.T) { return nil, nil }, i) - p.SubmitWithContext(ctx, task) + _ = p.SubmitWithContext(ctx, task) } wg1.Wait() @@ -381,7 +381,7 @@ func TestSubmitWithCancel(t *testing.T) { return nil, nil }, i) - p.SubmitWithContext(ctx, task) + _ = p.SubmitWithContext(ctx, task) } wg1.Wait() @@ -417,7 +417,7 @@ func TestSubmitCustom1(t *testing.T) { return 1, nil }, i) - p.SubmitCustom(ctx, completion, task) + _ = p.SubmitCustom(ctx, completion, task) } wg1.Wait() @@ -447,7 +447,7 @@ func TestSubmitCustom2(t *testing.T) { return 1, nil }, i) - p.SubmitCustom(ctx, completion, task) + _ = p.SubmitCustom(ctx, completion, task) } wg1.Wait() @@ -569,7 +569,7 @@ func TestWaitClose(t *testing.T) { return nil, nil }) - p.SubmitWithCompletion(completion, task) + _ = p.SubmitWithCompletion(completion, task) } resultCount := 0 @@ -610,22 +610,22 @@ func TestResize1(t *testing.T) { return nil, nil }, i) - p.Submit(task) + _ = p.Submit(task) } wg1.Wait() - p.SetSize(7) + _ = p.SetSize(7) if p.config.Size != 7 { t.Fatalf("invalid pool size %d; want 7", p.Size()) } - p.SetSize(12) + _ = p.SetSize(12) if p.config.Size != 12 { t.Fatalf("invalid pool size %d; want 12", p.Size()) } - p.SetSize(3) + _ = p.SetSize(3) if p.config.Size != 3 { t.Fatalf("invalid pool size %d; want 3", p.Size()) } @@ -665,22 +665,22 @@ func TestResize2(t *testing.T) { return nil, nil }, i) - p.Submit(task) + _ = p.Submit(task) } wg1.Wait() - p.SetSize(7) + _ = p.SetSize(7) if p.config.Size != 7 { t.Fatalf("invalid pool size %d; want 7", p.Size()) } - p.SetSize(12) + _ = p.SetSize(12) if p.config.Size != 12 { t.Fatalf("invalid pool size %d; want 12", p.Size()) } - p.SetSize(3) + _ = p.SetSize(3) if p.config.Size != 3 { t.Fatalf("invalid pool size %d; want 3", p.Size()) } diff --git a/task_test.go b/task_test.go index 2a88204..6bf584c 100644 --- a/task_test.go +++ b/task_test.go @@ -14,7 +14,7 @@ func TestCustomTaskCancel(t *testing.T) { } task.Get() - task.Error() + _ = task.Error() task.Panic() } @@ -44,6 +44,6 @@ func TestCustomTaskCancelDone(t *testing.T) { } task.Get() - task.Error() + _ = task.Error() task.Panic() } diff --git a/worker.go b/worker.go index 9e58fef..32e2f8d 100644 --- a/worker.go +++ b/worker.go @@ -26,19 +26,19 @@ import ( ) type customWorker struct { + ctx context.Context taskChan chan *wrappedTask freeChan chan struct{} syncRelease sync.Once release chan struct{} - kill <-chan struct{} } -func newCustomWorker(kill <-chan struct{}) *customWorker { +func newCustomWorker(ctx context.Context) *customWorker { return &customWorker{ + ctx: ctx, taskChan: make(chan *wrappedTask, 1), release: make(chan struct{}), freeChan: make(chan struct{}, 1), - kill: kill, } } @@ -55,7 +55,7 @@ func (cw *customWorker) completeTask(t *wrappedTask) { break case <-cw.release: break - case <-cw.kill: + case <-cw.ctx.Done(): break case <-ctx.Done(): break @@ -105,7 +105,7 @@ LOOP: cw.runTask(t, onComplete) case <-cw.release: break LOOP - case <-cw.kill: + case <-cw.ctx.Done(): break LOOP } }