diff --git a/internal/linkedbuffer/linkedbuffer.go b/internal/linkedbuffer/linkedbuffer.go index 41651c6..739318f 100644 --- a/internal/linkedbuffer/linkedbuffer.go +++ b/internal/linkedbuffer/linkedbuffer.go @@ -1,6 +1,7 @@ package linkedbuffer import ( + "math" "sync" "sync/atomic" ) @@ -125,7 +126,8 @@ func (b *LinkedBuffer[T]) Len() uint64 { readCount := b.readCount.Load() if writeCount < readCount { - return 0 // Make sure we don't return a negative value + // The writeCount counter wrapped around + return math.MaxUint64 - readCount + writeCount } return writeCount - readCount diff --git a/internal/linkedbuffer/linkedbuffer_test.go b/internal/linkedbuffer/linkedbuffer_test.go index 0a7cbc1..db46f4b 100644 --- a/internal/linkedbuffer/linkedbuffer_test.go +++ b/internal/linkedbuffer/linkedbuffer_test.go @@ -1,6 +1,7 @@ package linkedbuffer import ( + "math" "sync" "sync/atomic" "testing" @@ -88,8 +89,10 @@ func TestLinkedBufferLen(t *testing.T) { assert.Equal(t, uint64(0), buf.Len()) - buf.readCount.Add(1) - assert.Equal(t, uint64(0), buf.Len()) + // Test wrap around + buf.writeCount.Add(math.MaxUint64) + buf.readCount.Add(math.MaxUint64 - 3) + assert.Equal(t, uint64(3), buf.Len()) } func TestLinkedBufferWithReusedBuffer(t *testing.T) { diff --git a/internal/semaphore/semaphore.go b/internal/semaphore/semaphore.go new file mode 100644 index 0000000..78c735c --- /dev/null +++ b/internal/semaphore/semaphore.go @@ -0,0 +1,142 @@ +package semaphore + +import ( + "context" + "fmt" + "sync" +) + +type Weighted struct { + ctx context.Context + cond *sync.Cond + size int + n int + waiting int +} + +func NewWeighted(ctx context.Context, size int) *Weighted { + sem := &Weighted{ + ctx: ctx, + cond: sync.NewCond(&sync.Mutex{}), + size: size, + n: size, + } + + // Notify all waiters when the context is done + context.AfterFunc(ctx, func() { + sem.cond.Broadcast() + }) + + return sem +} + +func (w *Weighted) Acquire(weight int) error { + if weight <= 0 { + return fmt.Errorf("semaphore: weight %d cannot be negative or zero", weight) + } + if weight > w.size { + return fmt.Errorf("semaphore: weight %d is greater than semaphore size %d", weight, w.size) + } + + w.cond.L.Lock() + defer w.cond.L.Unlock() + + done := w.ctx.Done() + + select { + case <-done: + return w.ctx.Err() + default: + } + + for weight > w.n { + // Check if the context is done + select { + case <-done: + return w.ctx.Err() + default: + } + + w.waiting++ + w.cond.Wait() + w.waiting-- + } + + w.n -= weight + + return nil +} + +func (w *Weighted) TryAcquire(weight int) bool { + if weight <= 0 { + return false + } + if weight > w.size { + return false + } + + w.cond.L.Lock() + defer w.cond.L.Unlock() + + // Check if the context is done + select { + case <-w.ctx.Done(): + return false + default: + } + + if weight > w.n { + // Not enough room in the semaphore + return false + } + + w.n -= weight + + return true +} + +func (w *Weighted) Release(weight int) error { + if weight <= 0 { + return fmt.Errorf("semaphore: weight %d cannot be negative or zero", weight) + } + if weight > w.size { + return fmt.Errorf("semaphore: weight %d is greater than semaphore size %d", weight, w.size) + } + + w.cond.L.Lock() + defer w.cond.L.Unlock() + + if weight > w.size-w.n { + return fmt.Errorf("semaphore: trying to release more than acquired: %d > %d", weight, w.size-w.n) + } + + w.n += weight + w.cond.Broadcast() + + return nil +} + +func (w *Weighted) Size() int { + return w.size +} + +func (w *Weighted) Acquired() int { + w.cond.L.Lock() + defer w.cond.L.Unlock() + + return w.size - w.n +} + +func (w *Weighted) Available() int { + w.cond.L.Lock() + defer w.cond.L.Unlock() + + return w.n +} + +func (w *Weighted) Waiting() int { + w.cond.L.Lock() + defer w.cond.L.Unlock() + + return w.waiting +} diff --git a/internal/semaphore/semaphore_test.go b/internal/semaphore/semaphore_test.go new file mode 100644 index 0000000..6800954 --- /dev/null +++ b/internal/semaphore/semaphore_test.go @@ -0,0 +1,192 @@ +package semaphore + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/alitto/pond/v2/internal/assert" +) + +func TestWeighted(t *testing.T) { + sem := NewWeighted(context.Background(), 10) + + // Acquire 5 + err := sem.Acquire(5) + assert.Equal(t, nil, err) + + // Acquire 4 + err = sem.Acquire(4) + assert.Equal(t, nil, err) + + // Try to acquire 2 + assert.Equal(t, false, sem.TryAcquire(2)) + + // Try to acquire 1 + assert.Equal(t, true, sem.TryAcquire(1)) + + // Release 7 + sem.Release(7) + + // Try to acquire 7 + assert.Equal(t, true, sem.TryAcquire(7)) +} + +func TestWeightedWithMoreAcquirersThanReleasers(t *testing.T) { + sem := NewWeighted(context.Background(), 6) + + goroutines := 12 + acquire := 2 + release := 5 + wg := sync.WaitGroup{} + acquireSuccessCount := atomic.Uint64{} + acquireFailCount := atomic.Uint64{} + + wg.Add(goroutines) + + // Launch goroutines that try to acquire the semaphore + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + + if err := sem.Acquire(acquire); err != nil { + acquireFailCount.Add(1) + } else { + acquireSuccessCount.Add(1) + } + + if sem.Acquired() >= release { + sem.Release(release) + } + }() + } + + // Wait for goroutines to finish + wg.Wait() + + assert.Equal(t, uint64(12), acquireSuccessCount.Load()) + assert.Equal(t, uint64(0), acquireFailCount.Load()) + assert.Equal(t, 4, sem.Acquired()) +} + +func TestWeightedAcquireWithInvalidWeights(t *testing.T) { + sem := NewWeighted(context.Background(), 10) + + // Acquire 0 + err := sem.Acquire(0) + assert.Equal(t, "semaphore: weight 0 cannot be negative or zero", err.Error()) + + // Try to acquire 0 + res := sem.TryAcquire(0) + assert.Equal(t, false, res) + + // Acquire -1 + err = sem.Acquire(-1) + assert.Equal(t, "semaphore: weight -1 cannot be negative or zero", err.Error()) + + // Try to acquire -1 + res = sem.TryAcquire(-1) + assert.Equal(t, false, res) + + // Acquire 11 + err = sem.Acquire(11) + assert.Equal(t, "semaphore: weight 11 is greater than semaphore size 10", err.Error()) + + // Try to acquire 11 + res = sem.TryAcquire(11) + assert.Equal(t, false, res) +} + +func TestWeightedReleaseWithInvalidWeights(t *testing.T) { + sem := NewWeighted(context.Background(), 10) + + // Release 0 + err := sem.Release(0) + assert.Equal(t, "semaphore: weight 0 cannot be negative or zero", err.Error()) + + // Release -1 + err = sem.Release(-1) + assert.Equal(t, "semaphore: weight -1 cannot be negative or zero", err.Error()) + + // Release 11 + err = sem.Release(11) + assert.Equal(t, "semaphore: weight 11 is greater than semaphore size 10", err.Error()) + + // Release 1 + err = sem.Release(1) + assert.Equal(t, "semaphore: trying to release more than acquired: 1 > 0", err.Error()) +} + +func TestWeightedWithContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + sem := NewWeighted(ctx, 10) + + // Acquire the semaphore + err := sem.Acquire(5) + assert.Equal(t, nil, err) + + // Cancel the context + cancel() + + // Attempt to acquire the semaphore + err = sem.Acquire(5) + assert.Equal(t, context.Canceled, err) + + // Try to acquire the semaphore + assert.Equal(t, false, sem.TryAcquire(5)) +} + +func TestWeightedWithContextCanceledWhileWaiting(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + sem := NewWeighted(ctx, 10) + + writers := 30 + wg := sync.WaitGroup{} + wg.Add(writers) + + assert.Equal(t, 10, sem.Size()) + assert.Equal(t, 0, sem.Acquired()) + assert.Equal(t, 10, sem.Available()) + assert.Equal(t, 0, sem.Waiting()) + + // Acquire the semaphore more than the semaphore size + for i := 0; i < writers; i++ { + go func() { + defer wg.Done() + sem.Acquire(1) + }() + } + + // Wait until 10 goroutines are blocked + for sem.Acquired() < 10 { + time.Sleep(1 * time.Millisecond) + } + + assert.Equal(t, 10, sem.Acquired()) + assert.Equal(t, 0, sem.Available()) + + // Release 10 goroutines + err := sem.Release(10) + assert.Equal(t, nil, err) + + // Wait until 10 goroutines are blocked + for sem.Acquired() < 10 { + time.Sleep(1 * time.Millisecond) + } + + // Cancel the context + cancel() + + // Wait for goroutines to finish + wg.Wait() + + assert.Equal(t, 10, sem.Acquired()) + assert.Equal(t, 0, sem.Available()) + assert.Equal(t, 0, sem.Waiting()) + assert.Equal(t, context.Canceled, sem.Acquire(1)) + assert.Equal(t, false, sem.TryAcquire(1)) +} diff --git a/pool.go b/pool.go index b93dade..038369c 100644 --- a/pool.go +++ b/pool.go @@ -10,6 +10,7 @@ import ( "github.com/alitto/pond/v2/internal/dispatcher" "github.com/alitto/pond/v2/internal/future" + "github.com/alitto/pond/v2/internal/semaphore" ) var NUM_CPU = runtime.NumCPU() @@ -17,6 +18,7 @@ var NUM_CPU = runtime.NumCPU() var MAX_TASKS_CHAN_LENGTH = NUM_CPU * 128 var ErrPoolStopped = errors.New("pool stopped") +var ErrQueueFull = errors.New("task queue is full") var poolStoppedFuture = func() Task { future, resolve := future.NewFuture(context.Background()) @@ -24,6 +26,12 @@ var poolStoppedFuture = func() Task { return future }() +var poolQueueFullFuture = func() Task { + future, resolve := future.NewFuture(context.Background()) + resolve(ErrQueueFull) + return future +}() + // basePool is the base interface for all pool types. type basePool interface { // Returns the number of worker goroutines that are currently active (executing a task) in the pool. @@ -47,6 +55,14 @@ type basePool interface { // Returns the maximum concurrency of the pool. MaxConcurrency() int + // Returns the size of the task queue. + QueueSize() int + + // Returns true if the pool is non-blocking, meaning that it will not block when the task queue is full. + // In a non-blocking pool, tasks that cannot be submitted to the queue will be dropped. + // By default, pools are blocking, meaning that they will block when the task queue is full. + NonBlocking() bool + // Returns the context associated with this pool. Context() context.Context @@ -73,8 +89,8 @@ type Pool interface { // Submits a task to the pool and returns a future that can be used to wait for the task to complete. SubmitErr(task func() error) Task - // Creates a new subpool with the specified maximum concurrency. - NewSubpool(maxConcurrency int) Pool + // Creates a new subpool with the specified maximum concurrency and options. + NewSubpool(maxConcurrency int, options ...Option) Pool // Creates a new task group. NewGroup() TaskGroup @@ -96,6 +112,9 @@ type pool struct { dispatcherRunning sync.Mutex successfulTaskCount atomic.Uint64 failedTaskCount atomic.Uint64 + nonBlocking bool + queueSize int + queueSem *semaphore.Weighted } func (p *pool) Context() context.Context { @@ -110,6 +129,14 @@ func (p *pool) MaxConcurrency() int { return p.maxConcurrency } +func (p *pool) QueueSize() int { + return p.queueSize +} + +func (p *pool) NonBlocking() bool { + return p.nonBlocking +} + func (p *pool) RunningWorkers() int64 { return p.workerCount.Load() } @@ -163,6 +190,19 @@ func (p *pool) submit(task any) Task { wrapped := wrapTask[struct{}, func(error)](task, resolve) + if p.queueSem != nil { + if p.nonBlocking { + if !p.queueSem.TryAcquire(1) { + return poolQueueFullFuture + } + } else { + if err := p.queueSem.Acquire(1); err != nil { + resolve(err) + return future + } + } + } + if err := p.dispatcher.Write(wrapped); err != nil { return poolStoppedFuture } @@ -186,8 +226,8 @@ func (p *pool) StopAndWait() { p.Stop().Wait() } -func (p *pool) NewSubpool(maxConcurrency int) Pool { - return newSubpool(maxConcurrency, p.ctx, p) +func (p *pool) NewSubpool(maxConcurrency int, options ...Option) Pool { + return newSubpool(maxConcurrency, p.Context(), p, options...) } func (p *pool) NewGroup() TaskGroup { @@ -297,6 +337,11 @@ func (p *pool) worker() { return } + // We have a task to execute, release the semaphore since it is no longer in the queue + if p.queueSem != nil { + p.queueSem.Release(1) + } + // Execute task _, err := invokeTask[any](task) @@ -343,6 +388,10 @@ func newPool(maxConcurrency int, options ...Option) *pool { pool.ctx, pool.cancel = context.WithCancelCause(pool.ctx) + if pool.queueSize > 0 { + pool.queueSem = semaphore.NewWeighted(pool.ctx, pool.queueSize) + } + pool.dispatcher = dispatcher.NewDispatcher(pool.ctx, pool.dispatch, tasksLen) return pool diff --git a/pool_test.go b/pool_test.go index 8f6b273..fc99cf8 100644 --- a/pool_test.go +++ b/pool_test.go @@ -195,3 +195,66 @@ func TestPoolStoppedAfterCancel(t *testing.T) { assert.Equal(t, ErrPoolStopped, err) } + +func TestPoolWithQueueSize(t *testing.T) { + + pool := NewPool(1, WithQueueSize(10)) + + assert.Equal(t, 10, pool.QueueSize()) + assert.Equal(t, false, pool.NonBlocking()) + + var taskCount int = 50 + + for i := 0; i < taskCount; i++ { + pool.Submit(func() { + time.Sleep(1 * time.Millisecond) + }) + } + + pool.Stop().Wait() + + assert.Equal(t, uint64(taskCount), pool.SubmittedTasks()) + assert.Equal(t, uint64(taskCount), pool.CompletedTasks()) +} + +func TestPoolWithQueueSizeAndNonBlocking(t *testing.T) { + + pool := NewPool(10, WithQueueSize(10), NonBlocking()) + + assert.Equal(t, 10, pool.QueueSize()) + assert.Equal(t, true, pool.NonBlocking()) + + taskStarted := make(chan struct{}, 10) + taskWait := make(chan struct{}) + + for i := 0; i < 10; i++ { + pool.Submit(func() { + taskStarted <- struct{}{} + <-taskWait + }) + } + + // Wait for 10 tasks to start + for i := 0; i < 10; i++ { + <-taskStarted + } + + assert.Equal(t, int64(10), pool.RunningWorkers()) + assert.Equal(t, uint64(10), pool.SubmittedTasks()) + assert.Equal(t, uint64(0), pool.WaitingTasks()) + + // Saturate the queue + for i := 0; i < 10; i++ { + pool.Submit(func() { + time.Sleep(10 * time.Millisecond) + }) + } + + // Submit a task that should be rejected + task := pool.Submit(func() {}) + // Unblock tasks + close(taskWait) + assert.Equal(t, ErrQueueFull, task.Wait()) + + pool.Stop().Wait() +} diff --git a/pooloptions.go b/pooloptions.go index 25478d9..ada9254 100644 --- a/pooloptions.go +++ b/pooloptions.go @@ -12,3 +12,17 @@ func WithContext(ctx context.Context) Option { p.ctx = ctx } } + +// WithQueueSize sets the max number of elements that can be queued in the pool. +func WithQueueSize(size int) Option { + return func(p *pool) { + p.queueSize = size + } +} + +// NonBlocking sets the pool to be non-blocking when the queue is full. +func NonBlocking() Option { + return func(p *pool) { + p.nonBlocking = true + } +} diff --git a/result.go b/result.go index 9d2a7da..87a19de 100644 --- a/result.go +++ b/result.go @@ -16,8 +16,8 @@ type ResultPool[R any] interface { // Submits a task to the pool and returns a future that can be used to wait for the task to complete and get the result. SubmitErr(task func() (R, error)) Result[R] - // Creates a new subpool with the specified maximum concurrency. - NewSubpool(maxConcurrency int) ResultPool[R] + // Creates a new subpool with the specified maximum concurrency and options. + NewSubpool(maxConcurrency int, options ...Option) ResultPool[R] // Creates a new task group. NewGroup() ResultTaskGroup[R] @@ -56,8 +56,8 @@ func (p *resultPool[R]) submit(task any) Result[R] { return future } -func (p *resultPool[R]) NewSubpool(maxConcurrency int) ResultPool[R] { - return newResultSubpool[R](maxConcurrency, p.Context(), p.pool) +func (p *resultPool[R]) NewSubpool(maxConcurrency int, options ...Option) ResultPool[R] { + return newResultSubpool[R](maxConcurrency, p.Context(), p.pool, options...) } func NewResultPool[R any](maxConcurrency int, options ...Option) ResultPool[R] { diff --git a/resultsubpool.go b/resultsubpool.go index 86ea28a..50b37c6 100644 --- a/resultsubpool.go +++ b/resultsubpool.go @@ -7,16 +7,17 @@ import ( "sync" "github.com/alitto/pond/v2/internal/dispatcher" + "github.com/alitto/pond/v2/internal/semaphore" ) type resultSubpool[R any] struct { *resultPool[R] parent *pool waitGroup sync.WaitGroup - sem chan struct{} + sem *semaphore.Weighted } -func newResultSubpool[R any](maxConcurrency int, ctx context.Context, parent *pool) ResultPool[R] { +func newResultSubpool[R any](maxConcurrency int, ctx context.Context, parent *pool, options ...Option) ResultPool[R] { if maxConcurrency == 0 { maxConcurrency = parent.MaxConcurrency() @@ -35,15 +36,27 @@ func newResultSubpool[R any](maxConcurrency int, ctx context.Context, parent *po tasksLen = MAX_TASKS_CHAN_LENGTH } - subpool := &resultSubpool[R]{ - resultPool: &resultPool[R]{ - pool: &pool{ - ctx: ctx, - maxConcurrency: maxConcurrency, - }, + resultPool := &resultPool[R]{ + pool: &pool{ + ctx: ctx, + maxConcurrency: maxConcurrency, }, - parent: parent, - sem: make(chan struct{}, maxConcurrency), + } + + for _, option := range options { + option(resultPool.pool) + } + + ctx = resultPool.Context() + + if resultPool.pool.queueSize > 0 { + resultPool.pool.queueSem = semaphore.NewWeighted(ctx, resultPool.pool.queueSize) + } + + subpool := &resultSubpool[R]{ + resultPool: resultPool, + parent: parent, + sem: semaphore.NewWeighted(ctx, maxConcurrency), } subpool.pool.dispatcher = dispatcher.NewDispatcher(ctx, subpool.dispatch, tasksLen) @@ -53,27 +66,36 @@ func newResultSubpool[R any](maxConcurrency int, ctx context.Context, parent *po func (p *resultSubpool[R]) dispatch(incomingTasks []any) { - p.waitGroup.Add(len(incomingTasks)) - // Submit tasks for _, task := range incomingTasks { - select { - case <-p.Context().Done(): - // Context canceled, exit - return - case p.sem <- struct{}{}: - // Acquired the semaphore, submit another task + // Acquire semaphore to limit concurrency + if p.nonBlocking { + if ok := p.sem.TryAcquire(1); !ok { + // Context canceled, exit + return + } + } else { + if err := p.sem.Acquire(1); err != nil { + // Context canceled, exit + return + } } subpoolTask := subpoolTask[any]{ task: task, sem: p.sem, + queueSem: p.queueSem, waitGroup: &p.waitGroup, updateMetrics: p.updateMetrics, } - p.parent.Go(subpoolTask.Run) + p.waitGroup.Add(1) + + if err := p.parent.Go(subpoolTask.Run); err != nil { + // We failed to submit the task, release semaphore + subpoolTask.Close() + } } } @@ -82,11 +104,13 @@ func (p *resultSubpool[R]) Stop() Task { p.dispatcher.CloseAndWait() p.waitGroup.Wait() - - close(p.sem) }) } func (p *resultSubpool[R]) StopAndWait() { p.Stop().Wait() } + +func (p *resultSubpool[R]) RunningWorkers() int64 { + return int64(p.sem.Acquired()) +} diff --git a/subpool.go b/subpool.go index baf1bbc..93479ff 100644 --- a/subpool.go +++ b/subpool.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/alitto/pond/v2/internal/dispatcher" + "github.com/alitto/pond/v2/internal/semaphore" ) // subpool is a pool that is a subpool of another pool @@ -14,10 +15,10 @@ type subpool struct { *pool parent *pool waitGroup sync.WaitGroup - sem chan struct{} + sem *semaphore.Weighted } -func newSubpool(maxConcurrency int, ctx context.Context, parent *pool) Pool { +func newSubpool(maxConcurrency int, ctx context.Context, parent *pool, options ...Option) Pool { if maxConcurrency == 0 { maxConcurrency = parent.MaxConcurrency() @@ -36,13 +37,25 @@ func newSubpool(maxConcurrency int, ctx context.Context, parent *pool) Pool { tasksLen = MAX_TASKS_CHAN_LENGTH } + pool := &pool{ + ctx: ctx, + maxConcurrency: maxConcurrency, + } + + for _, option := range options { + option(pool) + } + + ctx = pool.Context() + + if pool.queueSize > 0 { + pool.queueSem = semaphore.NewWeighted(ctx, pool.queueSize) + } + subpool := &subpool{ - pool: &pool{ - ctx: ctx, - maxConcurrency: maxConcurrency, - }, + pool: pool, parent: parent, - sem: make(chan struct{}, maxConcurrency), + sem: semaphore.NewWeighted(ctx, maxConcurrency), } subpool.pool.dispatcher = dispatcher.NewDispatcher(ctx, subpool.dispatch, tasksLen) @@ -51,28 +64,34 @@ func newSubpool(maxConcurrency int, ctx context.Context, parent *pool) Pool { } func (p *subpool) dispatch(incomingTasks []any) { - - p.waitGroup.Add(len(incomingTasks)) - // Submit tasks for _, task := range incomingTasks { - select { - case <-p.Context().Done(): - // Context canceled, exit - return - case p.sem <- struct{}{}: - // Acquired the semaphore, submit another task + // Acquire semaphore to limit concurrency + if p.nonBlocking { + if ok := p.sem.TryAcquire(1); !ok { + return + } + } else { + if err := p.sem.Acquire(1); err != nil { + return + } } subpoolTask := subpoolTask[any]{ task: task, + queueSem: p.queueSem, sem: p.sem, waitGroup: &p.waitGroup, updateMetrics: p.updateMetrics, } - p.parent.Go(subpoolTask.Run) + p.waitGroup.Add(1) + + if err := p.parent.Go(subpoolTask.Run); err != nil { + // We failed to submit the task, release semaphore + subpoolTask.Close() + } } } @@ -81,11 +100,13 @@ func (p *subpool) Stop() Task { p.dispatcher.CloseAndWait() p.waitGroup.Wait() - - close(p.sem) }) } func (p *subpool) StopAndWait() { p.Stop().Wait() } + +func (p *subpool) RunningWorkers() int64 { + return int64(p.sem.Acquired()) +} diff --git a/subpool_test.go b/subpool_test.go index 29c5d4e..4e52f5e 100644 --- a/subpool_test.go +++ b/subpool_test.go @@ -191,3 +191,111 @@ func TestSubpoolStoppedAfterCancel(t *testing.T) { assert.Equal(t, ErrPoolStopped, err) } + +func TestSubpoolWithDifferentLimits(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + pool := NewPool(7, WithContext(ctx)) + + subpool1 := pool.NewSubpool(1) + subpool2 := pool.NewSubpool(2) + subpool3 := pool.NewSubpool(3) + + taskStarted := make(chan struct{}, 10) + taskWait := make(chan struct{}) + + var task = func() func() { + return func() { + taskStarted <- struct{}{} + <-taskWait + } + } + + // Submit tasks to subpool1 and wait for 1 task to start + for i := 0; i < 10; i++ { + subpool1.Submit(task()) + } + <-taskStarted + + // Submit tasks to subpool2 and wait for 2 tasks to start + for i := 0; i < 10; i++ { + subpool2.Submit(task()) + } + <-taskStarted + <-taskStarted + + // Submit tasks to subpool3 and wait for 3 tasks to start + for i := 0; i < 10; i++ { + subpool3.Submit(task()) + } + <-taskStarted + <-taskStarted + <-taskStarted + + // Submit tasks to the main pool and wait for 1 to start + for i := 0; i < 10; i++ { + pool.Submit(task()) + } + <-taskStarted + + // Verify concurrency of each pool + assert.Equal(t, int64(1), subpool1.RunningWorkers()) + assert.Equal(t, int64(2), subpool2.RunningWorkers()) + assert.Equal(t, int64(3), subpool3.RunningWorkers()) + assert.Equal(t, int64(7), pool.RunningWorkers()) + + assert.Equal(t, uint64(0), subpool1.CompletedTasks()) + assert.Equal(t, uint64(0), subpool2.CompletedTasks()) + assert.Equal(t, uint64(0), subpool3.CompletedTasks()) + assert.Equal(t, uint64(0), pool.CompletedTasks()) + + // Cancel the context to abort pending tasks + cancel() + + // Unblock all running tasks + close(taskWait) + + subpool1.StopAndWait() + subpool2.StopAndWait() + subpool3.StopAndWait() + pool.StopAndWait() + + assert.Equal(t, uint64(1), subpool1.CompletedTasks()) + assert.Equal(t, uint64(2), subpool2.CompletedTasks()) + assert.Equal(t, uint64(3), subpool3.CompletedTasks()) + assert.Equal(t, uint64(7), pool.CompletedTasks()) +} + +func TestSubpoolWithQueueSizeOverride(t *testing.T) { + pool := NewPool(10, WithQueueSize(10)) + + subpool := pool.NewSubpool(1, WithQueueSize(2), NonBlocking()) + + taskStarted := make(chan struct{}, 10) + taskWait := make(chan struct{}) + + var task = func() func() { + return func() { + taskStarted <- struct{}{} + <-taskWait + } + } + + // Submit tasks to subpool and wait for it to start + subpool.Submit(task()) + <-taskStarted + + // Submit more tasks to fill up the queue + for i := 0; i < 10; i++ { + subpool.Submit(task()) + } + + // 7 tasks should have been discarded + assert.Equal(t, int64(1), subpool.RunningWorkers()) + assert.Equal(t, uint64(3), subpool.SubmittedTasks()) + + // Unblock all running tasks + close(taskWait) + + subpool.StopAndWait() + pool.StopAndWait() +} diff --git a/task.go b/task.go index 6814073..518461b 100644 --- a/task.go +++ b/task.go @@ -4,24 +4,27 @@ import ( "errors" "fmt" "sync" + + "github.com/alitto/pond/v2/internal/semaphore" ) var ErrPanic = errors.New("task panicked") type subpoolTask[R any] struct { task any - sem chan struct{} + queueSem *semaphore.Weighted + sem *semaphore.Weighted waitGroup *sync.WaitGroup updateMetrics func(error) } func (t subpoolTask[R]) Run() { - defer func() { - // Release semaphore - <-t.sem - // Decrement wait group - t.waitGroup.Done() - }() + defer t.Close() + + // Release task queue semaphore when task is pulled from queue + if t.queueSem != nil { + t.queueSem.Release(1) + } _, err := invokeTask[R](t.task) @@ -30,6 +33,14 @@ func (t subpoolTask[R]) Run() { } } +func (t subpoolTask[R]) Close() { + // Release semaphore + t.sem.Release(1) + + // Decrement wait group + t.waitGroup.Done() +} + type wrappedTask[R any, C func(error) | func(R, error)] struct { task any callback C