From 2cce8d8c365e7553e11a36d26bdda8aaeb75b5f7 Mon Sep 17 00:00:00 2001
From: ccoVeille <3875889+ccoVeille@users.noreply.github.com>
Date: Wed, 3 Jul 2024 00:03:13 +0200
Subject: [PATCH 1/5] chore: fix test timeout helper

using os.Exit(1) kills everything, tests statuses are not always displayed
---
 lo_test.go | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/lo_test.go b/lo_test.go
index 26db4c25..bc92dc35 100644
--- a/lo_test.go
+++ b/lo_test.go
@@ -1,7 +1,6 @@
 package lo
 
 import (
-	"os"
 	"testing"
 	"time"
 )
@@ -18,7 +17,7 @@ func testWithTimeout(t *testing.T, timeout time.Duration) {
 		case <-testFinished:
 		case <-time.After(timeout):
 			t.Errorf("test timed out after %s", timeout)
-			os.Exit(1)
+			t.FailNow()
 		}
 	}()
 }

From e26f1b43796eee77974b4b9e8d20d7636c3bfab2 Mon Sep 17 00:00:00 2001
From: ccoVeille <3875889+ccoVeille@users.noreply.github.com>
Date: Wed, 3 Jul 2024 00:07:14 +0200
Subject: [PATCH 2/5] chore: refactor WaitFor unit tests

zero-code changes
---
 concurrency_test.go | 109 ++++++++++++++++++++++++++++++--------------
 1 file changed, 74 insertions(+), 35 deletions(-)

diff --git a/concurrency_test.go b/concurrency_test.go
index 0ee70dbf..ee339a90 100644
--- a/concurrency_test.go
+++ b/concurrency_test.go
@@ -215,44 +215,83 @@ func TestAsyncX(t *testing.T) {
 
 func TestWaitFor(t *testing.T) {
 	t.Parallel()
-	testWithTimeout(t, 100*time.Millisecond)
-	is := assert.New(t)
 
-	alwaysTrue := func(i int) bool { return true }
-	alwaysFalse := func(i int) bool { return false }
+	testTimeout := 100 * time.Millisecond
+	longTimeout := 2 * testTimeout
+	shortTimeout := 4 * time.Millisecond
 
-	iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
-	is.Equal(1, iter)
-	is.Equal(time.Duration(0), duration)
-	is.True(ok)
-	iter, duration, ok = WaitFor(alwaysFalse, 10*time.Millisecond, 4*time.Millisecond)
-	is.Equal(3, iter)
-	is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
-	is.False(ok)
+	t.Run("exist condition works", func(t *testing.T) {
+		t.Parallel()
 
-	laterTrue := func(i int) bool {
-		return i >= 5
-	}
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
 
-	iter, duration, ok = WaitFor(laterTrue, 10*time.Millisecond, time.Millisecond)
-	is.Equal(6, iter)
-	is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
-	is.True(ok)
-	iter, duration, ok = WaitFor(laterTrue, 10*time.Millisecond, 5*time.Millisecond)
-	is.Equal(2, iter)
-	is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
-	is.False(ok)
-
-	counter := 0
-
-	alwaysFalse = func(i int) bool {
-		is.Equal(counter, i)
-		counter++
-		return false
-	}
+		laterTrue := func(i int) bool {
+			return i >= 5
+		}
+
+		iter, duration, ok := WaitFor(laterTrue, longTimeout, time.Millisecond)
+		is.Equal(6, iter, "unexpected iteration count")
+		is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
+		is.True(ok)
+	})
+
+	t.Run("counter is incremented", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		counter := 0
+		alwaysFalse := func(i int) bool {
+			is.Equal(counter, i)
+			counter++
+			return false
+		}
+
+		iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 1050*time.Microsecond)
+		is.Equal(counter, iter, "unexpected iteration count")
+		is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
+		is.False(ok)
+	})
 
-	iter, duration, ok = WaitFor(alwaysFalse, 10*time.Millisecond, 1050*time.Microsecond)
-	is.Equal(10, iter)
-	is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
-	is.False(ok)
+	alwaysTrue := func(_ int) bool { return true }
+	alwaysFalse := func(_ int) bool { return false }
+
+	t.Run("short timeout works", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond)
+		is.Equal(1, iter, "unexpected iteration count")
+		is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
+		is.False(ok)
+	})
+
+	t.Run("timeout works", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		shortTimeout := 4 * time.Millisecond
+		iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond)
+		is.Equal(1, iter, "unexpected iteration count")
+		is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
+		is.False(ok)
+	})
+
+	t.Run("exist on first condition", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
+		is.Equal(1, iter, "unexpected iteration count")
+		is.Zero(duration)
+		is.True(ok)
+	})
 }

From 071a746f6331fe338a895f47b98173e9445292f8 Mon Sep 17 00:00:00 2001
From: ccoVeille <3875889+ccoVeille@users.noreply.github.com>
Date: Wed, 3 Jul 2024 00:15:15 +0200
Subject: [PATCH 3/5] fix: WaitFor on first condition

duration must be non-zero if first conditions is true
---
 README.md           | 4 ++--
 concurrency.go      | 6 +-----
 concurrency_test.go | 6 +++---
 3 files changed, 6 insertions(+), 10 deletions(-)

diff --git a/README.md b/README.md
index efa83fb2..bdd73086 100644
--- a/README.md
+++ b/README.md
@@ -3068,9 +3068,9 @@ laterTrue := func(i int) bool {
     return i > 5
 }
 
-iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
+iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond)
 // 1
-// 0ms
+// 1ms
 // true
 
 iterations, duration, ok := lo.WaitFor(alwaysFalse, 10*time.Millisecond, time.Millisecond)
diff --git a/concurrency.go b/concurrency.go
index 95580661..d907a74a 100644
--- a/concurrency.go
+++ b/concurrency.go
@@ -99,10 +99,6 @@ func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A,
 
 // WaitFor runs periodically until a condition is validated.
 func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) {
-	if condition(0) {
-		return 1, 0, true
-	}
-
 	start := time.Now()
 
 	timer := time.NewTimer(maxDuration)
@@ -113,7 +109,7 @@ func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Du
 		ticker.Stop()
 	}()
 
-	i := 1
+	i := 0
 
 	for {
 		select {
diff --git a/concurrency_test.go b/concurrency_test.go
index ee339a90..8a8c9e5d 100644
--- a/concurrency_test.go
+++ b/concurrency_test.go
@@ -265,7 +265,7 @@ func TestWaitFor(t *testing.T) {
 		is := assert.New(t)
 
 		iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond)
-		is.Equal(1, iter, "unexpected iteration count")
+		is.Equal(0, iter, "unexpected iteration count")
 		is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
 		is.False(ok)
 	})
@@ -278,7 +278,7 @@ func TestWaitFor(t *testing.T) {
 
 		shortTimeout := 4 * time.Millisecond
 		iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond)
-		is.Equal(1, iter, "unexpected iteration count")
+		is.Equal(0, iter, "unexpected iteration count")
 		is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
 		is.False(ok)
 	})
@@ -291,7 +291,7 @@ func TestWaitFor(t *testing.T) {
 
 		iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
 		is.Equal(1, iter, "unexpected iteration count")
-		is.Zero(duration)
+		is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond))
 		is.True(ok)
 	})
 }

From e1d8c98673fa352825f3fb5d9fd4de31e54053fc Mon Sep 17 00:00:00 2001
From: ccoVeille <3875889+ccoVeille@users.noreply.github.com>
Date: Sun, 30 Jun 2024 01:42:12 +0200
Subject: [PATCH 4/5] feat: add WaitForWithContext

---
 README.md           |  44 +++++++++++++++++
 concurrency.go      |  29 +++++++----
 concurrency_test.go | 116 ++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 180 insertions(+), 9 deletions(-)

diff --git a/README.md b/README.md
index bdd73086..a4cd9478 100644
--- a/README.md
+++ b/README.md
@@ -276,6 +276,7 @@ Concurrency helpers:
 - [Async](#async)
 - [Transaction](#transaction)
 - [WaitFor](#waitfor)
+- [WaitForWithContext](#waitforwithcontext)
 
 Error handling:
 
@@ -3089,6 +3090,49 @@ iterations, duration, ok := lo.WaitFor(laterTrue, 10*time.Millisecond, 5*time.Mi
 // false
 ```
 
+
+### WaitForWithContext
+
+Runs periodically until a condition is validated or context is invalid.
+
+The condition receives also the context, so it can invalidate the process in the condition checker
+
+```go
+ctx := context.Background()
+
+alwaysTrue := func(_ context.Context, i int) bool { return true }
+alwaysFalse := func(_ context.Context, i int) bool { return false }
+laterTrue := func(_ context.Context, i int) bool {
+    return i >= 5
+}
+
+iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond)
+// 1
+// 1ms
+// true
+
+iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysFalse, 10*time.Millisecond, time.Millisecond)
+// 10
+// 10ms
+// false
+
+iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, time.Millisecond)
+// 5
+// 5ms
+// true
+
+iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, 5*time.Millisecond)
+// 2
+// 10ms
+// false
+
+expiringCtx, cancel := context.WithTimeout(ctx, 5*time.Millisecond)
+iterations, duration, ok := lo.WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, time.Millisecond)
+// 5
+// 5.1ms
+// false
+```
+
 ### Validate
 
 Helper function that creates an error when a condition is not met.
diff --git a/concurrency.go b/concurrency.go
index d907a74a..dc16f8df 100644
--- a/concurrency.go
+++ b/concurrency.go
@@ -1,6 +1,7 @@
 package lo
 
 import (
+	"context"
 	"sync"
 	"time"
 )
@@ -99,28 +100,38 @@ func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A,
 
 // WaitFor runs periodically until a condition is validated.
 func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) {
+	conditionWithContext := func(_ context.Context, i int) bool {
+		return condition(i)
+	}
+	return WaitForWithContext(context.Background(), conditionWithContext, maxDuration, tick)
+}
+
+// WaitForWithContext runs periodically until a condition is validated or context is canceled.
+func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) {
 	start := time.Now()
 
-	timer := time.NewTimer(maxDuration)
+	i := 0
+	if ctx.Err() != nil {
+		return i, time.Since(start), false
+	}
+
+	ctx, cleanCtx := context.WithTimeout(ctx, maxDuration)
 	ticker := time.NewTicker(tick)
 
 	defer func() {
-		timer.Stop()
+		cleanCtx()
 		ticker.Stop()
 	}()
 
-	i := 0
-
 	for {
 		select {
-		case <-timer.C:
+		case <-ctx.Done():
 			return i, time.Since(start), false
 		case <-ticker.C:
-			if condition(i) {
-				return i + 1, time.Since(start), true
-			}
-
 			i++
+			if condition(ctx, i-1) {
+				return i, time.Since(start), true
+			}
 		}
 	}
 }
diff --git a/concurrency_test.go b/concurrency_test.go
index 8a8c9e5d..61f3dd61 100644
--- a/concurrency_test.go
+++ b/concurrency_test.go
@@ -1,6 +1,7 @@
 package lo
 
 import (
+	"context"
 	"sync"
 	"testing"
 	"time"
@@ -295,3 +296,118 @@ func TestWaitFor(t *testing.T) {
 		is.True(ok)
 	})
 }
+
+func TestWaitForWithContext(t *testing.T) {
+	t.Parallel()
+
+	testTimeout := 100 * time.Millisecond
+	longTimeout := 2 * testTimeout
+	shortTimeout := 4 * time.Millisecond
+
+	t.Run("exist condition works", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		laterTrue := func(_ context.Context, i int) bool {
+			return i >= 5
+		}
+
+		iter, duration, ok := WaitForWithContext(context.Background(), laterTrue, longTimeout, time.Millisecond)
+		is.Equal(6, iter, "unexpected iteration count")
+		is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
+		is.True(ok)
+	})
+
+	t.Run("counter is incremented", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		counter := 0
+		alwaysFalse := func(_ context.Context, i int) bool {
+			is.Equal(counter, i)
+			counter++
+			return false
+		}
+
+		iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 1050*time.Microsecond)
+		is.Equal(counter, iter, "unexpected iteration count")
+		is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
+		is.False(ok)
+	})
+
+	alwaysTrue := func(_ context.Context, _ int) bool { return true }
+	alwaysFalse := func(_ context.Context, _ int) bool { return false }
+
+	t.Run("short timeout works", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond)
+		is.Equal(0, iter, "unexpected iteration count")
+		is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
+		is.False(ok)
+	})
+
+	t.Run("timeout works", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		shortTimeout := 4 * time.Millisecond
+		iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond)
+		is.Equal(0, iter, "unexpected iteration count")
+		is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
+		is.False(ok)
+	})
+
+	t.Run("exist on first condition", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		iter, duration, ok := WaitForWithContext(context.Background(), alwaysTrue, 10*time.Millisecond, time.Millisecond)
+		is.Equal(1, iter, "unexpected iteration count")
+		is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond))
+		is.True(ok)
+	})
+
+	t.Run("context cancellation stops everything", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		expiringCtx, clean := context.WithTimeout(context.Background(), 8*time.Millisecond)
+		t.Cleanup(func() {
+			clean()
+		})
+
+		iter, duration, ok := WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, 3*time.Millisecond)
+		is.Equal(2, iter, "unexpected iteration count")
+		is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
+		is.False(ok)
+	})
+
+	t.Run("canceled context stops everything", func(t *testing.T) {
+		t.Parallel()
+
+		testWithTimeout(t, testTimeout)
+		is := assert.New(t)
+
+		canceledCtx, cancel := context.WithCancel(context.Background())
+		cancel()
+
+		iter, duration, ok := WaitForWithContext(canceledCtx, alwaysFalse, 100*time.Millisecond, 1050*time.Microsecond)
+		is.Equal(0, iter, "unexpected iteration count")
+		is.InEpsilon(1*time.Millisecond, duration, float64(5*time.Microsecond))
+		is.False(ok)
+	})
+}

From 74eb4206786d25ec1df911b29394d60ecbeec287 Mon Sep 17 00:00:00 2001
From: ccoVeille <3875889+ccoVeille@users.noreply.github.com>
Date: Wed, 3 Jul 2024 00:35:19 +0200
Subject: [PATCH 5/5] chore: provide meaningful returned values for WaitFor and
 WaitForWithContext

---
 concurrency.go | 25 ++++++++++++-------------
 1 file changed, 12 insertions(+), 13 deletions(-)

diff --git a/concurrency.go b/concurrency.go
index dc16f8df..a2ebbce2 100644
--- a/concurrency.go
+++ b/concurrency.go
@@ -99,24 +99,23 @@ func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A,
 }
 
 // WaitFor runs periodically until a condition is validated.
-func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) {
-	conditionWithContext := func(_ context.Context, i int) bool {
-		return condition(i)
+func WaitFor(condition func(i int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) {
+	conditionWithContext := func(_ context.Context, currentIteration int) bool {
+		return condition(currentIteration)
 	}
-	return WaitForWithContext(context.Background(), conditionWithContext, maxDuration, tick)
+	return WaitForWithContext(context.Background(), conditionWithContext, timeout, heartbeatDelay)
 }
 
 // WaitForWithContext runs periodically until a condition is validated or context is canceled.
-func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) {
+func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, currentIteration int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) {
 	start := time.Now()
 
-	i := 0
 	if ctx.Err() != nil {
-		return i, time.Since(start), false
+		return totalIterations, time.Since(start), false
 	}
 
-	ctx, cleanCtx := context.WithTimeout(ctx, maxDuration)
-	ticker := time.NewTicker(tick)
+	ctx, cleanCtx := context.WithTimeout(ctx, timeout)
+	ticker := time.NewTicker(heartbeatDelay)
 
 	defer func() {
 		cleanCtx()
@@ -126,11 +125,11 @@ func WaitForWithContext(ctx context.Context, condition func(ctx context.Context,
 	for {
 		select {
 		case <-ctx.Done():
-			return i, time.Since(start), false
+			return totalIterations, time.Since(start), false
 		case <-ticker.C:
-			i++
-			if condition(ctx, i-1) {
-				return i, time.Since(start), true
+			totalIterations++
+			if condition(ctx, totalIterations-1) {
+				return totalIterations, time.Since(start), true
 			}
 		}
 	}